diff --git a/docker-compose.yaml b/docker-compose.yaml index 25dac7fc..a1aac726 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -59,6 +59,29 @@ services: capabilities: [gpu] runtime: nvidia + yolox-table-structure: + image: ${YOLOX_TABLE_STRUCTURE_IMAGE:-set-image-name-in-dot-env}:${YOLOX_TABLE_STRUCTURE_TAG:-set-image-tag-in-dot-env} + ports: + - "8006:8000" + - "8007:8001" + - "8008:8002" + user: root + environment: + - NIM_HTTP_API_PORT=8000 + - NIM_TRITON_LOG_VERBOSE=1 + - NGC_API_KEY=${STAGING_NIM_NGC_API_KEY} + - CUDA_VISIBLE_DEVICES=0 + deploy: + resources: + reservations: + devices: + - driver: nvidia + device_ids: ["1"] + capabilities: [gpu] + runtime: nvidia + profiles: + - yolox-table-structure + paddle: image: ${PADDLE_IMAGE:-nvcr.io/nvidia/nemo-microservices/paddleocr}:${PADDLE_TAG:-1.0.0} shm_size: 2gb @@ -155,6 +178,9 @@ services: - YOLOX_GRAPHIC_ELEMENTS_GRPC_ENDPOINT=yolox-graphic-elements:8001 - YOLOX_GRAPHIC_ELEMENTS_HTTP_ENDPOINT=http://yolox-graphic-elements:8000/v1/infer - YOLOX_GRAPHIC_ELEMENTS_INFER_PROTOCOL=grpc + - YOLOX_TABLE_STRUCTURE_GRPC_ENDPOINT=yolox-table-structure:8001 + - YOLOX_TABLE_STRUCTURE_HTTP_ENDPOINT=http://yolox-table-structure:8000/v1/infer + - YOLOX_TABLE_STRUCTURE_INFER_PROTOCOL=grpc - VLM_CAPTION_ENDPOINT=https://ai.api.nvidia.com/v1/gr/meta/llama-3.2-11b-vision-instruct/chat/completions - VLM_CAPTION_MODEL_NAME=meta/llama-3.2-11b-vision-instruct healthcheck: diff --git a/src/nv_ingest/schemas/table_extractor_schema.py b/src/nv_ingest/schemas/table_extractor_schema.py index 6c97665f..6e9b5875 100644 --- a/src/nv_ingest/schemas/table_extractor_schema.py +++ b/src/nv_ingest/schemas/table_extractor_schema.py @@ -44,6 +44,9 @@ class TableExtractorConfigSchema(BaseModel): auth_token: Optional[str] = None + yolox_endpoints: Tuple[Optional[str], Optional[str]] = (None, None) + yolox_infer_protocol: str = "" + paddle_endpoints: Tuple[Optional[str], Optional[str]] = (None, None) paddle_infer_protocol: str = "" @@ -78,14 +81,15 @@ def clean_service(service): return None return service - grpc_service, http_service = values.get("paddle_endpoints", (None, None)) - grpc_service = clean_service(grpc_service) - http_service = clean_service(http_service) + for endpoint_name in ["yolox_endpoints", "paddle_endpoints"]: + grpc_service, http_service = values.get(endpoint_name, (None, None)) + grpc_service = clean_service(grpc_service) + http_service = clean_service(http_service) - if not grpc_service and not http_service: - raise ValueError("Both gRPC and HTTP services cannot be empty for paddle_endpoints.") + if not grpc_service and not http_service: + raise ValueError(f"Both gRPC and HTTP services cannot be empty for {endpoint_name}.") - values["paddle_endpoints"] = (grpc_service, http_service) + values[endpoint_name] = (grpc_service, http_service) return values diff --git a/src/nv_ingest/stages/nim/chart_extraction.py b/src/nv_ingest/stages/nim/chart_extraction.py index da3f0e4c..10771714 100644 --- a/src/nv_ingest/stages/nim/chart_extraction.py +++ b/src/nv_ingest/stages/nim/chart_extraction.py @@ -17,7 +17,7 @@ from nv_ingest.schemas.chart_extractor_schema import ChartExtractorSchema from nv_ingest.stages.multiprocessing_stage import MultiProcessingBaseStage -from nv_ingest.util.image_processing.table_and_chart import join_yolox_and_paddle_output +from nv_ingest.util.image_processing.table_and_chart import join_yolox_graphic_elements_and_paddle_output from nv_ingest.util.image_processing.table_and_chart import process_yolox_graphic_elements from nv_ingest.util.image_processing.transforms import base64_to_numpy from nv_ingest.util.nim.helpers import NimClient @@ -116,7 +116,7 @@ def _update_metadata( # Join the corresponding results from both services for each image. for idx, (yolox_res, paddle_res) in enumerate(zip(yolox_results, paddle_results)): bounding_boxes, text_predictions = paddle_res - yolox_elements = join_yolox_and_paddle_output(yolox_res, bounding_boxes, text_predictions) + yolox_elements = join_yolox_graphic_elements_and_paddle_output(yolox_res, bounding_boxes, text_predictions) chart_content = process_yolox_graphic_elements(yolox_elements) original_index = valid_indices[idx] results[original_index] = (base64_images[original_index], chart_content) diff --git a/src/nv_ingest/stages/nim/table_extraction.py b/src/nv_ingest/stages/nim/table_extraction.py index 5077e832..c914ab5a 100644 --- a/src/nv_ingest/stages/nim/table_extraction.py +++ b/src/nv_ingest/stages/nim/table_extraction.py @@ -4,24 +4,28 @@ import functools import logging +import traceback +from concurrent.futures import ThreadPoolExecutor from typing import Any from typing import Dict from typing import List from typing import Optional from typing import Tuple +import numpy as np import pandas as pd from morpheus.config import Config from nv_ingest.schemas.metadata_schema import TableFormatEnum from nv_ingest.schemas.table_extractor_schema import TableExtractorSchema from nv_ingest.stages.multiprocessing_stage import MultiProcessingBaseStage -from nv_ingest.util.image_processing.table_and_chart import convert_paddle_response_to_psuedo_markdown from nv_ingest.util.image_processing.transforms import base64_to_numpy from nv_ingest.util.nim.helpers import NimClient from nv_ingest.util.nim.helpers import create_inference_client -from nv_ingest.util.nim.helpers import get_version from nv_ingest.util.nim.paddle import PaddleOCRModelInterface +from nv_ingest.util.nim.yolox import YoloxTableStructureModelInterface +from nv_ingest.util.image_processing.table_and_chart import join_yolox_table_structure_and_paddle_output +from nv_ingest.util.image_processing.table_and_chart import convert_paddle_response_to_psuedo_markdown logger = logging.getLogger(__name__) @@ -31,8 +35,10 @@ def _update_metadata( base64_images: List[str], + yolox_client: NimClient, paddle_client: NimClient, worker_pool_size: int = 8, # Not currently used + enable_yolox: bool = False, trace_info: Dict = None, ) -> List[Tuple[str, Tuple[Any, Any]]]: """ @@ -48,79 +54,121 @@ def _update_metadata( logger.debug(f"Running table extraction using protocol {paddle_client.protocol}") # Initialize the results list in the same order as base64_images. - results: List[Optional[Tuple[str, Tuple[Any, Any]]]] = [None] * len(base64_images) + results: List[Optional[Tuple[str, Tuple[Any, Any, Any]]]] = ["", (None, None, None)] * len(base64_images) valid_images: List[str] = [] valid_indices: List[int] = [] + valid_arrays: List[np.ndarray] = [] - _ = worker_pool_size # Pre-decode image dimensions and filter valid images. for i, img in enumerate(base64_images): array = base64_to_numpy(img) height, width = array.shape[0], array.shape[1] if width >= PADDLE_MIN_WIDTH and height >= PADDLE_MIN_HEIGHT: valid_images.append(img) + valid_arrays.append(array) valid_indices.append(i) else: # Image is too small; mark as skipped. - results[i] = (img, (None, None)) + results[i] = ("", None, None, None) - if valid_images: - data = {"base64_images": valid_images} - try: - # Call infer once for all valid images. The NimClient will handle batching internally. - paddle_result = paddle_client.infer( - data=data, - model_name="paddle", + if not valid_images: + return results + + # Prepare data payloads for both clients. + if enable_yolox: + data_yolox = {"images": valid_arrays} + data_paddle = {"base64_images": valid_images} + + _ = worker_pool_size + with ThreadPoolExecutor(max_workers=2) as executor: + if enable_yolox: + future_yolox = executor.submit( + yolox_client.infer, + data=data_yolox, + model_name="yolox", stage_name="table_data_extraction", - max_batch_size=1 if paddle_client.protocol == "grpc" else 2, + max_batch_size=8, trace_info=trace_info, ) + future_paddle = executor.submit( + paddle_client.infer, + data=data_paddle, + model_name="paddle", + stage_name="table_data_extraction", + max_batch_size=1 if paddle_client.protocol == "grpc" else 2, + trace_info=trace_info, + ) - if not isinstance(paddle_result, list): - raise ValueError(f"Expected a list of tuples, got {type(paddle_result)}") - if len(paddle_result) != len(valid_images): - raise ValueError(f"Expected {len(valid_images)} results, got {len(paddle_result)}") - - # Assign each result back to its original position. - for idx, result in enumerate(paddle_result): - original_index = valid_indices[idx] - results[original_index] = (base64_images[original_index], result) + if enable_yolox: + try: + yolox_results = future_yolox.result() + except Exception as e: + logger.error(f"Error calling yolox_client.infer: {e}", exc_info=True) + raise + else: + yolox_results = [None] * len(valid_images) + try: + paddle_results = future_paddle.result() except Exception as e: - logger.error(f"Error processing images. Error: {e}", exc_info=True) - for i in valid_indices: - results[i] = (base64_images[i], (None, None)) + logger.error(f"Error calling paddle_client.infer: {e}", exc_info=True) raise - return results + # Ensure both clients returned lists of results matching the number of input images. + if not isinstance(yolox_results, list) or not isinstance(paddle_results, list): + logger.warning( + "Unexpected result types from inference clients: yolox_results=%s, paddle_results=%s. " + "Proceeding with available results.", + type(yolox_results).__name__, + type(paddle_results).__name__, + ) + # Assign default values for missing results + if not isinstance(yolox_results, list): + yolox_results = [None] * len(valid_arrays) + if not isinstance(paddle_results, list): + paddle_results = [(None, None)] * len(valid_images) # Default for paddle output -def _create_paddle_client(stage_config) -> NimClient: - """ - Helper to create a NimClient for PaddleOCR, retrieving the paddle version from the endpoint. - """ - # Attempt to obtain PaddleOCR version from the second endpoint - paddle_endpoint = stage_config.paddle_endpoints[1] - try: - paddle_version = get_version(paddle_endpoint) - if not paddle_version: - logger.warning("Failed to obtain PaddleOCR version from the endpoint. Falling back to the latest version.") - paddle_version = None - except Exception: - logger.warning("Failed to get PaddleOCR version after 30 seconds. Falling back to the latest version.") - paddle_version = None + if len(yolox_results) != len(valid_arrays): + raise ValueError(f"Expected {len(valid_arrays)} yolox results, got {len(yolox_results)}") + if len(paddle_results) != len(valid_images): + raise ValueError(f"Expected {len(valid_images)} paddle results, got {len(paddle_results)}") + + for idx, (yolox_res, paddle_res) in enumerate(zip(yolox_results, paddle_results)): + original_index = valid_indices[idx] + results[original_index] = (base64_images[original_index], yolox_res, paddle_res[0], paddle_res[1]) + return results + + +def _create_clients( + yolox_endpoints: Tuple[str, str], + yolox_protocol: str, + paddle_endpoints: Tuple[str, str], + paddle_protocol: str, + auth_token: str, +) -> Tuple[NimClient, NimClient]: + yolox_model_interface = YoloxTableStructureModelInterface() paddle_model_interface = PaddleOCRModelInterface() + logger.debug(f"Inference protocols: yolox={yolox_protocol}, paddle={paddle_protocol}") + + yolox_client = create_inference_client( + endpoints=yolox_endpoints, + model_interface=yolox_model_interface, + auth_token=auth_token, + infer_protocol=yolox_protocol, + ) + paddle_client = create_inference_client( - endpoints=stage_config.paddle_endpoints, + endpoints=paddle_endpoints, model_interface=paddle_model_interface, - auth_token=stage_config.auth_token, - infer_protocol=stage_config.paddle_infer_protocol, + auth_token=auth_token, + infer_protocol=paddle_protocol, ) - return paddle_client + return yolox_client, paddle_client def _extract_table_data( @@ -157,7 +205,13 @@ def _extract_table_data( return df, trace_info stage_config = validated_config.stage_config - paddle_client = _create_paddle_client(stage_config) + yolox_client, paddle_client = _create_clients( + stage_config.yolox_endpoints, + stage_config.yolox_infer_protocol, + stage_config.paddle_endpoints, + stage_config.paddle_infer_protocol, + stage_config.auth_token, + ) try: # 1) Identify rows that meet criteria (structured, subtype=table, table_metadata != None, content not empty) @@ -189,26 +243,34 @@ def meets_criteria(row): base64_images.append(meta["content"]) # 3) Call our bulk _update_metadata to get all results + table_content_format = ( + df.at[valid_indices[0], "metadata"]["table_metadata"].get("table_content_format") + or TableFormatEnum.PSEUDO_MARKDOWN + ) + enable_yolox = True if table_content_format in (TableFormatEnum.MARKDOWN,) else False + bulk_results = _update_metadata( base64_images=base64_images, + yolox_client=yolox_client, paddle_client=paddle_client, worker_pool_size=stage_config.workers_per_progress_engine, + enable_yolox=enable_yolox, trace_info=trace_info, ) # 4) Write the results (bounding_boxes, text_predictions) back - table_content_format = df.at[valid_indices[0], "metadata"]["table_metadata"].get( - "table_content_format", TableFormatEnum.PSEUDO_MARKDOWN - ) - for row_id, idx in enumerate(valid_indices): - # unpack (base64_image, (bounding boxes, text_predictions)) - _, (bounding_boxes, text_predictions) = bulk_results[row_id] + # unpack (base64_image, (yolox_predictions, paddle_bounding boxes, paddle_text_predictions)) + _, cell_predictions, bounding_boxes, text_predictions = bulk_results[row_id] if table_content_format == TableFormatEnum.SIMPLE: table_content = " ".join(text_predictions) elif table_content_format == TableFormatEnum.PSEUDO_MARKDOWN: table_content = convert_paddle_response_to_psuedo_markdown(bounding_boxes, text_predictions) + elif table_content_format == TableFormatEnum.MARKDOWN: + table_content = join_yolox_table_structure_and_paddle_output( + cell_predictions, bounding_boxes, text_predictions + ) else: raise ValueError(f"Unexpected table format: {table_content_format}") @@ -219,8 +281,10 @@ def meets_criteria(row): except Exception: logger.error("Error occurred while extracting table data.", exc_info=True) + traceback.print_exc() raise finally: + yolox_client.close() paddle_client.close() diff --git a/src/nv_ingest/util/image_processing/table_and_chart.py b/src/nv_ingest/util/image_processing/table_and_chart.py index 4e3d6e91..5f0337dd 100644 --- a/src/nv_ingest/util/image_processing/table_and_chart.py +++ b/src/nv_ingest/util/image_processing/table_and_chart.py @@ -92,7 +92,7 @@ def match_bboxes(yolox_box, paddle_ocr_boxes, already_matched=None, delta=2.0): return matches -def join_yolox_and_paddle_output(yolox_output, paddle_boxes, paddle_txts): +def join_yolox_graphic_elements_and_paddle_output(yolox_output, paddle_boxes, paddle_txts): """ Matching boxes We need to associate a text to the paddle detections. @@ -184,3 +184,266 @@ def convert_paddle_response_to_psuedo_markdown(bboxes, texts): results += "| " + " | ".join(dfg["text"].values.tolist()) + " |\n" return results + + +def join_yolox_table_structure_and_paddle_output(yolox_cell_preds, paddle_ocr_boxes, paddle_ocr_txts): + if (not paddle_ocr_boxes) or (not paddle_ocr_txts): + return "" + + paddle_ocr_boxes = np.array(paddle_ocr_boxes) + paddle_ocr_boxes_ = np.array( + [ + paddle_ocr_boxes[:, :, 0].min(-1), + paddle_ocr_boxes[:, :, 1].min(-1), + paddle_ocr_boxes[:, :, 0].max(-1), + paddle_ocr_boxes[:, :, 1].max(-1), + ] + ).T + + assignments = [] + for i, (b, t) in enumerate(zip(paddle_ocr_boxes_, paddle_ocr_txts)): + # Find a cell + matches_cell = assign_boxes(b, yolox_cell_preds["cell"], delta=1) + cell = yolox_cell_preds["cell"][matches_cell[0]] if len(matches_cell) else b + + # Find a row + matches_row = assign_boxes(cell, yolox_cell_preds["row"], delta=1) + row_ids = matches_row if len(matches_row) else -1 + + # Find a column - or more if if it is the first row + if isinstance(row_ids, np.ndarray): + delta = 2 if row_ids.min() == 0 else 1 # delta=2 if header column + else: + delta = 1 + matches_col = assign_boxes(cell, yolox_cell_preds["column"], delta=delta) + col_ids = matches_col if len(matches_col) else -1 + + assignments.append( + { + "index": i, + "paddle_box": b, + "is_table": isinstance(col_ids, np.ndarray) and isinstance(row_ids, np.ndarray), + "cell_id": matches_cell[0] if len(matches_cell) else -1, + "cell": cell, + "col_ids": col_ids, + "row_ids": row_ids, + "text": t, + } + ) + # break + df_assign = pd.DataFrame(assignments) + + # Merge cells with several assigned texts + dfs = [] + for cell_id, df_cell in df_assign.groupby("cell_id"): + if len(df_cell) > 1 and cell_id > -1: + df_cell = merge_text_in_cell(df_cell) + dfs.append(df_cell) + df_assign = pd.concat(dfs) + + df_text = df_assign[~df_assign["is_table"]].reset_index(drop=True) + + # Table to text + df_table = df_assign[df_assign["is_table"]].reset_index(drop=True) + if len(df_table): + mat = build_markdown(df_table) + markdown_table = display_markdown(mat, use_header=False) + + all_boxes = np.stack(df_table.paddle_box.values) + table_box = np.concatenate([all_boxes[:, [0, 1]].min(0), all_boxes[:, [2, 3]].max(0)]) + + df_table_to_text = pd.DataFrame( + [ + { + "paddle_box": table_box, + "text": markdown_table, + "is_table": True, + } + ] + ) + # Final text representations dataframe + df_text = pd.concat([df_text, df_table_to_text], ignore_index=True) + + df_text = df_text.rename(columns={"paddle_box": "box"}) + + # Sort by y and x + df_text["x"] = df_text["box"].apply(lambda x: (x[0] + x[2]) / 2) + df_text["y"] = df_text["box"].apply(lambda x: (x[1] + x[3]) / 2) + df_text["x"] = (df_text["x"] - df_text["x"].min()) // 10 + df_text["y"] = (df_text["y"] - df_text["y"].min()) // 20 + df_text = df_text.sort_values(["y", "x"], ignore_index=True) + + # Loop over lines + rows_list = [] + for r, df_row in df_text.groupby("y"): + if df_row["is_table"].values.any(): # Add table + table = df_row[df_row["is_table"]] + df_row = df_row[~df_row["is_table"]] + else: + table = None + + if len(df_row) > 1: # Add text + df_row = df_row.reset_index(drop=True) + df_row["text"] = "\n".join(df_row["text"].values.tolist()) + + rows_list.append(df_row.head(1)) + + if table is not None: + rows_list.append(table) + + df_display = pd.concat(rows_list, ignore_index=True) + result = "\n".join(df_display.text.values.tolist()) + + return result + + +def assign_boxes(paddle_box, boxes, delta=2.0, min_overlap=0.25): + """ + Assigns the closest bounding boxes to a reference `paddle_box` based on overlap. + + Args: + paddle_box (list or numpy.ndarray): Reference bounding box [x_min, y_min, x_max, y_max]. + boxes (numpy.ndarray): Array of candidate bounding boxes with shape (N, 4). + delta (float, optional): Factor for matches relative to the best overlap. Defaults to 2.0. + min_overlap (float, optional): Minimum required overlap for a match. Defaults to 0.25. + + Returns: + list: Indices of the matched boxes sorted by decreasing overlap. + Returns an empty list if no matches are found. + """ + if not len(boxes): + return [] + + boxes = np.array(boxes) + + x0_1, y0_1, x1_1, y1_1 = paddle_box + x0_2, y0_2, x1_2, y1_2 = ( + boxes[:, 0], + boxes[:, 1], + boxes[:, 2], + boxes[:, 3], + ) + + # Intersection + inter_y0 = np.maximum(y0_1, y0_2) + inter_y1 = np.minimum(y1_1, y1_2) + inter_x0 = np.maximum(x0_1, x0_2) + inter_x1 = np.minimum(x1_1, x1_2) + inter_area = np.maximum(0, inter_y1 - inter_y0) * np.maximum(0, inter_x1 - inter_x0) + + # Normalize by paddle_box size + area_1 = (y1_1 - y0_1) * (x1_1 - x0_1) + ious = inter_area / (area_1 + 1e-6) + + max_iou = np.max(ious) + if max_iou <= min_overlap: # No match + return [] + + n = len(np.where(ious >= (max_iou / delta))[0]) + matches = np.argsort(-ious)[:n] + return matches + + +def build_markdown(df): + """ + Convert a dataframe into a markdown table. + + Args: + df (pandas DataFrame): The dataframe to convert. + + Returns: + list[list]: A list of lists representing the markdown table. + """ + df = df.reset_index(drop=True) + n_cols = max([np.max(c) for c in df["col_ids"].values]) + n_rows = max([np.max(c) for c in df["row_ids"].values]) + + mat = np.empty((n_rows + 1, n_cols + 1), dtype=str).tolist() + + for i in range(len(df)): + if isinstance(df["row_ids"][i], int) or isinstance(df["col_ids"][i], int): + continue + for r in df["row_ids"][i]: + for c in df["col_ids"][i]: + mat[r][c] = (mat[r][c] + " " + df["text"][i]).strip() + + # Remove empty rows & columns + mat = remove_empty_row(mat) + mat = np.array(remove_empty_row(np.array(mat).T.tolist())).T.tolist() + + return mat + + +def merge_text_in_cell(df_cell): + """ + Merges text from multiple rows into a single cell and recalculates its bounding box. + Values are sorted by rounded (y, x) coordinates. + + Args: + df_cell (pandas.DataFrame): DataFrame containing cells to merge. + + Returns: + pandas.DataFrame: Updated DataFrame with merged text and a single bounding box. + """ + paddle_boxes = np.stack(df_cell["paddle_box"].values) + + df_cell["x"] = (paddle_boxes[:, 0] - paddle_boxes[:, 0].min()) // 10 + df_cell["y"] = (paddle_boxes[:, 1] - paddle_boxes[:, 1].min()) // 10 + df_cell = df_cell.sort_values(["y", "x"]) + + text = " ".join(df_cell["text"].values.tolist()) + df_cell["text"] = text + df_cell = df_cell.head(1) + df_cell["paddle_box"] = df_cell["cell"] + df_cell.drop(["x", "y"], axis=1, inplace=True) + + return df_cell + + +def remove_empty_row(mat): + """ + Remove empty rows from a matrix. + + Args: + mat (list[list]): The matrix to remove empty rows from. + + Returns: + list[list]: The matrix with empty rows removed. + """ + mat_filter = [] + for row in mat: + if max([len(c) for c in row]): + mat_filter.append(row) + return mat_filter + + +def display_markdown( + data: list[list[str]], + use_header: bool = False, +) -> str: + """ + Convert a list of lists of strings into a markdown table. + + Parameters: + data (list[list[str]]): The table data. The first sublist should contain headers. + use_header (bool, optional): Whether to use the first sublist as headers. Defaults to True. + + Returns: + str: A markdown-formatted table as a string. + """ + if not len(data): + return "EMPTY TABLE" + + max_cols = max(len(row) for row in data) + data = [row + [""] * (max_cols - len(row)) for row in data] + + if use_header: + header = "| " + " | ".join(data[0]) + " |" + separator = "| " + " | ".join(["---"] * max_cols) + " |" + body = "\n".join("| " + " | ".join(row) + " |" for row in data[1:]) + markdown_table = f"{header}\n{separator}\n{body}" if body else f"{header}\n{separator}" + + else: + markdown_table = "\n".join("| " + " | ".join(row) + " |" for row in data) + + return markdown_table diff --git a/src/nv_ingest/util/nim/yolox.py b/src/nv_ingest/util/nim/yolox.py index 47df3569..420c6a2b 100644 --- a/src/nv_ingest/util/nim/yolox.py +++ b/src/nv_ingest/util/nim/yolox.py @@ -67,6 +67,26 @@ ] +# yolox-table-structure-v1 contants +YOLOX_TABLE_NUM_CLASSES = 5 +YOLOX_TABLE_CONF_THRESHOLD = 0.01 +YOLOX_TABLE_IOU_THRESHOLD = 0.25 +YOLOX_TABLE_MIN_SCORE = 0.1 +YOLOX_TABLE_FINAL_SCORE = 0.0 +YOLOX_TABLE_NIM_MAX_IMAGE_SIZE = 512_000 + +YOLOX_TABLE_IMAGE_PREPROC_HEIGHT = 1024 +YOLOX_TABLE_IMAGE_PREPROC_WIDTH = 1024 + +YOLOX_TABLE_CLASS_LABELS = [ + "border", + "cell", + "row", + "column", + "header", +] + + # YoloxModelInterfaceBase implements methods that are common to yolox-page-elements and yolox-graphic-elements class YoloxModelInterfaceBase(ModelInterface): """ @@ -471,6 +491,65 @@ def postprocess_annotations(self, annotation_dicts, **kwargs): return inference_results +class YoloxTableStructureModelInterface(YoloxModelInterfaceBase): + """ + An interface for handling inference with yolox-graphic-elemenents model, supporting both gRPC and HTTP protocols. + """ + + def __init__(self): + """ + Initialize the yolox-graphic-elements model interface. + """ + super().__init__( + image_preproc_width=YOLOX_TABLE_IMAGE_PREPROC_HEIGHT, + image_preproc_height=YOLOX_TABLE_IMAGE_PREPROC_HEIGHT, + nim_max_image_size=YOLOX_TABLE_NIM_MAX_IMAGE_SIZE, + num_classes=YOLOX_TABLE_NUM_CLASSES, + conf_threshold=YOLOX_TABLE_CONF_THRESHOLD, + iou_threshold=YOLOX_TABLE_IOU_THRESHOLD, + min_score=YOLOX_TABLE_MIN_SCORE, + final_score=YOLOX_TABLE_FINAL_SCORE, + class_labels=YOLOX_TABLE_CLASS_LABELS, + ) + + def name( + self, + ) -> str: + """ + Returns the name of the Yolox model interface. + + Returns + ------- + str + The name of the model interface. + """ + + return "yolox-table-structure" + + def postprocess_annotations(self, annotation_dicts, **kwargs): + original_image_shapes = kwargs.get("original_image_shapes", []) + + annotation_dicts = self.transform_normalized_coordinates_to_original(annotation_dicts, original_image_shapes) + + inference_results = [] + + # bbox extraction: additional postprocessing speicifc to nv-ingest + for pred, shape in zip(annotation_dicts, original_image_shapes): + bbox_dict = get_bbox_dict_yolox_table( + pred, + shape, + self.class_labels, + self.min_score, + ) + # convert numpy arrays to list + bbox_dict = { + label: array.tolist() if isinstance(array, np.ndarray) else array for label, array in bbox_dict.items() + } + inference_results.append(bbox_dict) + + return inference_results + + def postprocess_model_prediction(prediction, num_classes, conf_thre=0.7, nms_thre=0.45, class_agnostic=False): # Convert numpy array to torch tensor prediction = torch.from_numpy(prediction.copy()) @@ -1184,3 +1263,65 @@ def get_bbox_dict_yolox_graphic(preds, shape, class_labels, threshold_=0.1) -> D bbox_dict["other"] = bbox_dict.get("other", []) return bbox_dict + + +def get_bbox_dict_yolox_table(preds, shape, class_labels, threshold=0.1, delta=0.0): + """ + Extracts bounding boxes from YOLOX model predictions: + - Applies thresholding + - Reformats boxes + - Reorders predictions + + Args: + preds (np.ndarray): YOLOX model predictions including bounding boxes, scores, and labels. + shape (tuple): Original image shape. + config: Model configuration, including size for bounding box adjustment. + threshold (float): Score threshold to filter bounding boxes. + delta (float): How much the table was cropped upwards. + + Returns: + dict[str, np.ndarray]: Dictionary of bounding boxes, organized by class. + """ + bbox_dict = {label: np.array([]) for label in class_labels} + + for i, label in enumerate(class_labels): + if label not in ["cell", "row", "column"]: + continue # Ignore useless classes + + bboxes_class = np.array(preds[label]) + + if bboxes_class.size == 0: + continue + + # Threshold and clip + bboxes_class = bboxes_class[bboxes_class[:, -1] >= threshold][:, :4].astype(int) + bboxes_class[:, [0, 2]] = np.clip(bboxes_class[:, [0, 2]], 0, shape[1]) + bboxes_class[:, [1, 3]] = np.clip(bboxes_class[:, [1, 3]], 0, shape[0]) + + # Reorder + sort = ["x0", "y0"] if label != "row" else ["y0", "x0"] + df = pd.DataFrame( + { + "y0": (bboxes_class[:, 1] + bboxes_class[:, 3]) / 2, + "x0": (bboxes_class[:, 0] + bboxes_class[:, 2]) / 2, + } + ) + idxs = df.sort_values(sort).index + bboxes_class = bboxes_class[idxs] + + bbox_dict[label] = bboxes_class + + # Enforce spanning the entire table + if len(bbox_dict["row"]): + bbox_dict["row"][:, 0] = 0 + bbox_dict["row"][:, 2] = shape[1] + if len(bbox_dict["column"]): + bbox_dict["column"][:, 1] = 0 + bbox_dict["column"][:, 3] = shape[0] + + # Shift back if cropped + for k in bbox_dict: + if len(bbox_dict[k]): + bbox_dict[k][:, [1, 3]] = np.add(bbox_dict[k][:, [1, 3]], delta, casting="unsafe") + + return bbox_dict diff --git a/src/nv_ingest/util/pipeline/stage_builders.py b/src/nv_ingest/util/pipeline/stage_builders.py index eefce6b2..df046565 100644 --- a/src/nv_ingest/util/pipeline/stage_builders.py +++ b/src/nv_ingest/util/pipeline/stage_builders.py @@ -204,12 +204,14 @@ def add_pdf_extractor_stage(pipe, morpheus_pipeline_config, ingest_config, defau def add_table_extractor_stage(pipe, morpheus_pipeline_config, ingest_config, default_cpu_count): - _, _, yolox_auth, _ = get_table_detection_service("yolox") + yolox_grpc, yolox_http, yolox_auth, yolox_protocol = get_table_detection_service("yolox_table_structure") paddle_grpc, paddle_http, paddle_auth, paddle_protocol = get_table_detection_service("paddle") table_content_extractor_config = ingest_config.get( "table_content_extraction_module", { "stage_config": { + "yolox_endpoints": (yolox_grpc, yolox_http), + "yolox_infer_protocol": yolox_protocol, "paddle_endpoints": (paddle_grpc, paddle_http), "paddle_infer_protocol": paddle_protocol, "auth_token": yolox_auth,