diff --git a/README.md b/README.md index 5ae93af9..d16f317c 100644 --- a/README.md +++ b/README.md @@ -4,6 +4,10 @@ All rights reserved. SPDX-License-Identifier: Apache-2.0 --> +> [!Note] +> Cached and Deplot are deprecated, docker-compose now points to a beta version of the yolox-graphic-elements container instead. That model and container is slated for full release in March. +> With this change, you should now be able to run on a single 80GB A100 or H100 GPU. +> If you want to continue using the old pipeline with Cached and Deplot, please use the [24.12.1 release](https://github.com/NVIDIA/nv-ingest/tree/24.12.1). ## NVIDIA-Ingest: Multi-modal data extraction @@ -44,8 +48,8 @@ A service that: | GPU | Family | Memory | # of GPUs (min.) | | ------ | ------ | ------ | ------ | -| H100 | SXM or PCIe | 80GB | 2 | -| A100 | SXM or PCIe | 80GB | 2 | +| H100 | SXM or PCIe | 80GB | 1 | +| A100 | SXM or PCIe | 80GB | 1 | ### Software diff --git a/client/src/nv_ingest_client/client/client.py b/client/src/nv_ingest_client/client/client.py index 1f2e3bdd..74884d7c 100644 --- a/client/src/nv_ingest_client/client/client.py +++ b/client/src/nv_ingest_client/client/client.py @@ -93,10 +93,13 @@ def __init__( self._message_client_hostname = message_client_hostname or "localhost" self._message_client_port = message_client_port or 7670 self._message_counter_id = msg_counter_id or "nv-ingest-message-id" + self._message_client_kwargs = message_client_kwargs or {} logger.debug("Instantiate NvIngestClient:\n%s", str(self)) self._message_client = message_client_allocator( - host=self._message_client_hostname, port=self._message_client_port + host=self._message_client_hostname, + port=self._message_client_port, + **self._message_client_kwargs, ) # Initialize the worker pool with the specified size diff --git a/client/src/nv_ingest_client/message_clients/rest/rest_client.py b/client/src/nv_ingest_client/message_clients/rest/rest_client.py index 11598cf7..e47bccef 100644 --- a/client/src/nv_ingest_client/message_clients/rest/rest_client.py +++ b/client/src/nv_ingest_client/message_clients/rest/rest_client.py @@ -230,7 +230,7 @@ def fetch_message(self, job_id: str, timeout: float = 10) -> ResponseSchema: except RuntimeError as rte: raise rte - except requests.HTTPError as err: + except (ConnectionError, requests.HTTPError, requests.exceptions.ConnectionError) as err: logger.error(f"Error during fetching, retrying... Error: {err}") self._client = None # Invalidate client to force reconnection try: diff --git a/docker-compose.yaml b/docker-compose.yaml index e265ebb2..7b02948c 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -34,12 +34,12 @@ services: reservations: devices: - driver: nvidia - device_ids: ["1"] + device_ids: ["0"] capabilities: [gpu] runtime: nvidia - deplot: - image: ${DEPLOT_IMAGE:-nvcr.io/nvidia/nemo-microservices/deplot}:${DEPLOT_TAG:-1.0.0} + yolox-graphic-elements: + image: ${YOLOX_GRAPHIC_ELEMENTS_IMAGE:-nvcr.io/nvidia/nemo-microservices/nemoretriever-graphic-elements-v1}:${YOLOX_GRAPHIC_ELEMENTS_TAG:-1.1} ports: - "8003:8000" - "8004:8001" @@ -59,28 +59,6 @@ services: capabilities: [gpu] runtime: nvidia - cached: - image: ${CACHED_IMAGE:-nvcr.io/nvidia/nemo-microservices/cached}:${CACHED_TAG:-0.2.1} - shm_size: 2gb - ports: - - "8006:8000" - - "8007:8001" - - "8008:8002" - user: root - environment: - - NIM_HTTP_API_PORT=8000 - - NIM_TRITON_LOG_VERBOSE=1 - - NGC_API_KEY=${NIM_NGC_API_KEY:-${NGC_API_KEY:-ngcapikey}} - - CUDA_VISIBLE_DEVICES=0 - deploy: - resources: - reservations: - devices: - - driver: nvidia - device_ids: ["1"] - capabilities: [gpu] - runtime: nvidia - paddle: image: ${PADDLE_IMAGE:-nvcr.io/nvidia/nemo-microservices/paddleocr}:${PADDLE_TAG:-1.0.0} shm_size: 2gb @@ -99,13 +77,13 @@ services: reservations: devices: - driver: nvidia - device_ids: ["1"] + device_ids: ["0"] capabilities: [gpu] runtime: nvidia embedding: # NIM ON - image: ${EMBEDDING_IMAGE:-nvcr.io/nim/nvidia/nv-embedqa-e5-v5}:${EMBEDDING_TAG:-1.1.0} + image: ${EMBEDDING_IMAGE:-nvcr.io/nim/nvidia/llama-3.2-nv-embedqa-1b-v2}:${EMBEDDING_TAG:-1.3.0} shm_size: 16gb ports: - "8012:8000" @@ -121,7 +99,7 @@ services: reservations: devices: - driver: nvidia - device_ids: ["1"] + device_ids: ["0"] capabilities: [gpu] runtime: nvidia @@ -141,22 +119,9 @@ services: cap_add: - sys_nice environment: - # Self-hosted cached endpoints. - - CACHED_GRPC_ENDPOINT=cached:8001 - - CACHED_HTTP_ENDPOINT=http://cached:8000/v1/infer - - CACHED_INFER_PROTOCOL=grpc - # build.nvidia.com hosted cached endpoints. - #- CACHED_HTTP_ENDPOINT=https://ai.api.nvidia.com/v1/cv/university-at-buffalo/cached - #- CACHED_INFER_PROTOCOL=http - CUDA_VISIBLE_DEVICES=0 - #- DEPLOT_GRPC_ENDPOINT="" - # Self-hosted deplot endpoints. - - DEPLOT_HTTP_ENDPOINT=http://deplot:8000/v1/chat/completions - # build.nvidia.com hosted deplot - #- DEPLOT_HTTP_ENDPOINT=https://ai.api.nvidia.com/v1/vlm/google/deplot - - DEPLOT_INFER_PROTOCOL=http - DOUGHNUT_GRPC_TRITON=triton-doughnut:8001 - - EMBEDDING_NIM_MODEL_NAME=${EMBEDDING_NIM_MODEL_NAME:-nvidia/nv-embedqa-e5-v5} + - EMBEDDING_NIM_MODEL_NAME=${EMBEDDING_NIM_MODEL_NAME:-nvidia/llama-3.2-nv-embedqa-1b-v2} - INGEST_LOG_LEVEL=DEFAULT # Message client for development #- MESSAGE_CLIENT_HOST=0.0.0.0 @@ -187,6 +152,9 @@ services: # build.nvidia.com hosted yolox endpoints. #- YOLOX_HTTP_ENDPOINT=https://ai.api.nvidia.com/v1/cv/nvidia/nv-yolox-page-elements-v1 #- YOLOX_INFER_PROTOCOL=http + - YOLOX_GRAPHIC_ELEMENTS_GRPC_ENDPOINT=yolox-graphic-elements:8001 + - YOLOX_GRAPHIC_ELEMENTS_HTTP_ENDPOINT=http://yolox-graphic-elements:8000/v1/infer + - YOLOX_GRAPHIC_ELEMENTS_INFER_PROTOCOL=grpc - VLM_CAPTION_ENDPOINT=https://ai.api.nvidia.com/v1/gr/meta/llama-3.2-90b-vision-instruct/chat/completions - VLM_CAPTION_MODEL_NAME=meta/llama-3.2-90b-vision-instruct healthcheck: @@ -199,7 +167,7 @@ services: reservations: devices: - driver: nvidia - device_ids: ["1"] + device_ids: ["0"] capabilities: [gpu] otel-collector: @@ -321,7 +289,7 @@ services: reservations: devices: - driver: nvidia - device_ids: ["1"] + device_ids: ["0"] capabilities: [gpu] depends_on: - "etcd" diff --git a/src/nv_ingest/api/v1/health.py b/src/nv_ingest/api/v1/health.py index df813b60..075132e3 100644 --- a/src/nv_ingest/api/v1/health.py +++ b/src/nv_ingest/api/v1/health.py @@ -64,20 +64,26 @@ async def get_ready_state() -> dict: # We give the users an option to disable checking all distributed services for "readiness" check_all_components = os.getenv("READY_CHECK_ALL_COMPONENTS", "True").lower() if check_all_components in ["1", "true", "yes"]: - yolox_ready = is_ready(os.getenv("YOLOX_HTTP_ENDPOINT", None), "/v1/health/ready") - deplot_ready = is_ready(os.getenv("DEPLOT_HTTP_ENDPOINT", None), "/v1/health/ready") - cached_ready = is_ready(os.getenv("CACHED_HTTP_ENDPOINT", None), "/v1/health/ready") + yolox_page_elements_ready = is_ready(os.getenv("YOLOX_HTTP_ENDPOINT", None), "/v1/health/ready") + yolox_graphic_elements_ready = is_ready( + os.getenv("YOLOX_GRAPHIC_ELEMENTS_HTTP_ENDPOINT", None), "/v1/health/ready" + ) paddle_ready = is_ready(os.getenv("PADDLE_HTTP_ENDPOINT", None), "/v1/health/ready") - if ingest_ready and morpheus_pipeline_ready and yolox_ready and deplot_ready and cached_ready and paddle_ready: + if ( + ingest_ready + and morpheus_pipeline_ready + and yolox_page_elements_ready + and yolox_graphic_elements_ready + and paddle_ready + ): return JSONResponse(content={"ready": True}, status_code=200) else: ready_statuses = { "ingest_ready": ingest_ready, "morpheus_pipeline_ready": morpheus_pipeline_ready, - "yolox_ready": yolox_ready, - "deplot_ready": deplot_ready, - "cached_ready": cached_ready, + "yolox_page_elemenst_ready": yolox_page_elements_ready, + "yolox_graphic_elements_ready": yolox_graphic_elements_ready, "paddle_ready": paddle_ready, } logger.debug(f"Ready Statuses: {ready_statuses}") diff --git a/src/nv_ingest/extraction_workflows/image/image_handlers.py b/src/nv_ingest/extraction_workflows/image/image_handlers.py index 9accecf3..315ce2d0 100644 --- a/src/nv_ingest/extraction_workflows/image/image_handlers.py +++ b/src/nv_ingest/extraction_workflows/image/image_handlers.py @@ -158,7 +158,7 @@ def extract_table_and_chart_images( objects = annotation_dict[label] for idx, bboxes in enumerate(objects): *bbox, _ = bboxes - h1, w1, h2, w2 = np.array(bbox) * np.array([height, width, height, width]) + h1, w1, h2, w2 = bbox base64_img = crop_image(original_image, (int(h1), int(w1), int(h2), int(w2))) diff --git a/src/nv_ingest/extraction_workflows/pdf/pdfium_helper.py b/src/nv_ingest/extraction_workflows/pdf/pdfium_helper.py index 9ca0edf7..b1608ddc 100644 --- a/src/nv_ingest/extraction_workflows/pdf/pdfium_helper.py +++ b/src/nv_ingest/extraction_workflows/pdf/pdfium_helper.py @@ -173,15 +173,15 @@ def extract_table_and_chart_images( objects = annotation_dict[label] for idx, bboxes in enumerate(objects): *bbox, _ = bboxes - h1, w1, h2, w2 = bbox * np.array([height, width, height, width]) + h1, w1, h2, w2 = bbox - cropped = crop_image(original_image, (h1, w1, h2, w2)) + cropped = crop_image(original_image, (int(h1), int(w1), int(h2), int(w2))) base64_img = numpy_to_base64(cropped) table_data = CroppedImageWithContent( content="", image=base64_img, - bbox=(w1, h1, w2, h2), + bbox=(int(w1), int(h1), int(w2), int(h2)), max_width=width, max_height=height, type_string=label, diff --git a/src/nv_ingest/schemas/chart_extractor_schema.py b/src/nv_ingest/schemas/chart_extractor_schema.py index 7c652e1d..2c56fd7b 100644 --- a/src/nv_ingest/schemas/chart_extractor_schema.py +++ b/src/nv_ingest/schemas/chart_extractor_schema.py @@ -20,12 +20,8 @@ class ChartExtractorConfigSchema(BaseModel): auth_token : Optional[str], default=None Authentication token required for secure services. - cached_endpoints : Tuple[Optional[str], Optional[str]], default=(None, None) - A tuple containing the gRPC and HTTP services for the cached endpoint. - Either the gRPC or HTTP service can be empty, but not both. - - deplot_endpoints : Tuple[Optional[str], Optional[str]], default=(None, None) - A tuple containing the gRPC and HTTP services for the deplot endpoint. + yolox_endpoints : Tuple[Optional[str], Optional[str]], default=(None, None) + A tuple containing the gRPC and HTTP services for the yolox endpoint. Either the gRPC or HTTP service can be empty, but not both. paddle_endpoints : Tuple[Optional[str], Optional[str]], default=(None, None) @@ -50,13 +46,9 @@ class ChartExtractorConfigSchema(BaseModel): auth_token: Optional[str] = None - cached_endpoints: Tuple[Optional[str], Optional[str]] = (None, None) - cached_infer_protocol: str = "" - - deplot_endpoints: Tuple[Optional[str], Optional[str]] = (None, None) - deplot_infer_protocol: str = "" + yolox_endpoints: Tuple[Optional[str], Optional[str]] = (None, None) + yolox_infer_protocol: str = "" - ## NOTE: Paddle isn't currently called independently of the cached NIM, but will be in the future. paddle_endpoints: Tuple[Optional[str], Optional[str]] = (None, None) paddle_infer_protocol: str = "" @@ -94,7 +86,7 @@ def clean_service(service): return None return service - for endpoint_name in ["cached_endpoints", "deplot_endpoints", "paddle_endpoints"]: + for endpoint_name in ["yolox_endpoints", "paddle_endpoints"]: grpc_service, http_service = values.get(endpoint_name, (None, None)) grpc_service = clean_service(grpc_service) http_service = clean_service(http_service) @@ -125,7 +117,7 @@ class ChartExtractorSchema(BaseModel): A flag indicating whether to raise an exception if a failure occurs during chart extraction. stage_config : Optional[ChartExtractorConfigSchema], default=None - Configuration for the chart extraction stage, including cached, deplot, and paddle service endpoints. + Configuration for the chart extraction stage, including yolox and paddle service endpoints. """ max_queue_size: int = 1 diff --git a/src/nv_ingest/stages/nim/chart_extraction.py b/src/nv_ingest/stages/nim/chart_extraction.py index fcdc5985..e744a0c2 100644 --- a/src/nv_ingest/stages/nim/chart_extraction.py +++ b/src/nv_ingest/stages/nim/chart_extraction.py @@ -5,34 +5,41 @@ import functools import logging from concurrent.futures import ThreadPoolExecutor -from typing import Any, List +from typing import Any from typing import Dict +from typing import List from typing import Optional from typing import Tuple +import numpy as np import pandas as pd from morpheus.config import Config from nv_ingest.schemas.chart_extractor_schema import ChartExtractorSchema from nv_ingest.stages.multiprocessing_stage import MultiProcessingBaseStage -from nv_ingest.util.image_processing.table_and_chart import join_cached_and_deplot_output -from nv_ingest.util.nim.cached import CachedModelInterface -from nv_ingest.util.nim.deplot import DeplotModelInterface -from nv_ingest.util.nim.helpers import create_inference_client +from nv_ingest.util.image_processing.table_and_chart import join_yolox_and_paddle_output +from nv_ingest.util.image_processing.table_and_chart import process_yolox_graphic_elements +from nv_ingest.util.image_processing.transforms import base64_to_numpy from nv_ingest.util.nim.helpers import NimClient +from nv_ingest.util.nim.helpers import create_inference_client +from nv_ingest.util.nim.paddle import PaddleOCRModelInterface +from nv_ingest.util.nim.yolox import YoloxGraphicElementsModelInterface logger = logging.getLogger(f"morpheus.{__name__}") +PADDLE_MIN_WIDTH = 32 +PADDLE_MIN_HEIGHT = 32 + def _update_metadata( base64_images: List[str], - cached_client: NimClient, - deplot_client: NimClient, + yolox_client: NimClient, + paddle_client: NimClient, trace_info: Dict, worker_pool_size: int = 8, # Not currently used. ) -> List[Tuple[str, Dict]]: """ - Given a list of base64-encoded chart images, this function calls both the Cached and Deplot + Given a list of base64-encoded chart images, this function calls both the Yolox and Paddle inference services concurrently to extract chart data for all images. For each base64-encoded image, returns: @@ -40,86 +47,99 @@ def _update_metadata( """ logger.debug("Running chart extraction using updated concurrency handling.") + valid_images: List[str] = [] + valid_arrays: List[np.ndarray] = [] + + # Pre-decode image dimensions and filter valid images. + for i, img in enumerate(base64_images): + array = base64_to_numpy(img) + height, width = array.shape[0], array.shape[1] + if width >= PADDLE_MIN_WIDTH and height >= PADDLE_MIN_HEIGHT: + valid_images.append(img) + valid_arrays.append(array) + # Prepare data payloads for both clients. - data_cached = {"base64_images": base64_images} - data_deplot = {"base64_images": base64_images} + data_yolox = {"images": valid_arrays} + data_paddle = {"base64_images": valid_images} _ = worker_pool_size with ThreadPoolExecutor(max_workers=2) as executor: - future_cached = executor.submit( - cached_client.infer, - data=data_cached, - model_name="cached", + future_yolox = executor.submit( + yolox_client.infer, + data=data_yolox, + model_name="yolox", stage_name="chart_data_extraction", - max_batch_size=1 if cached_client.protocol == "grpc" else 2, + max_batch_size=8, trace_info=trace_info, ) - future_deplot = executor.submit( - deplot_client.infer, - data=data_deplot, - model_name="deplot", + future_paddle = executor.submit( + paddle_client.infer, + data=data_paddle, + model_name="paddle", stage_name="chart_data_extraction", - max_batch_size=1, + max_batch_size=1 if paddle_client.protocol == "grpc" else 2, trace_info=trace_info, ) try: - cached_results = future_cached.result() + yolox_results = future_yolox.result() except Exception as e: - logger.error(f"Error calling cached_client.infer: {e}", exc_info=True) + logger.error(f"Error calling yolox_client.infer: {e}", exc_info=True) raise try: - deplot_results = future_deplot.result() + paddle_results = future_paddle.result() except Exception as e: - logger.error(f"Error calling deplot_client.infer: {e}", exc_info=True) + logger.error(f"Error calling yolox_client.infer: {e}", exc_info=True) raise # Ensure both clients returned lists of results matching the number of input images. - if not (isinstance(cached_results, list) and isinstance(deplot_results, list)): - raise ValueError("Expected list results from both cached_client and deplot_client infer calls.") + if not (isinstance(yolox_results, list) and isinstance(paddle_results, list)): + raise ValueError("Expected list results from both yolox_client and paddle_client infer calls.") - if len(cached_results) != len(base64_images): - raise ValueError(f"Expected {len(base64_images)} cached results, got {len(cached_results)}") - if len(deplot_results) != len(base64_images): - raise ValueError(f"Expected {len(base64_images)} deplot results, got {len(deplot_results)}") + if len(yolox_results) != len(base64_images): + raise ValueError(f"Expected {len(base64_images)} yolox results, got {len(yolox_results)}") + if len(paddle_results) != len(base64_images): + raise ValueError(f"Expected {len(base64_images)} paddle results, got {len(paddle_results)}") # Join the corresponding results from both services for each image. results = [] - for img_str, cached_res, deplot_res in zip(base64_images, cached_results, deplot_results): - joined_chart_content = join_cached_and_deplot_output(cached_res, deplot_res) - results.append((img_str, joined_chart_content)) + for img_str, yolox_res, paddle_res in zip(base64_images, yolox_results, paddle_results): + bounding_boxes, text_predictions = paddle_res + yolox_elements = join_yolox_and_paddle_output(yolox_res, bounding_boxes, text_predictions) + chart_content = process_yolox_graphic_elements(yolox_elements) + results.append((img_str, chart_content)) return results def _create_clients( - cached_endpoints: Tuple[str, str], - cached_protocol: str, - deplot_endpoints: Tuple[str, str], - deplot_protocol: str, + yolox_endpoints: Tuple[str, str], + yolox_protocol: str, + paddle_endpoints: Tuple[str, str], + paddle_protocol: str, auth_token: str, ) -> Tuple[NimClient, NimClient]: - cached_model_interface = CachedModelInterface() - deplot_model_interface = DeplotModelInterface() + yolox_model_interface = YoloxGraphicElementsModelInterface() + paddle_model_interface = PaddleOCRModelInterface() - logger.debug(f"Inference protocols: cached={cached_protocol}, deplot={deplot_protocol}") + logger.debug(f"Inference protocols: yolox={yolox_protocol}, paddle={paddle_protocol}") - cached_client = create_inference_client( - endpoints=cached_endpoints, - model_interface=cached_model_interface, + yolox_client = create_inference_client( + endpoints=yolox_endpoints, + model_interface=yolox_model_interface, auth_token=auth_token, - infer_protocol=cached_protocol, + infer_protocol=yolox_protocol, ) - deplot_client = create_inference_client( - endpoints=deplot_endpoints, - model_interface=deplot_model_interface, + paddle_client = create_inference_client( + endpoints=paddle_endpoints, + model_interface=paddle_model_interface, auth_token=auth_token, - infer_protocol=deplot_protocol, + infer_protocol=paddle_protocol, ) - return cached_client, deplot_client + return yolox_client, paddle_client def _extract_chart_data( @@ -159,11 +179,11 @@ def _extract_chart_data( return df, trace_info stage_config = validated_config.stage_config - cached_client, deplot_client = _create_clients( - stage_config.cached_endpoints, - stage_config.cached_infer_protocol, - stage_config.deplot_endpoints, - stage_config.deplot_infer_protocol, + yolox_client, paddle_client = _create_clients( + stage_config.yolox_endpoints, + stage_config.yolox_infer_protocol, + stage_config.paddle_endpoints, + stage_config.paddle_infer_protocol, stage_config.auth_token, ) @@ -204,8 +224,8 @@ def meets_criteria(row): # 3) Call our bulk update_metadata to get all results bulk_results = _update_metadata( base64_images=base64_images, - cached_client=cached_client, - deplot_client=deplot_client, + yolox_client=yolox_client, + paddle_client=paddle_client, worker_pool_size=stage_config.workers_per_progress_engine, trace_info=trace_info, ) @@ -214,7 +234,7 @@ def meets_criteria(row): # The order of base64_images in bulk_results should match their original # indices if we process them in the same order. for row_id, idx in enumerate(valid_indices): - (_, chart_content) = bulk_results[row_id] + _, chart_content = bulk_results[row_id] df.at[idx, "metadata"]["table_metadata"]["table_content"] = chart_content return df, {"trace_info": trace_info} @@ -223,8 +243,8 @@ def meets_criteria(row): logger.error("Error occurred while extracting chart data.", exc_info=True) raise finally: - cached_client.close() - deplot_client.close() + yolox_client.close() + paddle_client.close() def generate_chart_extractor_stage( diff --git a/src/nv_ingest/stages/nim/table_extraction.py b/src/nv_ingest/stages/nim/table_extraction.py index b980e847..000902c1 100644 --- a/src/nv_ingest/stages/nim/table_extraction.py +++ b/src/nv_ingest/stages/nim/table_extraction.py @@ -4,16 +4,23 @@ import functools import logging -from typing import Any, Dict, List, Optional, Tuple +from typing import Any +from typing import Dict +from typing import List +from typing import Optional +from typing import Tuple import pandas as pd - from morpheus.config import Config +from nv_ingest.schemas.metadata_schema import TableFormatEnum from nv_ingest.schemas.table_extractor_schema import TableExtractorSchema from nv_ingest.stages.multiprocessing_stage import MultiProcessingBaseStage +from nv_ingest.util.image_processing.table_and_chart import convert_paddle_response_to_psuedo_markdown from nv_ingest.util.image_processing.transforms import base64_to_numpy -from nv_ingest.util.nim.helpers import create_inference_client, NimClient, get_version +from nv_ingest.util.nim.helpers import NimClient +from nv_ingest.util.nim.helpers import create_inference_client +from nv_ingest.util.nim.helpers import get_version from nv_ingest.util.nim.paddle import PaddleOCRModelInterface logger = logging.getLogger(__name__) @@ -33,7 +40,7 @@ def _update_metadata( size requirements and then calls the PaddleOCR model via paddle_client.infer to extract table data. For each base64-encoded image, the result is: - (base64_image, (table_content, table_content_format)) + (base64_image, (text_predictions, bounding_boxes)) Images that do not meet the minimum size are skipped (resulting in ("", "") for that image). The paddle_client is expected to handle any necessary batching and concurrency. @@ -56,7 +63,7 @@ def _update_metadata( valid_indices.append(i) else: # Image is too small; mark as skipped. - results[i] = (img, ("", "")) + results[i] = (img, (None, None)) if valid_images: data = {"base64_images": valid_images} @@ -83,7 +90,7 @@ def _update_metadata( except Exception as e: logger.error(f"Error processing images. Error: {e}", exc_info=True) for i in valid_indices: - results[i] = (base64_images[i], ("", "")) + results[i] = (base64_images[i], (None, None)) raise return results @@ -104,7 +111,7 @@ def _create_paddle_client(stage_config) -> NimClient: logger.warning("Failed to get PaddleOCR version after 30 seconds. Falling back to the latest version.") paddle_version = None - paddle_model_interface = PaddleOCRModelInterface(paddle_version=paddle_version) + paddle_model_interface = PaddleOCRModelInterface() paddle_client = create_inference_client( endpoints=stage_config.paddle_endpoints, @@ -189,10 +196,21 @@ def meets_criteria(row): trace_info=trace_info, ) - # 4) Write the results (table_content, table_content_format) back + # 4) Write the results (bounding_boxes, text_predictions) back + table_content_format = df.at[valid_indices[0], "metadata"]["table_metadata"].get( + "table_content_format", TableFormatEnum.PSEUDO_MARKDOWN + ) + for row_id, idx in enumerate(valid_indices): - # unpack (base64_image, (content, format)) - _, (table_content, table_content_format) = bulk_results[row_id] + # unpack (base64_image, (bounding boxes, text_predictions)) + _, (bounding_boxes, text_predictions) = bulk_results[row_id] + + if table_content_format == TableFormatEnum.SIMPLE: + table_content = " ".join(text_predictions) + elif table_content_format == TableFormatEnum.PSEUDO_MARKDOWN: + table_content = convert_paddle_response_to_psuedo_markdown(bounding_boxes, text_predictions) + else: + raise ValueError(f"Unexpected table format: {table_content_format}") df.at[idx, "metadata"]["table_metadata"]["table_content"] = table_content df.at[idx, "metadata"]["table_metadata"]["table_content_format"] = table_content_format diff --git a/src/nv_ingest/util/image_processing/table_and_chart.py b/src/nv_ingest/util/image_processing/table_and_chart.py index 3a016aab..4e3d6e91 100644 --- a/src/nv_ingest/util/image_processing/table_and_chart.py +++ b/src/nv_ingest/util/image_processing/table_and_chart.py @@ -3,77 +3,184 @@ # SPDX-License-Identifier: Apache-2.0 -import json import logging +import re + +import numpy as np +import pandas as pd +from sklearn.cluster import DBSCAN + logger = logging.getLogger(__name__) -def join_cached_and_deplot_output(cached_text, deplot_text): +def process_yolox_graphic_elements(yolox_text_dict): """ - Process the inference results from cached and deplot models. + Process the inference results from yolox-graphic-elements model. Parameters ---------- - cached_text : str - The result from the cached model inference, expected to be a JSON string or plain text. - deplot_text : str - The result from the deplot model inference, expected to be plain text. + yolox_text : str + The result from the yolox model inference. Returns ------- str The concatenated and processed chart content as a string. - - Notes - ----- - This function attempts to parse the `cached_text` as JSON to extract specific fields. - If parsing fails, it falls back to using the raw `cached_text`. The `deplot_text` is then - appended to this content. - - Examples - -------- - >>> cached_text = '{"chart_title": "Sales Over Time"}' - >>> deplot_text = "This chart shows the sales over time." - >>> result = join_cached_and_deplot_output(cached_text, deplot_text) - >>> print(result) - "Sales Over Time This chart shows the sales over time." """ chart_content = "" - if cached_text is not None: - try: - if isinstance(cached_text, str): - cached_text_dict = json.loads(cached_text) - elif isinstance(cached_text, dict): - cached_text_dict = cached_text - else: - cached_text_dict = {} - - chart_content += cached_text_dict.get("chart_title", "") - - if deplot_text is not None: - chart_content += f" {deplot_text}" - - chart_content += " " + cached_text_dict.get("caption", "") - chart_content += " " + cached_text_dict.get("info_deplot", "") - chart_content += " " + cached_text_dict.get("x_title", "") - chart_content += " " + cached_text_dict.get("xlabel", "") - chart_content += " " + cached_text_dict.get("y_title", "") - chart_content += " " + cached_text_dict.get("ylabel", "") - chart_content += " " + cached_text_dict.get("legend_label", "") - chart_content += " " + cached_text_dict.get("legend_title", "") - chart_content += " " + cached_text_dict.get("mark_label", "") - chart_content += " " + cached_text_dict.get("value_label", "") - chart_content += " " + cached_text_dict.get("other", "") - except json.JSONDecodeError: - chart_content += cached_text - - if deplot_text is not None: - chart_content += f" {deplot_text}" - - else: - if deplot_text is not None: - chart_content += f" {deplot_text}" - - return chart_content + chart_content += yolox_text_dict.get("chart_title", "") + + chart_content += " " + yolox_text_dict.get("caption", "") + chart_content += " " + yolox_text_dict.get("x_title", "") + chart_content += " " + yolox_text_dict.get("xlabel", "") + chart_content += " " + yolox_text_dict.get("y_title", "") + chart_content += " " + yolox_text_dict.get("ylabel", "") + chart_content += " " + yolox_text_dict.get("legend_label", "") + chart_content += " " + yolox_text_dict.get("legend_title", "") + chart_content += " " + yolox_text_dict.get("mark_label", "") + chart_content += " " + yolox_text_dict.get("value_label", "") + chart_content += " " + yolox_text_dict.get("other", "") + + return chart_content.strip() + + +def match_bboxes(yolox_box, paddle_ocr_boxes, already_matched=None, delta=2.0): + """ + Associates a yolox-graphic-elements box to PaddleOCR bboxes, by taking overlapping boxes. + Criterion is iou > max_iou / delta where max_iou is the biggest found overlap. + Boxes are expeceted in format (x0, y0, x1, y1) + Args: + yolox_box (np array [4]): Cached Bbox. + paddle_ocr_boxes (np array [n x 4]): PaddleOCR boxes + already_matched (list or None, Optional): Already matched ids to ignore. + delta (float, Optional): IoU delta for considering several boxes. Defaults to 2.. + Returns: + np array or list: Indices of the match bboxes + """ + x0_1, y0_1, x1_1, y1_1 = yolox_box + x0_2, y0_2, x1_2, y1_2 = ( + paddle_ocr_boxes[:, 0], + paddle_ocr_boxes[:, 1], + paddle_ocr_boxes[:, 2], + paddle_ocr_boxes[:, 3], + ) + + # Intersection + inter_y0 = np.maximum(y0_1, y0_2) + inter_y1 = np.minimum(y1_1, y1_2) + inter_x0 = np.maximum(x0_1, x0_2) + inter_x1 = np.minimum(x1_1, x1_2) + inter_area = np.maximum(0, inter_y1 - inter_y0) * np.maximum(0, inter_x1 - inter_x0) + + # Union + area_1 = (y1_1 - y0_1) * (x1_1 - x0_1) + area_2 = (y1_2 - y0_2) * (x1_2 - x0_2) + union_area = area_1 + area_2 - inter_area + + # IoU + ious = inter_area / union_area + + max_iou = np.max(ious) + if max_iou <= 0.01: + return [] + + matches = np.where(ious > (max_iou / delta))[0] + if already_matched is not None: + matches = np.array([m for m in matches if m not in already_matched]) + return matches + + +def join_yolox_and_paddle_output(yolox_output, paddle_boxes, paddle_txts): + """ + Matching boxes + We need to associate a text to the paddle detections. + For each class and for each CACHED detections, we look for overlapping text bboxes + with IoU > max_iou / delta where max_iou is the biggest found overlap. + Found texts are added to the class representation, and removed from the texts to match + """ + KEPT_CLASSES = [ # Used CACHED classes, corresponds to YoloX classes + "chart_title", + "x_title", + "y_title", + "xlabel", + "ylabel", + "other", + "legend_label", + "legend_title", + "mark_label", + "value_label", + ] + + paddle_txts = np.array(paddle_txts) + paddle_boxes = np.array(paddle_boxes) + + if (paddle_txts.size == 0) or (paddle_boxes.size == 0): + return {} + + paddle_boxes = np.array( + [ + paddle_boxes[:, :, 0].min(-1), + paddle_boxes[:, :, 1].min(-1), + paddle_boxes[:, :, 0].max(-1), + paddle_boxes[:, :, 1].max(-1), + ] + ).T + + already_matched = [] + results = {} + + for k in KEPT_CLASSES: + if not len(yolox_output.get(k, [])): # No bounding boxes + continue + + texts = [] + for yolox_box in yolox_output[k]: + # if there's a score at the end, drop the score. + yolox_box = yolox_box[:4] + paddle_ids = match_bboxes(yolox_box, paddle_boxes, already_matched=already_matched, delta=4) + + if len(paddle_ids) > 0: + text = " ".join(paddle_txts[paddle_ids].tolist()) + texts.append(text) + + processed_texts = [] + for t in texts: + t = re.sub(r"\s+", " ", t) + t = re.sub(r"\.+", ".", t) + processed_texts.append(t) + + if "title" in k: + processed_texts = " ".join(processed_texts) + else: + processed_texts = " - ".join(processed_texts) # Space ? + + results[k] = processed_texts + + return results + + +def convert_paddle_response_to_psuedo_markdown(bboxes, texts): + if (not bboxes) or (not texts): + return "" + + bboxes = np.array(bboxes).astype(int) + bboxes = bboxes.reshape(-1, 8)[:, [0, 1, 2, -1]] + + preds_df = pd.DataFrame( + {"x0": bboxes[:, 0], "y0": bboxes[:, 1], "x1": bboxes[:, 2], "y1": bboxes[:, 3], "text": texts} + ) + preds_df = preds_df.sort_values("y0") + + dbscan = DBSCAN(eps=10, min_samples=1) + dbscan.fit(preds_df["y0"].values[:, None]) + + preds_df["cluster"] = dbscan.labels_ + preds_df = preds_df.sort_values(["cluster", "x0"]) + + results = "" + for _, dfg in preds_df.groupby("cluster"): + results += "| " + " | ".join(dfg["text"].values.tolist()) + " |\n" + + return results diff --git a/src/nv_ingest/util/nim/helpers.py b/src/nv_ingest/util/nim/helpers.py index 40dc10ec..e2d03abd 100644 --- a/src/nv_ingest/util/nim/helpers.py +++ b/src/nv_ingest/util/nim/helpers.py @@ -16,7 +16,6 @@ import numpy as np import requests import tritonclient.grpc as grpcclient -from packaging import version as pkgversion from nv_ingest.util.image_processing.transforms import normalize_image from nv_ingest.util.image_processing.transforms import pad_image @@ -468,7 +467,7 @@ def create_inference_client( return NimClient(model_interface, infer_protocol, endpoints, auth_token) -def preprocess_image_for_paddle(array: np.ndarray, paddle_version: Optional[str] = None) -> np.ndarray: +def preprocess_image_for_paddle(array: np.ndarray, image_max_dimension: int = 960) -> np.ndarray: """ Preprocesses an input image to be suitable for use with PaddleOCR by resizing, normalizing, padding, and transposing it into the required format. @@ -501,11 +500,8 @@ def preprocess_image_for_paddle(array: np.ndarray, paddle_version: Optional[str] a requirement for PaddleOCR. - The normalized pixel values are scaled between 0 and 1 before padding and transposing the image. """ - if (not paddle_version) or (pkgversion.parse(paddle_version) < pkgversion.parse("0.2.0-rc1")): - return array - height, width = array.shape[:2] - scale_factor = 960 / max(height, width) + scale_factor = image_max_dimension / max(height, width) new_height = int(height * scale_factor) new_width = int(width * scale_factor) resized = cv2.resize(array, (new_width, new_height)) @@ -515,14 +511,25 @@ def preprocess_image_for_paddle(array: np.ndarray, paddle_version: Optional[str] # PaddleOCR NIM (GRPC) requires input shapes to be multiples of 32. new_height = (normalized.shape[0] + 31) // 32 * 32 new_width = (normalized.shape[1] + 31) // 32 * 32 - padded, _ = pad_image( + padded, (pad_width, pad_height) = pad_image( normalized, target_height=new_height, target_width=new_width, background_color=0, dtype=np.float32 ) # PaddleOCR NIM (GRPC) requires input to be (channel, height, width). transposed = padded.transpose((2, 0, 1)) - return transposed + # Metadata can used for inverting transformations on the resulting bounding boxes. + metadata = { + "original_height": height, + "original_width": width, + "scale_factor": scale_factor, + "new_height": transposed.shape[1], + "new_width": transposed.shape[2], + "pad_height": pad_height, + "pad_width": pad_width, + } + + return transposed, metadata def remove_url_endpoints(url) -> str: diff --git a/src/nv_ingest/util/nim/paddle.py b/src/nv_ingest/util/nim/paddle.py index da5c2c6b..98b99efd 100644 --- a/src/nv_ingest/util/nim/paddle.py +++ b/src/nv_ingest/util/nim/paddle.py @@ -5,11 +5,7 @@ from typing import Optional import numpy as np -import pandas as pd -from packaging import version as pkgversion -from sklearn.cluster import DBSCAN -from nv_ingest.schemas.metadata_schema import TableFormatEnum from nv_ingest.util.image_processing.transforms import base64_to_numpy from nv_ingest.util.nim.helpers import ModelInterface from nv_ingest.util.nim.helpers import preprocess_image_for_paddle @@ -22,17 +18,6 @@ class PaddleOCRModelInterface(ModelInterface): An interface for handling inference with a PaddleOCR model, supporting both gRPC and HTTP protocols. """ - def __init__(self, paddle_version: Optional[str] = None) -> None: - """ - Initialize the PaddleOCR model interface. - - Parameters - ---------- - paddle_version : str, optional - The version of the PaddleOCR model (default is None). - """ - self.paddle_version: Optional[str] = paddle_version - def name(self) -> str: """ Get the name of the model interface. @@ -40,9 +25,9 @@ def name(self) -> str: Returns ------- str - The name of the model interface, including the PaddleOCR version. + The name of the model interface. """ - return f"PaddleOCR - {self.paddle_version}" + return "PaddleOCR" def prepare_data_for_inference(self, data: Dict[str, Any]) -> Dict[str, Any]: """ @@ -76,20 +61,16 @@ def prepare_data_for_inference(self, data: Dict[str, Any]) -> Dict[str, Any]: raise ValueError("The 'base64_images' key must contain a list of base64-encoded strings.") image_arrays: List[np.ndarray] = [] - dims: List[Tuple[int, int]] = [] for b64 in base64_list: img = base64_to_numpy(b64) image_arrays.append(img) - dims.append((img.shape[0], img.shape[1])) data["image_arrays"] = image_arrays - data["image_dims"] = dims elif "base64_image" in data: # Single-image fallback img = base64_to_numpy(data["base64_image"]) data["image_arrays"] = [img] - data["image_dims"] = [(img.shape[0], img.shape[1])] else: raise KeyError("Input data must include 'base64_image' or 'base64_images'.") @@ -126,6 +107,12 @@ def format_input(self, data: Dict[str, Any], protocol: str, max_batch_size: int, If an invalid protocol is specified. """ + images = data["image_arrays"] + + dims: List[Dict[str, Any]] = [] + data["image_dims"] = dims + + # Helper function to split a list into chunks of size up to chunk_size. def chunk_list(lst, chunk_size): return [lst[i : i + chunk_size] for i in range(0, len(lst), chunk_size)] @@ -139,7 +126,9 @@ def chunk_list(lst, chunk_size): logger.debug("Formatting input for gRPC PaddleOCR model (batched).") processed: List[np.ndarray] = [] for img in images: - arr = preprocess_image_for_paddle(img, self.paddle_version).astype(np.float32) + arr, _dims = preprocess_image_for_paddle(img) + dims.append(_dims) + arr = arr.astype(np.float32) arr = np.expand_dims(arr, axis=0) # => shape (1, H, W, C) processed.append(arr) @@ -162,44 +151,23 @@ def chunk_list(lst, chunk_size): else: base64_list = [data["base64_image"]] - if self._is_version_early_access_legacy_api(): - content_list: List[Dict[str, Any]] = [] - for b64 in base64_list: - image_url = f"data:image/png;base64,{b64}" - image_obj = {"type": "image_url", "image_url": {"url": image_url}} - content_list.append(image_obj) - - batches = [] - batch_data_list = [] - for content_chunk, orig_chunk, dims_chunk in zip( - chunk_list(content_list, max_batch_size), - chunk_list(images, max_batch_size), - chunk_list(dims, max_batch_size), - ): - message = {"content": content_chunk} - payload = {"messages": [message]} - batches.append(payload) - batch_data_list.append({"image_arrays": orig_chunk, "image_dims": dims_chunk}) - return batches, batch_data_list + input_list: List[Dict[str, Any]] = [] + for b64 in base64_list: + image_url = f"data:image/png;base64,{b64}" + image_obj = {"type": "image_url", "url": image_url} + input_list.append(image_obj) - else: - input_list: List[Dict[str, Any]] = [] - for b64 in base64_list: - image_url = f"data:image/png;base64,{b64}" - image_obj = {"type": "image_url", "url": image_url} - input_list.append(image_obj) - - batches = [] - batch_data_list = [] - for input_chunk, orig_chunk, dims_chunk in zip( - chunk_list(input_list, max_batch_size), - chunk_list(images, max_batch_size), - chunk_list(dims, max_batch_size), - ): - payload = {"input": input_chunk} - batches.append(payload) - batch_data_list.append({"image_arrays": orig_chunk, "image_dims": dims_chunk}) - return batches, batch_data_list + batches = [] + batch_data_list = [] + for input_chunk, orig_chunk, dims_chunk in zip( + chunk_list(input_list, max_batch_size), + chunk_list(images, max_batch_size), + chunk_list(dims, max_batch_size), + ): + payload = {"input": input_chunk} + batches.append(payload) + batch_data_list.append({"image_arrays": orig_chunk, "image_dims": dims_chunk}) + return batches, batch_data_list else: raise ValueError("Invalid protocol specified. Must be 'grpc' or 'http'.") @@ -230,29 +198,16 @@ def parse_output(self, response: Any, protocol: str, data: Optional[Dict[str, An ValueError If an invalid protocol is specified. """ - default_table_content_format = ( - TableFormatEnum.SIMPLE if self._is_version_early_access_legacy_api() else TableFormatEnum.PSEUDO_MARKDOWN - ) - table_content_format = kwargs.get("table_content_format", default_table_content_format) - - # Enforce legacy constraints - if self._is_version_early_access_legacy_api() and table_content_format != TableFormatEnum.SIMPLE: - logger.warning( - f"Paddle version {self.paddle_version} does not support {table_content_format} format. " - "The table content will be in `simple` format." - ) - table_content_format = TableFormatEnum.SIMPLE - # Retrieve image dimensions if available dims: Optional[List[Tuple[int, int]]] = data.get("image_dims") if data else None if protocol == "grpc": logger.debug("Parsing output from gRPC PaddleOCR model (batched).") - return self._extract_content_from_paddle_grpc_response(response, table_content_format, dims) + return self._extract_content_from_paddle_grpc_response(response, dims) elif protocol == "http": logger.debug("Parsing output from HTTP PaddleOCR model (batched).") - return self._extract_content_from_paddle_http_response(response, table_content_format, dims) + return self._extract_content_from_paddle_http_response(response) else: raise ValueError("Invalid protocol specified. Must be 'grpc' or 'http'.") @@ -276,18 +231,6 @@ def process_inference_results(self, output: Any, **kwargs: Any) -> Any: """ return output - def _is_version_early_access_legacy_api(self) -> bool: - """ - Determine if the current PaddleOCR version is considered "early access" and thus uses - the legacy API format. - - Returns - ------- - bool - True if the version is < 0.2.1-rc2; False otherwise. - """ - return self.paddle_version is not None and pkgversion.parse(self.paddle_version) < pkgversion.parse("0.2.1-rc2") - def _prepare_paddle_payload(self, base64_img: str) -> Dict[str, Any]: """ DEPRECATED by batch logic in format_input. Kept here if you need single-image direct calls. @@ -304,18 +247,15 @@ def _prepare_paddle_payload(self, base64_img: str) -> Dict[str, Any]: """ image_url = f"data:image/png;base64,{base64_img}" - if self._is_version_early_access_legacy_api(): - image = {"type": "image_url", "image_url": {"url": image_url}} - message = {"content": [image]} - payload = {"messages": [message]} - else: - image = {"type": "image_url", "url": image_url} - payload = {"input": [image]} + image = {"type": "image_url", "url": image_url} + payload = {"input": [image]} return payload def _extract_content_from_paddle_http_response( - self, json_response: Dict[str, Any], table_content_format: Optional[str], dims: Optional[List[Tuple[int, int]]] + self, + json_response: Dict[str, Any], + table_content_format: Optional[str], ) -> List[Tuple[str, str]]: """ Extract content from the JSON response of a PaddleOCR HTTP API request. @@ -326,9 +266,6 @@ def _extract_content_from_paddle_http_response( The JSON response returned by the PaddleOCR endpoint. table_content_format : str or None The specified format for table content (e.g., 'simple' or 'pseudo_markdown'). - dims : list of (int, int), optional - A list of (height, width) for each corresponding image, used for bounding box - scaling if not None. Returns ------- @@ -347,33 +284,21 @@ def _extract_content_from_paddle_http_response( results: List[str] = [] for item_idx, item in enumerate(json_response["data"]): - if self._is_version_early_access_legacy_api(): - content = item.get("content", "") - else: - text_detections = item.get("text_detections", []) - text_predictions: List[str] = [] - bounding_boxes: List[List[Tuple[float, float]]] = [] - - for td in text_detections: - text_predictions.append(td["text_prediction"]["text"]) - bounding_boxes.append([(pt["x"], pt["y"]) for pt in td["bounding_box"]["points"]]) - - if table_content_format == TableFormatEnum.SIMPLE: - content = " ".join(text_predictions) - elif table_content_format == TableFormatEnum.PSEUDO_MARKDOWN: - content = self._convert_paddle_response_to_psuedo_markdown( - bounding_boxes, text_predictions, img_index=item_idx, dims=dims - ) - else: - raise ValueError(f"Unexpected table format: {table_content_format}") + text_detections = item.get("text_detections", []) + text_predictions = [] + bounding_boxes = [] + for td in text_detections: + text_predictions.append(td["text_prediction"]["text"]) + bounding_boxes.append([[pt["x"], pt["y"]] for pt in td["bounding_box"]["points"]]) - results.append(content) + results.append([bounding_boxes, text_predictions]) - # Convert each content into a tuple (content, format). - return [(content, table_content_format) for content in results] + return results def _extract_content_from_paddle_grpc_response( - self, response: np.ndarray, table_content_format: str, dims: Optional[List[Tuple[int, int]]] + self, + response: np.ndarray, + dimensions: List[Dict[str, Any]], ) -> List[Tuple[str, str]]: """ Parse a gRPC response for one or more images. The response can have two possible shapes: @@ -391,8 +316,8 @@ def _extract_content_from_paddle_grpc_response( The raw NumPy array from gRPC. Expected shape: (3,) or (3, n). table_content_format : str The format of the output text content, e.g. 'simple' or 'pseudo_markdown'. - dims : list of (int, int), optional - A list of (height, width) for each corresponding image, used for bounding box scaling. + dims : list of dict, optional + A list of dict for each corresponding image, used for bounding box scaling. Returns ------- @@ -436,29 +361,27 @@ def _extract_content_from_paddle_grpc_response( if isinstance(text_predictions, list) and len(text_predictions) == 1: text_predictions = text_predictions[0] - # Construct the content string - if table_content_format == TableFormatEnum.SIMPLE: - content = " ".join(text_predictions) - elif table_content_format == TableFormatEnum.PSEUDO_MARKDOWN: - content = self._convert_paddle_response_to_psuedo_markdown( - bounding_boxes, text_predictions, img_index=i, dims=dims - ) - else: - raise ValueError(f"Unexpected table format: {table_content_format}") + bounding_boxes, text_predictions = self._postprocess_paddle_response( + bounding_boxes, + text_predictions, + dimensions, + img_index=i, + ) - results.append((content, table_content_format)) + results.append([bounding_boxes, text_predictions]) return results @staticmethod - def _convert_paddle_response_to_psuedo_markdown( + def _postprocess_paddle_response( bounding_boxes: List[Any], text_predictions: List[str], + dims: Optional[List[Dict[str, Any]]] = None, img_index: int = 0, - dims: Optional[List[Tuple[int, int]]] = None, - ) -> str: + ) -> Tuple[List[Any], List[str]]: """ - Convert bounding boxes & text to pseudo-markdown format. For multiple images, + Convert bounding boxes with normalized coordinates to pixel cooridnates by using + the dimensions. Also shift the coorindates if the inputs were padded. For multiple images, the correct image dimensions (height, width) are retrieved from `dims[img_index]`. Parameters @@ -469,14 +392,19 @@ def _convert_paddle_response_to_psuedo_markdown( A list of text predictions, one for each bounding box. img_index : int, optional The index of the image for which bounding boxes are being converted. Default is 0. - dims : list of (int, int), optional - A list of (height, width) for each corresponding image. + dims : list of dict, optional + A list of dictionaries, where each dictionary contains image-specific dimensions + and scaling information: + - "new_width" (int): The width of the image after processing. + - "new_height" (int): The height of the image after processing. + - "pad_width" (int, optional): The width of padding added to the image. + - "pad_height" (int, optional): The height of padding added to the image. + - "scale_factor" (float, optional): The scaling factor applied to the image. Returns ------- - str - The pseudo-markdown representation of detected text lines and bounding boxes. - Each cluster of text is placed on its own line, with text columns separated by '|'. + Tuple[List[Any], List[str]] + Bounding boxes scaled backed to the original dimensions and detected text lines. Notes ----- @@ -484,16 +412,19 @@ def _convert_paddle_response_to_psuedo_markdown( """ # Default to no scaling if dims are missing or out of range if not dims: - logger.warning("No image_dims provided; bounding boxes will not be scaled.") - target_h, target_w = 1, 1 + raise ValueError("No image_dims provided.") else: if img_index >= len(dims): logger.warning("Image index out of range for stored dimensions. Using first image dims by default.") - target_h, target_w = dims[0] - else: - target_h, target_w = dims[img_index] + img_index = 0 + + max_width = dims[img_index]["new_width"] + max_height = dims[img_index]["new_height"] + pad_width = dims[img_index].get("pad_width", 0) + pad_height = dims[img_index].get("pad_height", 0) + scale_factor = dims[img_index].get("scale_factor", 1.0) - scaled_boxes: List[List[float]] = [] + bboxes: List[List[float]] = [] texts: List[str] = [] # Convert normalized coords back to actual pixel coords @@ -502,43 +433,14 @@ def _convert_paddle_response_to_psuedo_markdown( continue points: List[List[float]] = [] for point in box: - x = float(point[0]) * target_w - y = float(point[1]) * target_h - points.append([x, y]) - scaled_boxes.append(points) + # Convert normalized coords back to actual pixel coords, + # and shift them back to their original positions if padded. + x_pixels = float(point[0]) * max_width - pad_width + y_pixels = float(point[1]) * max_height - pad_height + x_original = x_pixels / scale_factor + y_original = y_pixels / scale_factor + points.append([x_original, y_original]) + bboxes.append(points) texts.append(txt) - if not scaled_boxes or not texts: - return "" - - # Convert bounding boxes to a simplified (x0, y0, x1, y1) representation - # by taking only the top-left and bottom-right corners - bboxes_array = np.array(scaled_boxes).astype(int) - # Reshape => (N, 4) by taking [0,1, 2,3] from the original (N, 4, 2) - # but we have 4 corners => 8 values. So shape => (N, 8). Then keep indices [0, 1, 2, 7]. - bboxes_array = bboxes_array.reshape(-1, 8)[:, [0, 1, 2, -1]] - - preds_df = pd.DataFrame( - { - "x0": bboxes_array[:, 0], - "y0": bboxes_array[:, 1], - "x1": bboxes_array[:, 2], - "y1": bboxes_array[:, 3], - "text": texts, - } - ) - # Sort by top position - preds_df = preds_df.sort_values("y0") - - dbscan = DBSCAN(eps=10, min_samples=1) - dbscan.fit(preds_df["y0"].values[:, None]) - - preds_df["cluster"] = dbscan.labels_ - # Sort by cluster and then by x0 to group text lines from left to right - preds_df = preds_df.sort_values(["cluster", "x0"]) - - lines = [] - for _, dfg in preds_df.groupby("cluster"): - lines.append("| " + " | ".join(dfg["text"].values.tolist()) + " |") - - return "\n".join(lines) + return bboxes, texts diff --git a/src/nv_ingest/util/nim/yolox.py b/src/nv_ingest/util/nim/yolox.py index ce32bf83..47df3569 100644 --- a/src/nv_ingest/util/nim/yolox.py +++ b/src/nv_ingest/util/nim/yolox.py @@ -15,6 +15,7 @@ import cv2 import numpy as np +import pandas as pd import torch import torchvision from PIL import Image @@ -24,53 +25,80 @@ logger = logging.getLogger(__name__) -YOLOX_MAX_BATCH_SIZE = 8 -YOLOX_MAX_WIDTH = 1536 -YOLOX_MAX_HEIGHT = 1536 -YOLOX_NUM_CLASSES = 3 -YOLOX_CONF_THRESHOLD = 0.01 -YOLOX_IOU_THRESHOLD = 0.5 -YOLOX_MIN_SCORE = 0.1 -YOLOX_FINAL_SCORE = 0.48 -YOLOX_NIM_MAX_IMAGE_SIZE = 512_000 - -YOLOX_IMAGE_PREPROC_HEIGHT = 1024 -YOLOX_IMAGE_PREPROC_WIDTH = 1024 - - -def chunkify_linearly(lst, chunk_size): - for i in range(0, len(lst), chunk_size): - yield lst[i : i + chunk_size] - - -def chunkify_geometrically(lst, max_size): - # TRT engine in Yolox NIM (gRPC) only allows a batch size in multiples of 2. - i = 0 - while i < len(lst): - chunk_size = min(2 ** int(log(len(lst) - i, 2)), max_size) - yield lst[i : i + chunk_size] - i += chunk_size - - -# Implementing YoloxPageElemenetsModelInterface with required methods -class YoloxPageElementsModelInterface(ModelInterface): +# yolox-page-elements-v1 contants +YOLOX_PAGE_NUM_CLASSES = 3 +YOLOX_PAGE_CONF_THRESHOLD = 0.01 +YOLOX_PAGE_IOU_THRESHOLD = 0.5 +YOLOX_PAGE_MIN_SCORE = 0.1 +YOLOX_PAGE_FINAL_SCORE = 0.48 +YOLOX_PAGE_NIM_MAX_IMAGE_SIZE = 512_000 + +YOLOX_PAGE_IMAGE_PREPROC_HEIGHT = 1024 +YOLOX_PAGE_IMAGE_PREPROC_WIDTH = 1024 + +YOLOX_PAGE_CLASS_LABELS = [ + "table", + "chart", + "title", +] + +# yolox-graphic-elements-v1 contants +YOLOX_GRAPHIC_NUM_CLASSES = 10 +YOLOX_GRAPHIC_CONF_THRESHOLD = 0.01 +YOLOX_GRAPHIC_IOU_THRESHOLD = 0.25 +YOLOX_GRAPHIC_MIN_SCORE = 0.1 +YOLOX_GRAPHIC_FINAL_SCORE = 0.0 +YOLOX_GRAPHIC_NIM_MAX_IMAGE_SIZE = 512_000 + +YOLOX_GRAPHIC_IMAGE_PREPROC_HEIGHT = 768 +YOLOX_GRAPHIC_IMAGE_PREPROC_WIDTH = 768 + +YOLOX_GRAPHIC_CLASS_LABELS = [ + "chart_title", + "x_title", + "y_title", + "xlabel", + "ylabel", + "other", + "legend_label", + "legend_title", + "mark_label", + "value_label", +] + + +# YoloxModelInterfaceBase implements methods that are common to yolox-page-elements and yolox-graphic-elements +class YoloxModelInterfaceBase(ModelInterface): """ An interface for handling inference with a Yolox object detection model, supporting both gRPC and HTTP protocols. """ - def name( + def __init__( self, - ) -> str: + image_preproc_width: Optional[int] = None, + image_preproc_height: Optional[int] = None, + nim_max_image_size: Optional[int] = None, + num_classes: Optional[int] = None, + conf_threshold: Optional[float] = None, + iou_threshold: Optional[float] = None, + min_score: Optional[float] = None, + final_score: Optional[float] = None, + class_labels: Optional[List[str]] = None, + ): """ - Returns the name of the Yolox model interface. - - Returns - ------- - str - The name of the model interface. + Initialize the YOLOX model interface. + Parameters + ---------- """ - - return "yolox-page-elements" + self.image_preproc_width = image_preproc_width + self.image_preproc_height = image_preproc_height + self.nim_max_image_size = nim_max_image_size + self.num_classes = num_classes + self.conf_threshold = conf_threshold + self.iou_threshold = iou_threshold + self.min_score = min_score + self.final_score = final_score + self.class_labels = class_labels def prepare_data_for_inference(self, data: Dict[str, Any]) -> Dict[str, Any]: """ @@ -133,20 +161,30 @@ def format_input( If the protocol is invalid. """ - # Helper to chunk a list into sublists of length up to chunk_size. + # Helper functions to chunk a list into sublists of length up to chunk_size. def chunk_list(lst: list, chunk_size: int) -> List[list]: return [lst[i : i + chunk_size] for i in range(0, len(lst), chunk_size)] + def chunk_list_geometrically(lst: list, max_size: int) -> List[list]: + # TRT engine in Yolox NIM (gRPC) only allows a batch size in powers of 2. + chunks = [] + i = 0 + while i < len(lst): + chunk_size = min(2 ** int(log(len(lst) - i, 2)), max_size) + chunks.append(lst[i : i + chunk_size]) + i += chunk_size + return chunks + if protocol == "grpc": logger.debug("Formatting input for gRPC Yolox model") # Resize images for model input (Yolox expects 1024x1024). resized_images = [ - resize_image(image, (YOLOX_IMAGE_PREPROC_WIDTH, YOLOX_IMAGE_PREPROC_HEIGHT)) for image in data["images"] + resize_image(image, (self.image_preproc_width, self.image_preproc_height)) for image in data["images"] ] # Chunk the resized images, the original images, and their shapes. - resized_chunks = chunk_list(resized_images, max_batch_size) - original_chunks = chunk_list(data["images"], max_batch_size) - shape_chunks = chunk_list(data["original_image_shapes"], max_batch_size) + resized_chunks = chunk_list_geometrically(resized_images, max_batch_size) + original_chunks = chunk_list_geometrically(data["images"], max_batch_size) + shape_chunks = chunk_list_geometrically(data["original_image_shapes"], max_batch_size) batched_inputs = [] formatted_batch_data = [] @@ -172,7 +210,7 @@ def chunk_list(lst: list, chunk_size: int) -> List[list]: # Scale the image if necessary. scaled_image_b64, new_size = scale_image_to_encoding_size( - image_b64, max_base64_size=YOLOX_NIM_MAX_IMAGE_SIZE + image_b64, max_base64_size=self.nim_max_image_size ) if new_size != original_size: logger.debug(f"Image was scaled from {original_size} to {new_size}.") @@ -229,7 +267,7 @@ def parse_output(self, response: Any, protocol: str, data: Optional[Dict[str, An batch_results = response.get("data", []) for detections in batch_results: - new_bounding_boxes = {"table": [], "chart": [], "title": []} + new_bounding_boxes = {label: [] for label in self.class_labels} bounding_boxes = detections.get("bounding_boxes", []) for obj_type, bboxes in bounding_boxes.items(): @@ -265,11 +303,6 @@ def process_inference_results(self, output: Any, protocol: str, **kwargs) -> Lis A list of annotation dictionaries for each image in the batch. """ original_image_shapes = kwargs.get("original_image_shapes", []) - num_classes = kwargs.get("num_classes", YOLOX_NUM_CLASSES) - conf_thresh = kwargs.get("conf_thresh", YOLOX_CONF_THRESHOLD) - iou_thresh = kwargs.get("iou_thresh", YOLOX_IOU_THRESHOLD) - min_score = kwargs.get("min_score", YOLOX_MIN_SCORE) - final_thresh = kwargs.get("final_thresh", YOLOX_FINAL_SCORE) if protocol == "http": # For http, the output already has postprocessing applied. Skip to table/chart expansion. @@ -277,11 +310,88 @@ def process_inference_results(self, output: Any, protocol: str, **kwargs) -> Lis elif protocol == "grpc": # For grpc, apply the same NIM postprocessing. - pred = postprocess_model_prediction(output, num_classes, conf_thresh, iou_thresh, class_agnostic=True) - results = postprocess_results(pred, original_image_shapes, min_score=min_score) + pred = postprocess_model_prediction( + output, self.num_classes, self.conf_threshold, self.iou_threshold, class_agnostic=True + ) + results = postprocess_results( + pred, + original_image_shapes, + self.image_preproc_width, + self.image_preproc_height, + self.class_labels, + min_score=self.min_score, + ) + + inference_results = self.postprocess_annotations(results, **kwargs) + + return inference_results + + def postprocess_annotations(self, annotation_dicts, **kwargs): + raise NotImplementedError() + + def transform_normalized_coordinates_to_original(self, results, original_image_shapes): + """ """ + transformed_results = [] + + for annotation_dict, shape in zip(results, original_image_shapes): + new_dict = {} + for label, bboxes_and_scores in annotation_dict.items(): + new_dict[label] = [] + for bbox_and_score in bboxes_and_scores: + bbox = bbox_and_score[:4] + transformed_bbox = [ + bbox[0] * shape[1], + bbox[1] * shape[0], + bbox[2] * shape[1], + bbox[3] * shape[0], + ] + transformed_bbox += bbox_and_score[4:] + new_dict[label].append(transformed_bbox) + transformed_results.append(new_dict) + + return transformed_results + + +class YoloxPageElementsModelInterface(YoloxModelInterfaceBase): + """ + An interface for handling inference with yolox-page-elements model, supporting both gRPC and HTTP protocols. + """ + + def __init__(self): + """ + Initialize the yolox-page-elements model interface. + """ + super().__init__( + image_preproc_width=YOLOX_PAGE_IMAGE_PREPROC_HEIGHT, + image_preproc_height=YOLOX_PAGE_IMAGE_PREPROC_HEIGHT, + nim_max_image_size=YOLOX_PAGE_NIM_MAX_IMAGE_SIZE, + num_classes=YOLOX_PAGE_NUM_CLASSES, + conf_threshold=YOLOX_PAGE_CONF_THRESHOLD, + iou_threshold=YOLOX_PAGE_IOU_THRESHOLD, + min_score=YOLOX_PAGE_MIN_SCORE, + final_score=YOLOX_PAGE_FINAL_SCORE, + class_labels=YOLOX_PAGE_CLASS_LABELS, + ) + + def name( + self, + ) -> str: + """ + Returns the name of the Yolox model interface. + + Returns + ------- + str + The name of the model interface. + """ + + return "yolox-page-elements" + + def postprocess_annotations(self, annotation_dicts, **kwargs): + original_image_shapes = kwargs.get("original_image_shapes", []) # Table/chart expansion is "business logic" specific to nv-ingest - annotation_dicts = [expand_table_bboxes(annotation_dict) for annotation_dict in results] + annotation_dicts = [expand_table_bboxes(annotation_dict) for annotation_dict in annotation_dicts] annotation_dicts = [expand_chart_bboxes(annotation_dict) for annotation_dict in annotation_dicts] inference_results = [] @@ -290,13 +400,74 @@ def process_inference_results(self, output: Any, protocol: str, **kwargs) -> Lis for annotation_dict in annotation_dicts: new_dict = {} if "table" in annotation_dict: - new_dict["table"] = [bb for bb in annotation_dict["table"] if bb[4] >= final_thresh] + new_dict["table"] = [bb for bb in annotation_dict["table"] if bb[4] >= self.final_score] if "chart" in annotation_dict: - new_dict["chart"] = [bb for bb in annotation_dict["chart"] if bb[4] >= final_thresh] + new_dict["chart"] = [bb for bb in annotation_dict["chart"] if bb[4] >= self.final_score] if "title" in annotation_dict: new_dict["title"] = annotation_dict["title"] inference_results.append(new_dict) + inference_results = self.transform_normalized_coordinates_to_original(inference_results, original_image_shapes) + + return inference_results + + +class YoloxGraphicElementsModelInterface(YoloxModelInterfaceBase): + """ + An interface for handling inference with yolox-graphic-elemenents model, supporting both gRPC and HTTP protocols. + """ + + def __init__(self): + """ + Initialize the yolox-graphic-elements model interface. + """ + super().__init__( + image_preproc_width=YOLOX_GRAPHIC_IMAGE_PREPROC_HEIGHT, + image_preproc_height=YOLOX_GRAPHIC_IMAGE_PREPROC_HEIGHT, + nim_max_image_size=YOLOX_GRAPHIC_NIM_MAX_IMAGE_SIZE, + num_classes=YOLOX_GRAPHIC_NUM_CLASSES, + conf_threshold=YOLOX_GRAPHIC_CONF_THRESHOLD, + iou_threshold=YOLOX_GRAPHIC_IOU_THRESHOLD, + min_score=YOLOX_GRAPHIC_MIN_SCORE, + final_score=YOLOX_GRAPHIC_FINAL_SCORE, + class_labels=YOLOX_GRAPHIC_CLASS_LABELS, + ) + + def name( + self, + ) -> str: + """ + Returns the name of the Yolox model interface. + + Returns + ------- + str + The name of the model interface. + """ + + return "yolox-graphic-elements" + + def postprocess_annotations(self, annotation_dicts, **kwargs): + original_image_shapes = kwargs.get("original_image_shapes", []) + + annotation_dicts = self.transform_normalized_coordinates_to_original(annotation_dicts, original_image_shapes) + + inference_results = [] + + # bbox extraction: additional postprocessing speicifc to nv-ingest + for pred, shape in zip(annotation_dicts, original_image_shapes): + bbox_dict = get_bbox_dict_yolox_graphic( + pred, + shape, + self.class_labels, + self.min_score, + ) + # convert numpy arrays to list + bbox_dict = { + label: array.tolist() if isinstance(array, np.ndarray) else array for label, array in bbox_dict.items() + } + inference_results.append(bbox_dict) + return inference_results @@ -359,7 +530,9 @@ def postprocess_model_prediction(prediction, num_classes, conf_thre=0.7, nms_thr return output -def postprocess_results(results, original_image_shapes, min_score=0.0): +def postprocess_results( + results, original_image_shapes, image_preproc_width, image_preproc_height, class_labels, min_score=0.0 +): """ For each item (==image) in results, computes annotations in the form @@ -371,7 +544,6 @@ def postprocess_results(results, original_image_shapes, min_score=0.0): Keep only bboxes with high enough confidence. """ - class_labels = ["table", "chart", "title"] out = [] for original_image_shape, result in zip(original_image_shapes, results): @@ -388,8 +560,8 @@ def postprocess_results(results, original_image_shapes, min_score=0.0): # ratio is used when image was padded ratio = min( - YOLOX_IMAGE_PREPROC_WIDTH / original_image_shape[0], - YOLOX_IMAGE_PREPROC_HEIGHT / original_image_shape[1], + image_preproc_width / original_image_shape[0], + image_preproc_height / original_image_shape[1], ) bboxes = result[:, :4] / ratio @@ -907,3 +1079,108 @@ def get_weighted_box(boxes, conf_type="avg"): box[3] = -1 # model index field is retained for consistency but is not used. box[4:] /= conf return box + + +def batched_overlaps(A, B): + """ + Calculate the Intersection over Union (IoU) between + two sets of bounding boxes in a batched manner. + Normalization is modified to only use the area of A boxes, hence computing the overlaps. + Args: + A (ndarray): Array of bounding boxes of shape (N, 4) in format [x1, y1, x2, y2]. + B (ndarray): Array of bounding boxes of shape (M, 4) in format [x1, y1, x2, y2]. + Returns: + ndarray: Array of IoU values of shape (N, M) representing the overlaps + between each pair of bounding boxes. + """ + A = A.copy() + B = B.copy() + + A = A[None].repeat(B.shape[0], 0) + B = B[:, None].repeat(A.shape[1], 1) + + low = np.s_[..., :2] + high = np.s_[..., 2:] + + A, B = A.copy(), B.copy() + A[high] += 1 + B[high] += 1 + + intrs = (np.maximum(0, np.minimum(A[high], B[high]) - np.maximum(A[low], B[low]))).prod(-1) + ious = intrs / (A[high] - A[low]).prod(-1) + + return ious + + +def find_boxes_inside(boxes, boxes_to_check, threshold=0.9): + """ + Find all boxes that are inside another box based on + the intersection area divided by the area of the smaller box, + and removes them. + """ + overlaps = batched_overlaps(boxes_to_check, boxes) + to_keep = (overlaps >= threshold).sum(0) <= 1 + return boxes_to_check[to_keep] + + +def get_bbox_dict_yolox_graphic(preds, shape, class_labels, threshold_=0.1) -> Dict[str, np.ndarray]: + """ + Extracts bounding boxes from YOLOX model predictions: + - Applies thresholding + - Reformats boxes + - Cleans the `other` detections: removes the ones that are included in other detections. + - If no title is found, the biggest `other` box is used if it is larger than 0.3*img_w. + Args: + preds (np.ndarray): YOLOX model predictions including bounding boxes, scores, and labels. + shape (tuple): Original image shape. + threshold_ (float): Score threshold to filter bounding boxes. + Returns: + Dict[str, np.ndarray]: Dictionary of bounding boxes, organized by class. + """ + bbox_dict = {label: np.array([]) for label in class_labels} + + for i, label in enumerate(class_labels): + bboxes_class = np.array(preds[label]) + + if bboxes_class.size == 0: + continue + + # Try to find a chart_title box + threshold = threshold_ if label != "chart_title" else min(threshold_, bboxes_class[:, -1].max()) + bboxes_class = bboxes_class[bboxes_class[:, -1] >= threshold][:, :4].astype(int) + + sort = ["x0", "y0"] if label != "ylabel" else ["y0", "x0"] + idxs = ( + pd.DataFrame( + { + "y0": bboxes_class[:, 1], + "x0": bboxes_class[:, 0], + } + ) + .sort_values(sort, ascending=label != "ylabel") + .index + ) + bboxes_class = bboxes_class[idxs] + bbox_dict[label] = bboxes_class + + # Remove other included + if len(bbox_dict.get("other", [])): + other = find_boxes_inside( + np.concatenate(list([v for v in bbox_dict.values() if len(v)])), bbox_dict["other"], threshold=0.7 + ) + del bbox_dict["other"] + if len(other): + bbox_dict["other"] = other + + # Biggest other is title if no title + if not len(bbox_dict.get("chart_title", [])) and len(bbox_dict.get("other", [])): + boxes = bbox_dict["other"] + ws = boxes[:, 2] - boxes[:, 0] + if np.max(ws) > shape[1] * 0.3: + bbox_dict["chart_title"] = boxes[np.argmax(ws)][None].copy() + bbox_dict["other"] = np.delete(boxes, (np.argmax(ws)), axis=0) + + # Make sure other key not lost + bbox_dict["other"] = bbox_dict.get("other", []) + + return bbox_dict diff --git a/src/nv_ingest/util/pipeline/stage_builders.py b/src/nv_ingest/util/pipeline/stage_builders.py index 6824c431..d140e1c9 100644 --- a/src/nv_ingest/util/pipeline/stage_builders.py +++ b/src/nv_ingest/util/pipeline/stage_builders.py @@ -225,21 +225,15 @@ def add_table_extractor_stage(pipe, morpheus_pipeline_config, ingest_config, def def add_chart_extractor_stage(pipe, morpheus_pipeline_config, ingest_config, default_cpu_count): - _, _, yolox_auth, _ = get_table_detection_service("yolox") - - deplot_grpc, deplot_http, deplot_auth, deplot_protocol = get_table_detection_service("deplot") - cached_grpc, cached_http, cached_auth, cached_protocol = get_table_detection_service("cached") - # NOTE: Paddle isn't currently used directly by the chart extraction stage, but will be in the future. + yolox_grpc, yolox_http, yolox_auth, yolox_protocol = get_table_detection_service("yolox_graphic_elements") paddle_grpc, paddle_http, paddle_auth, paddle_protocol = get_table_detection_service("paddle") table_content_extractor_config = ingest_config.get( "table_content_extraction_module", { "stage_config": { - "cached_endpoints": (cached_grpc, cached_http), - "cached_infer_protocol": cached_protocol, - "deplot_endpoints": (deplot_grpc, deplot_http), - "deplot_infer_protocol": deplot_protocol, + "yolox_endpoints": (yolox_grpc, yolox_http), + "yolox_infer_protocol": yolox_protocol, "paddle_endpoints": (paddle_grpc, paddle_http), "paddle_infer_protocol": paddle_protocol, "auth_token": yolox_auth, diff --git a/tests/nv_ingest/extraction_workflows/image/test_image_handlers.py b/tests/nv_ingest/extraction_workflows/image/test_image_handlers.py index da0c358e..eb2f047a 100644 --- a/tests/nv_ingest/extraction_workflows/image/test_image_handlers.py +++ b/tests/nv_ingest/extraction_workflows/image/test_image_handlers.py @@ -138,7 +138,7 @@ def test_extract_table_and_chart_images_empty_annotations(): def test_extract_table_and_chart_images_single_table(): """Test extraction with a single table bounding box.""" - annotation_dict = {"table": [[0.1, 0.1, 0.3, 0.3, 0.8]], "chart": []} + annotation_dict = {"table": [[64, 64, 192, 192, 0.8]], "chart": []} original_image = np.random.rand(640, 640, 3) tables_and_charts = [] @@ -161,7 +161,7 @@ def test_extract_table_and_chart_images_single_table(): def test_extract_table_and_chart_images_single_chart(): """Test extraction with a single chart bounding box.""" - annotation_dict = {"table": [], "chart": [[0.4, 0.4, 0.6, 0.6, 0.9]]} + annotation_dict = {"table": [], "chart": [[256, 256, 384, 384, 0.9]]} original_image = np.random.rand(640, 640, 3) tables_and_charts = [] @@ -198,7 +198,7 @@ def test_extract_table_and_chart_images_multiple_objects(): def test_extract_table_and_chart_images_invalid_bounding_box(): """Test with an invalid bounding box to check handling of incorrect coordinates.""" - annotation_dict = {"table": [[1.1, 1.1, 1.5, 1.5, 0.9]], "chart": []} # Out of bounds + annotation_dict = {"table": [[704, 704, 960, 960, 0.9]], "chart": []} # Out of bounds original_image = np.random.rand(640, 640, 3) tables_and_charts = [] diff --git a/tests/nv_ingest/schemas/test_chart_extractor_schema.py b/tests/nv_ingest/schemas/test_chart_extractor_schema.py index 64d2f0b2..917fcafc 100644 --- a/tests/nv_ingest/schemas/test_chart_extractor_schema.py +++ b/tests/nv_ingest/schemas/test_chart_extractor_schema.py @@ -11,54 +11,45 @@ def test_valid_config_with_grpc_only(): config = ChartExtractorConfigSchema( auth_token="valid_token", - cached_endpoints=("grpc://cached_service", None), - deplot_endpoints=("grpc://deplot_service", None), + yolox_endpoints=("grpc://yolox_service", None), paddle_endpoints=("grpc://paddle_service", None), ) assert config.auth_token == "valid_token" - assert config.cached_endpoints == ("grpc://cached_service", None) - assert config.deplot_endpoints == ("grpc://deplot_service", None) + assert config.yolox_endpoints == ("grpc://yolox_service", None) assert config.paddle_endpoints == ("grpc://paddle_service", None) def test_valid_config_with_http_only(): config = ChartExtractorConfigSchema( auth_token="valid_token", - cached_endpoints=(None, "http://cached_service"), - deplot_endpoints=(None, "http://deplot_service"), + yolox_endpoints=(None, "http://yolox_service"), paddle_endpoints=(None, "http://paddle_service"), ) assert config.auth_token == "valid_token" - assert config.cached_endpoints == (None, "http://cached_service") - assert config.deplot_endpoints == (None, "http://deplot_service") + assert config.yolox_endpoints == (None, "http://yolox_service") assert config.paddle_endpoints == (None, "http://paddle_service") def test_invalid_config_with_empty_services(): with pytest.raises(ValidationError) as excinfo: - ChartExtractorConfigSchema( - cached_endpoints=(None, None), deplot_endpoints=(None, None), paddle_endpoints=(None, None) - ) + ChartExtractorConfigSchema(yolox_endpoints=(None, None), paddle_endpoints=(None, None)) assert "Both gRPC and HTTP services cannot be empty" in str(excinfo.value) def test_valid_config_with_both_grpc_and_http(): config = ChartExtractorConfigSchema( auth_token="another_token", - cached_endpoints=("grpc://cached_service", "http://cached_service"), - deplot_endpoints=("grpc://deplot_service", "http://deplot_service"), + yolox_endpoints=("grpc://yolox_service", "http://yolox_service"), paddle_endpoints=("grpc://paddle_service", "http://paddle_service"), ) assert config.auth_token == "another_token" - assert config.cached_endpoints == ("grpc://cached_service", "http://cached_service") - assert config.deplot_endpoints == ("grpc://deplot_service", "http://deplot_service") + assert config.yolox_endpoints == ("grpc://yolox_service", "http://yolox_service") assert config.paddle_endpoints == ("grpc://paddle_service", "http://paddle_service") def test_invalid_auth_token_none(): config = ChartExtractorConfigSchema( - cached_endpoints=("grpc://cached_service", None), - deplot_endpoints=("grpc://deplot_service", None), + yolox_endpoints=("grpc://yolox_service", None), paddle_endpoints=("grpc://paddle_service", None), ) assert config.auth_token is None @@ -67,7 +58,8 @@ def test_invalid_auth_token_none(): def test_invalid_endpoint_format(): with pytest.raises(ValidationError): ChartExtractorConfigSchema( - cached_endpoints=("invalid_endpoint", None), deplot_endpoints=(None, "invalid_endpoint") + yolox_endpoints=("invalid_endpoint", None), + deplot_endpoints=(None, "invalid_endpoint"), ) @@ -82,8 +74,7 @@ def test_chart_extractor_schema_defaults(): def test_chart_extractor_schema_with_custom_values(): stage_config = ChartExtractorConfigSchema( - cached_endpoints=("grpc://cached_service", "http://cached_service"), - deplot_endpoints=("grpc://deplot_service", None), + yolox_endpoints=("grpc://yolox_service", "http://yolox_service"), paddle_endpoints=(None, "http://paddle_service"), ) config = ChartExtractorSchema(max_queue_size=10, n_workers=5, raise_on_failure=True, stage_config=stage_config) diff --git a/tests/nv_ingest/stages/nims/test_chart_extraction.py b/tests/nv_ingest/stages/nims/test_chart_extraction.py index 47b83275..8f71df05 100644 --- a/tests/nv_ingest/stages/nims/test_chart_extraction.py +++ b/tests/nv_ingest/stages/nims/test_chart_extraction.py @@ -2,6 +2,7 @@ import pytest import pandas as pd +import numpy as np from ....import_checks import MORPHEUS_IMPORT_OK from ....import_checks import CUDA_DRIVER_OK @@ -15,6 +16,7 @@ from nv_ingest.schemas.chart_extractor_schema import ChartExtractorConfigSchema from nv_ingest.stages.nim.chart_extraction import _update_metadata, _create_clients from nv_ingest.stages.nim.chart_extraction import _extract_chart_data + from nv_ingest.util.image_processing.transforms import base64_to_numpy MODULE_UNDER_TEST = "nv_ingest.stages.nim.chart_extraction" @@ -27,16 +29,19 @@ def valid_chart_extractor_config(): """ return ChartExtractorConfigSchema( auth_token="fake_token", - cached_endpoints=("cached_grpc_url", "cached_http_url"), - cached_infer_protocol="grpc", - deplot_endpoints=("deplot_grpc_url", "deplot_http_url"), - deplot_infer_protocol="http", + yolox_endpoints=("yolox_grpc_url", "yolox_http_url"), + yolox_infer_protocol="grpc", paddle_endpoints=("paddle_grpc_url", "paddle_http_url"), paddle_infer_protocol="grpc", workers_per_progress_engine=5, ) +@pytest.fixture +def base64_image(): + return "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=" + + @pytest.fixture def validated_config(valid_chart_extractor_config): """ @@ -57,18 +62,18 @@ def test_update_metadata_empty_list(): With the updated implementation, both clients are still invoked (with an empty list) so we set their return values to [] and then verify the calls. """ - cached_mock = MagicMock() - deplot_mock = MagicMock() + yolox_mock = MagicMock() + paddle_mock = MagicMock() trace_info = {} # When given an empty list, both clients return an empty list. - cached_mock.infer.return_value = [] - deplot_mock.infer.return_value = [] + yolox_mock.infer.return_value = [] + paddle_mock.infer.return_value = [] result = _update_metadata( base64_images=[], - cached_client=cached_mock, - deplot_client=deplot_mock, + yolox_client=yolox_mock, + paddle_client=paddle_mock, trace_info=trace_info, worker_pool_size=1, ) @@ -76,183 +81,188 @@ def test_update_metadata_empty_list(): assert result == [] # Each client's infer should be called once with an empty list. - cached_mock.infer.assert_called_once_with( - data={"base64_images": []}, - model_name="cached", + yolox_mock.infer.assert_called_once_with( + data={"images": []}, + model_name="yolox", stage_name="chart_data_extraction", - max_batch_size=2, + max_batch_size=8, trace_info=trace_info, ) - deplot_mock.infer.assert_called_once_with( + paddle_mock.infer.assert_called_once_with( data={"base64_images": []}, - model_name="deplot", + model_name="paddle", stage_name="chart_data_extraction", - max_batch_size=1, + max_batch_size=2, trace_info=trace_info, ) -def test_update_metadata_single_batch_single_worker(mocker): +def test_update_metadata_single_batch_single_worker(mocker, base64_image): """ Test a simple scenario with a small list of base64_images using worker_pool_size=1. - In the updated _update_metadata implementation, both the cached and deplot clients are + In the updated _update_metadata implementation, both the yolox and paddle clients are called once with the full list of images. The join function is applied per image. """ # Mock out the clients - cached_mock = MagicMock() - deplot_mock = MagicMock() + yolox_mock = MagicMock() + paddle_mock = MagicMock() - # For 2 images, cached.infer should return a list with 2 results. - cached_mock.infer.return_value = ["cached_res1", "cached_res2"] + # Suppose yolox returns ["yolox_res1", "yolox_res2"] for 2 images + yolox_mock.infer.return_value = ["yolox_res1", "yolox_res2"] - # deplot.infer should also return a list with 2 results. - deplot_mock.infer.return_value = ["deplot_res1", "deplot_res2"] + # Suppose paddle returns ["paddle_res1", "paddle_res2"] for 2 images + paddle_mock.infer.return_value = [[(), "paddle_res1"], [(), "paddle_res2"]] - # Patch the join function to return expected joined outputs. mock_join = mocker.patch( - f"{MODULE_UNDER_TEST}.join_cached_and_deplot_output", - side_effect=["joined_1", "joined_2"], + f"{MODULE_UNDER_TEST}.join_yolox_and_paddle_output", + side_effect=[{"chart_title": "joined_1"}, {"chart_title": "joined_2"}], ) - base64_images = ["img1", "img2"] + base64_images = [ + base64_image, + base64_image, + ] trace_info = {} - result = _update_metadata(base64_images, cached_mock, deplot_mock, trace_info, worker_pool_size=1) + result = _update_metadata(base64_images, yolox_mock, paddle_mock, trace_info, batch_size=2, worker_pool_size=1) # Expect the result to combine each original image with its corresponding joined output. - assert result == [("img1", "joined_1"), ("img2", "joined_2")] - - # cached.infer should be called once with the full list of images. - cached_mock.infer.assert_called_once_with( - data={"base64_images": ["img1", "img2"]}, - model_name="cached", + assert len(result) == 2 + assert result[0] == (base64_image, "joined_1") + assert result[1] == (base64_image, "joined_2") + + # yolox.infer should be called once with the full list of arrays. + assert yolox_mock.infer.call_count == 1 + assert np.all(yolox_mock.infer.call_args.kwargs["data"]["images"][0] == base64_to_numpy(base64_image)) + assert np.all(yolox_mock.infer.call_args.kwargs["data"]["images"][1] == base64_to_numpy(base64_image)) + assert yolox_mock.infer.call_args.kwargs["model_name"] == "yolox" + assert yolox_mock.infer.call_args.kwargs["stage_name"] == "chart_data_extraction" + assert yolox_mock.infer.call_args.kwargs["trace_info"] == trace_info + + # paddle.infer should be called once with the full list of images. + paddle_mock.infer.assert_called_once_with( + data={"base64_images": [base64_image, base64_image]}, + model_name="paddle", stage_name="chart_data_extraction", max_batch_size=2, trace_info=trace_info, ) - # deplot.infer should be called once with the full list of images. - deplot_mock.infer.assert_called_once_with( - data={"base64_images": ["img1", "img2"]}, - model_name="deplot", - stage_name="chart_data_extraction", - max_batch_size=1, - trace_info=trace_info, - ) - # The join function should be invoked once per image. assert mock_join.call_count == 2 -def test_update_metadata_multiple_batches_multi_worker(mocker): +def test_update_metadata_multiple_batches_multi_worker(mocker, base64_image): """ With the new _update_metadata implementation, both cached_client.infer and deplot_client.infer are called once with the full list of images. Their results are expected to be lists with one item per image. The join function is still invoked for each image. """ - cached_mock = MagicMock() - deplot_mock = MagicMock() + yolox_mock = MagicMock() + paddle_mock = MagicMock() mock_join = mocker.patch( - f"{MODULE_UNDER_TEST}.join_cached_and_deplot_output", - side_effect=["joined_1", "joined_2", "joined_3"], + f"{MODULE_UNDER_TEST}.join_yolox_and_paddle_output", + side_effect=[{"chart_title": "joined_1"}, {"chart_title": "joined_2"}, {"chart_title": "joined_3"}], ) - # Each client should return a list with an entry per input image. - def cached_side_effect(**kwargs): - images = kwargs["data"]["base64_images"] - return [f"cached_{img}" for img in images] + # Suppose every yolox.infer call returns a 1-element list + def yolox_side_effect(**kwargs): + images = kwargs["data"]["images"] + return [f"yolox_{images[0]}"] - cached_mock.infer.side_effect = cached_side_effect + yolox_mock.infer.side_effect = yolox_side_effect - def deplot_side_effect(**kwargs): - images = kwargs["data"]["base64_images"] - return [f"deplot_{img}" for img in images] + # Suppose paddle.infer returns e.g. ["paddle_img1"], etc. + def paddle_side_effect(**kwargs): + img = kwargs["data"]["base64_images"] + return [([], f"paddle_{img}")] - deplot_mock.infer.side_effect = deplot_side_effect + paddle_mock.infer.side_effect = paddle_side_effect - base64_images = ["imgA", "imgB", "imgC"] + base64_images = [base64_image, base64_image, base64_image] trace_info = {} result = _update_metadata( base64_images, - cached_mock, - deplot_mock, + yolox_mock, + paddle_mock, trace_info, worker_pool_size=2, ) - # Expect 3 results corresponding to the input images and the joined outputs. - assert result == [("imgA", "joined_1"), ("imgB", "joined_2"), ("imgC", "joined_3")] + # Expect 3 results: [("imgA", "joined_1"), ("imgB", "joined_2"), ("imgC", "joined_3")] + assert result == [(base64_image, "joined_1"), (base64_image, "joined_2"), (base64_image, "joined_3")] - # Now, with the new implementation, each infer method is called only once. - assert cached_mock.infer.call_count == 1 - assert deplot_mock.infer.call_count == 1 - # The join function is still called once per image. + # We should have 3 calls to yolox.infer, each with one image + assert yolox_mock.infer.call_count == 3 + # Also 3 calls to paddle.infer + assert paddle_mock.infer.call_count == 3 + # 3 calls to join assert mock_join.call_count == 3 -def test_update_metadata_exception_in_cached_call(caplog): +def test_update_metadata_exception_in_yolox_call(base64_image, caplog): """ - If the cached call fails, we expect an exception to bubble up and the error to be logged. + If the yolox call fails, we expect an exception to bubble up and the error to be logged. """ - cached_mock = MagicMock() - deplot_mock = MagicMock() - cached_mock.infer.side_effect = Exception("Cached call error") + yolox_mock = MagicMock() + paddle_mock = MagicMock() + yolox_mock.infer.side_effect = Exception("Yolox call error") - with pytest.raises(Exception, match="Cached call error"): - _update_metadata(["some_img"], cached_mock, deplot_mock, trace_info={}, worker_pool_size=1) + with pytest.raises(Exception, match="Yolox call error"): + _update_metadata([base64_image], yolox_mock, paddle_mock, trace_info={}, batch_size=1, worker_pool_size=1) # Verify that the error message from the cached client is logged. - assert "Error calling cached_client.infer: Cached call error" in caplog.text + assert "Error calling yolox_client.infer: Cached call error" in caplog.text -def test_update_metadata_exception_in_deplot_call(caplog): +def test_update_metadata_exception_in_paddle_call(base64_image, caplog): """ - If the deplot call fails, we expect an exception to bubble up and the error to be logged. + If the paddle call fails, we expect an exception to bubble up and the error to be logged. """ - cached_mock = MagicMock() - cached_mock.infer.return_value = ["cached_result"] # Single-element list for one image - deplot_mock = MagicMock() - deplot_mock.infer.side_effect = Exception("Deplot error") + yolox_mock = MagicMock() + yolox_mock.infer.return_value = ["yolox_result"] # Single-element list for one image + paddle_mock = MagicMock() + paddle_mock.infer.side_effect = Exception("Paddle error") - with pytest.raises(Exception, match="Deplot error"): - _update_metadata(["some_img"], cached_mock, deplot_mock, trace_info={}, worker_pool_size=2) + with pytest.raises(Exception, match="Paddle error"): + _update_metadata([base64_image], yolox_mock, paddle_mock, trace_info={}, batch_size=1, worker_pool_size=2) # Verify that the error message from the deplot client is logged. - assert "Error calling deplot_client.infer: Deplot error" in caplog.text + assert "Error calling paddle_client.infer: Deplot error" in caplog.text def test_create_clients(mocker): """ Verify that _create_clients calls create_inference_client for - both cached and deplot endpoints, returning the pair of NimClient mocks. + both yolox and paddle endpoints, returning the pair of NimClient mocks. """ mock_create_inference_client = mocker.patch(f"{MODULE_UNDER_TEST}.create_inference_client") # Suppose it returns different mocks each time - cached_mock = MagicMock() - deplot_mock = MagicMock() - mock_create_inference_client.side_effect = [cached_mock, deplot_mock] + yolox_mock = MagicMock() + paddle_mock = MagicMock() + mock_create_inference_client.side_effect = [yolox_mock, paddle_mock] result = _create_clients( - cached_endpoints=("cached_grpc", "cached_http"), - cached_protocol="grpc", - deplot_endpoints=("deplot_grpc", "deplot_http"), - deplot_protocol="http", + yolox_endpoints=("yolox_grpc", "yolox_http"), + yolox_protocol="grpc", + paddle_endpoints=("paddle_grpc", "paddle_http"), + paddle_protocol="http", auth_token="xyz", ) - # result => (cached_mock, deplot_mock) - assert result == (cached_mock, deplot_mock) + # result => (yolox_mock, paddle_mock) + assert result == (yolox_mock, paddle_mock) # Check calls assert mock_create_inference_client.call_count == 2 mock_create_inference_client.assert_any_call( - endpoints=("cached_grpc", "cached_http"), model_interface=mocker.ANY, auth_token="xyz", infer_protocol="grpc" + endpoints=("yolox_grpc", "yolox_http"), model_interface=mocker.ANY, auth_token="xyz", infer_protocol="grpc" ) mock_create_inference_client.assert_any_call( - endpoints=("deplot_grpc", "deplot_http"), model_interface=mocker.ANY, auth_token="xyz", infer_protocol="http" + endpoints=("paddle_grpc", "paddle_http"), model_interface=mocker.ANY, auth_token="xyz", infer_protocol="http" ) @@ -307,8 +317,8 @@ def test_extract_chart_data_all_valid(validated_config, mocker): All rows meet criteria => pass them all to _update_metadata in order. """ # Mock out clients - cached_mock, deplot_mock = MagicMock(), MagicMock() - mock_create_clients = mocker.patch(f"{MODULE_UNDER_TEST}._create_clients", return_value=(cached_mock, deplot_mock)) + yolox_mock, paddle_mock = MagicMock(), MagicMock() + mock_create_clients = mocker.patch(f"{MODULE_UNDER_TEST}._create_clients", return_value=(yolox_mock, paddle_mock)) # Suppose _update_metadata returns chart content for each image mock_update_metadata = mocker.patch( @@ -342,18 +352,18 @@ def test_extract_chart_data_all_valid(validated_config, mocker): assert df_out.at[1, "metadata"]["table_metadata"]["table_content"] == {"joined": "contentB"} mock_create_clients.assert_called_once_with( - validated_config.stage_config.cached_endpoints, - validated_config.stage_config.cached_infer_protocol, - validated_config.stage_config.deplot_endpoints, - validated_config.stage_config.deplot_infer_protocol, + validated_config.stage_config.yolox_endpoints, + validated_config.stage_config.yolox_infer_protocol, + validated_config.stage_config.paddle_endpoints, + validated_config.stage_config.paddle_infer_protocol, validated_config.stage_config.auth_token, ) # Check _update_metadata call mock_update_metadata.assert_called_once_with( base64_images=["imgA", "imgB"], - cached_client=cached_mock, - deplot_client=deplot_mock, + yolox_client=yolox_mock, + paddle_client=paddle_mock, worker_pool_size=validated_config.stage_config.workers_per_progress_engine, trace_info=ti.get("trace_info"), ) @@ -364,8 +374,8 @@ def test_extract_chart_data_mixed_rows(validated_config, mocker): Some rows are valid, some not. We only pass valid images to _update_metadata, and only those rows get updated. """ - cached_mock, deplot_mock = MagicMock(), MagicMock() - mocker.patch(f"{MODULE_UNDER_TEST}._create_clients", return_value=(cached_mock, deplot_mock)) + yolox_mock, paddle_mock = MagicMock(), MagicMock() + mocker.patch(f"{MODULE_UNDER_TEST}._create_clients", return_value=(yolox_mock, paddle_mock)) mock_update = mocker.patch( f"{MODULE_UNDER_TEST}._update_metadata", @@ -410,8 +420,8 @@ def test_extract_chart_data_mixed_rows(validated_config, mocker): mock_update.assert_called_once_with( base64_images=["base64img1", "base64img2"], - cached_client=cached_mock, - deplot_client=deplot_mock, + yolox_client=yolox_mock, + paddle_client=paddle_mock, worker_pool_size=validated_config.stage_config.workers_per_progress_engine, trace_info=trace.get("trace_info"), ) diff --git a/tests/nv_ingest/stages/nims/test_table_extraction.py b/tests/nv_ingest/stages/nims/test_table_extraction.py index 3656f548..d033f249 100644 --- a/tests/nv_ingest/stages/nims/test_table_extraction.py +++ b/tests/nv_ingest/stages/nims/test_table_extraction.py @@ -107,8 +107,8 @@ def test_extract_table_data_all_valid(mocker, validated_config): mock_update_metadata = mocker.patch( f"{MODULE_UNDER_TEST}._update_metadata", return_value=[ - ("imgA", ("tableA", "fmtA")), - ("imgB", ("tableB", "fmtB")), + ("imgA", [[], ["tableA"]]), + ("imgB", [[], ["tableB"]]), ], ) @@ -117,14 +117,14 @@ def test_extract_table_data_all_valid(mocker, validated_config): { "metadata": { "content_metadata": {"type": "structured", "subtype": "table"}, - "table_metadata": {}, + "table_metadata": {"table_content_format": "simple"}, "content": "imgA", } }, { "metadata": { "content_metadata": {"type": "structured", "subtype": "table"}, - "table_metadata": {}, + "table_metadata": {"table_content_format": "simple"}, "content": "imgB", } }, @@ -135,9 +135,9 @@ def test_extract_table_data_all_valid(mocker, validated_config): # Each valid row updated assert df_out.at[0, "metadata"]["table_metadata"]["table_content"] == "tableA" - assert df_out.at[0, "metadata"]["table_metadata"]["table_content_format"] == "fmtA" + assert df_out.at[0, "metadata"]["table_metadata"]["table_content_format"] == "simple" assert df_out.at[1, "metadata"]["table_metadata"]["table_content"] == "tableB" - assert df_out.at[1, "metadata"]["table_metadata"]["table_content_format"] == "fmtB" + assert df_out.at[1, "metadata"]["table_metadata"]["table_content_format"] == "simple" # Check calls mock_create_client.assert_called_once() @@ -158,7 +158,7 @@ def test_extract_table_data_mixed_rows(mocker, validated_config): mock_create_client = mocker.patch(f"{MODULE_UNDER_TEST}._create_paddle_client", return_value=mock_client) mock_update_metadata = mocker.patch( f"{MODULE_UNDER_TEST}._update_metadata", - return_value=[("good1", ("table1", "fmt1")), ("good2", ("table2", "fmt2"))], + return_value=[("good1", [[], ["table1"]]), ("good2", [[], ["table2"]])], ) df_in = pd.DataFrame( @@ -166,7 +166,7 @@ def test_extract_table_data_mixed_rows(mocker, validated_config): { "metadata": { "content_metadata": {"type": "structured", "subtype": "table"}, - "table_metadata": {}, + "table_metadata": {"table_content_format": "simple"}, "content": "good1", } }, @@ -190,14 +190,14 @@ def test_extract_table_data_mixed_rows(mocker, validated_config): df_out, trace_info = _extract_table_data(df_in, {}, validated_config) - # row0 => updated with table1/fmt1 + # row0 => updated with table1/txt1 assert df_out.at[0, "metadata"]["table_metadata"]["table_content"] == "table1" - assert df_out.at[0, "metadata"]["table_metadata"]["table_content_format"] == "fmt1" + assert df_out.at[0, "metadata"]["table_metadata"]["table_content_format"] == "simple" # row1 => invalid => no table_content assert "table_content" not in df_out.at[1, "metadata"]["table_metadata"] - # row2 => updated => table2/fmt2 + # row2 => updated => table2/txt2 assert df_out.at[2, "metadata"]["table_metadata"]["table_content"] == "table2" - assert df_out.at[2, "metadata"]["table_metadata"]["table_content_format"] == "fmt2" + assert df_out.at[2, "metadata"]["table_metadata"]["table_content_format"] == "simple" mock_update_metadata.assert_called_once_with( base64_images=["good1", "good2"], @@ -310,7 +310,7 @@ def test_update_metadata_skip_small(mocker, paddle_mock): res = _update_metadata(imgs, paddle_mock) assert len(res) == 2 # The first image is too small and is skipped. - assert res[0] == ("imgSmall", ("", "")) + assert res[0] == ("imgSmall", (None, None)) # The second image is valid and processed. assert res[1] == ("imgBig", ("valid_table", "valid_fmt")) @@ -421,8 +421,8 @@ def test_update_metadata_all_small(mocker, paddle_mock): mocker.patch(f"{MODULE_UNDER_TEST}.PADDLE_MIN_HEIGHT", 30) res = _update_metadata(imgs, paddle_mock) - assert res[0] == ("imgA", ("", "")) - assert res[1] == ("imgB", ("", "")) + assert res[0] == ("imgA", (None, None)) + assert res[1] == ("imgB", (None, None)) # No calls to infer paddle_mock.infer.assert_not_called() diff --git a/tests/nv_ingest/util/nim/test_paddle.py b/tests/nv_ingest/util/nim/test_paddle.py index 5de7b9ea..5300968c 100644 --- a/tests/nv_ingest/util/nim/test_paddle.py +++ b/tests/nv_ingest/util/nim/test_paddle.py @@ -55,12 +55,7 @@ def create_valid_grpc_response_batched(text="mock_text"): @pytest.fixture def paddle_ocr_model(): - return PaddleOCRModelInterface(paddle_version="0.2.1") - - -@pytest.fixture -def legacy_paddle_ocr_model(): - return PaddleOCRModelInterface(paddle_version="0.2.0") + return PaddleOCRModelInterface() @pytest.fixture @@ -111,7 +106,6 @@ def test_prepare_data_for_inference(paddle_ocr_model): assert "image_arrays" in result assert len(result["image_arrays"]) == 1 assert result["image_arrays"][0].shape == (100, 100, 3) - assert data["image_dims"][0] == (100, 100) @@ -170,34 +164,6 @@ def test_format_input_http(paddle_ocr_model): assert "image_arrays" in bd and "image_dims" in bd -def test_format_input_http_legacy(legacy_paddle_ocr_model): - """ - For legacy mode (<0.2.1-rc2), after preparing the data, the formatted payload should - use the legacy structure: - {"messages": [ {"content": [ {"type": "image_url", "image_url": {"url": "data:image/png;base64,..."} } ] } ]} - As before, the accompanying batch data should contain the original images and dimensions. - """ - valid_b64 = create_valid_base64_image() - data = {"base64_image": valid_b64} - data = legacy_paddle_ocr_model.prepare_data_for_inference(data) - batches, batch_data = legacy_paddle_ocr_model.format_input(data, protocol="http", max_batch_size=1) - result = batches[0] - assert "messages" in result - assert isinstance(result["messages"], list) - assert len(result["messages"]) == 1 - content_list = result["messages"][0]["content"] - assert isinstance(content_list, list) - assert len(content_list) == 1 - item = content_list[0] - assert item["type"] == "image_url" - assert item["image_url"]["url"].startswith("data:image/png;base64,") - # Also check the returned batch data for legacy. - assert isinstance(batch_data, list) - assert len(batch_data) == 1 - bd = batch_data[0] - assert "image_arrays" in bd and "image_dims" in bd - - def test_parse_output_http_pseudo_markdown(paddle_ocr_model, mock_paddle_http_response): """ parse_output should now return a list of (content, table_content_format) tuples. @@ -227,55 +193,11 @@ def test_parse_output_http_simple(paddle_ocr_model, mock_paddle_http_response): data = {"base64_image": "mock_base64_string"} _ = paddle_ocr_model.prepare_data_for_inference(data) - result = paddle_ocr_model.parse_output(mock_paddle_http_response, protocol="http", table_content_format="simple") - # Should be [("mock_text", "simple")] - assert len(result) == 1 - assert result[0][0] == "mock_text" - assert result[0][1] == "simple" - - -def test_parse_output_http_simple_legacy(legacy_paddle_ocr_model): - """ - For the legacy version, parse_output also returns a list of (content, format), - but it forces 'simple' format if the user requested something else. - """ - with patch(f"{_MODULE_UNDER_TEST}.base64_to_numpy") as mock_base64_to_numpy: - mock_base64_to_numpy.return_value = np.zeros((100, 100, 3)) - - data = {"base64_image": "mock_base64_string"} - _ = legacy_paddle_ocr_model.prepare_data_for_inference(data) - - mock_legacy_paddle_http_response = {"data": [{"content": "mock_text"}]} - - result = legacy_paddle_ocr_model.parse_output( - mock_legacy_paddle_http_response, protocol="http", table_content_format="foo" - ) - # Expect => [("mock_text", "simple")] - assert len(result) == 1 - assert result[0][0] == "mock_text" - assert result[0][1] == "simple" - - -def test_parse_output_grpc_pseudo_markdown(paddle_ocr_model): - """ - Provide a valid (1,2) shape for bounding boxes & text. - The interface should parse them into [("| mock_text |\n", "pseudo_markdown")]. - """ - valid_b64 = create_valid_base64_image() - data = {"base64_image": valid_b64} - paddle_ocr_model.prepare_data_for_inference(data) - - # Create a valid shape => (1,2), with bounding-box JSON, text JSON - grpc_response = create_valid_grpc_response_batched("mock_text") - - # parse_output with default => pseudo_markdown for non-legacy - result = paddle_ocr_model.parse_output(grpc_response, protocol="grpc") - + result = paddle_ocr_model.parse_output(mock_paddle_http_response, protocol="http") + # Should be (bounding_boxes, text_predictions) assert len(result) == 1 - content, fmt = result[0] - # content might contain a markdown table row => "| mock_text |" - assert "mock_text" in content - assert fmt == "pseudo_markdown" + assert result[0][0] == [[[0.1, 0.2], [0.2, 0.2], [0.2, 0.3], [0.1, 0.3]]] + assert result[0][1] == ["mock_text"] def test_parse_output_grpc_simple(paddle_ocr_model): @@ -286,28 +208,18 @@ def test_parse_output_grpc_simple(paddle_ocr_model): data = {"base64_image": valid_b64} paddle_ocr_model.prepare_data_for_inference(data) + data["_dimensions"] = [ + { + "new_width": 1, + "new_height": 1, + "pad_width": 0, + "pad_height": 0, + "scale_factor": 1.0, + } + ] grpc_response = create_valid_grpc_response_batched("mock_text") - result = paddle_ocr_model.parse_output(grpc_response, protocol="grpc", table_content_format="simple") - - assert len(result) == 1 - content, fmt = result[0] - assert content == "mock_text" - assert fmt == "simple" - - -def test_parse_output_grpc_legacy(legacy_paddle_ocr_model): - """ - For legacy gRPC, we also return a list of (content, format). - We force 'simple' if table_content_format was not 'simple'. - """ - valid_b64 = create_valid_base64_image() - data = {"base64_image": valid_b64} - legacy_paddle_ocr_model.prepare_data_for_inference(data) - - grpc_response = create_valid_grpc_response_batched("mock_text") + result = paddle_ocr_model.parse_output(grpc_response, protocol="grpc", data=data) - # Pass a non-"simple" format => should be forced to 'simple' in legacy - result = legacy_paddle_ocr_model.parse_output(grpc_response, protocol="grpc", table_content_format="foo") assert len(result) == 1 - assert result[0][0] == "mock_text" - assert result[0][1] == "simple" + assert result[0][0] == [[[0.1, 0.2], [0.2, 0.2], [0.2, 0.3], [0.1, 0.3]]] + assert result[0][1] == ["mock_text"] diff --git a/tests/nv_ingest/util/nim/test_yolox.py b/tests/nv_ingest/util/nim/test_yolox.py index f44aadaa..89a8085c 100644 --- a/tests/nv_ingest/util/nim/test_yolox.py +++ b/tests/nv_ingest/util/nim/test_yolox.py @@ -229,11 +229,6 @@ def test_process_inference_results_grpc(model_interface): output_array, "grpc", original_image_shapes=original_image_shapes, - num_classes=3, - conf_thresh=0.5, - iou_thresh=0.4, - min_score=0.3, - final_thresh=0.6, ) assert isinstance(inference_results, list) assert len(inference_results) == 2 @@ -241,10 +236,10 @@ def test_process_inference_results_grpc(model_interface): assert isinstance(result, dict) if "table" in result: for bbox in result["table"]: - assert bbox[4] >= 0.6 + assert bbox[4] >= 0.48 if "chart" in result: for bbox in result["chart"]: - assert bbox[4] >= 0.6 + assert bbox[4] >= 0.48 if "title" in result: assert isinstance(result["title"], list) @@ -258,14 +253,11 @@ def test_process_inference_results_http(model_interface): } for _ in range(10) ] + original_image_shapes = [[100, 100] for _ in range(10)] inference_results = model_interface.process_inference_results( output, "http", - num_classes=3, - conf_thresh=0.5, - iou_thresh=0.4, - min_score=0.3, - final_thresh=0.6, + original_image_shapes=original_image_shapes, ) assert isinstance(inference_results, list) assert len(inference_results) == 10 @@ -273,9 +265,9 @@ def test_process_inference_results_http(model_interface): assert isinstance(result, dict) if "table" in result: for bbox in result["table"]: - assert bbox[4] >= 0.6 + assert bbox[4] >= 0.48 if "chart" in result: for bbox in result["chart"]: - assert bbox[4] >= 0.6 + assert bbox[4] >= 0.48 if "title" in result: assert isinstance(result["title"], list)