diff --git a/Dockerfile b/Dockerfile index a9663eab..c4405279 100644 --- a/Dockerfile +++ b/Dockerfile @@ -49,6 +49,7 @@ RUN source activate nv_ingest \ && mamba install -y \ nvidia/label/dev::morpheus-core \ nvidia/label/dev::morpheus-llm \ + imagemagick \ # pin to earlier version of cuda-python until __pyx_capi__ fix is upstreamed. cuda-python=12.6.0 \ -c rapidsai -c pytorch -c nvidia -c conda-forge diff --git a/client/src/nv_ingest_client/nv_ingest_cli.py b/client/src/nv_ingest_client/nv_ingest_cli.py index 8b38c10b..d4705273 100644 --- a/client/src/nv_ingest_client/nv_ingest_cli.py +++ b/client/src/nv_ingest_client/nv_ingest_cli.py @@ -116,57 +116,61 @@ \b Tasks and Options: -- split: Divides documents according to specified criteria. +- caption: Attempts to extract captions for unstructured images extracted from documents. Options: - - split_by (str): Criteria ('page', 'size', 'word', 'sentence'). No default. - - split_length (int): Segment length. No default. - - split_overlap (int): Segment overlap. No default. - - max_character_length (int): Maximum segment character count. No default. - - sentence_window_size (int): Sentence window size. No default. + - api_key (str): API key for captioning service. + Default: os.environ(NVIDIA_BUILD_API_KEY).' + - endpoint_url (str): Endpoint URL for captioning service. + Default: 'https://build.nvidia.com/meta/llama-3.2-90b-vision-instruct'. + - prompt (str): Prompt for captioning service. + Default: 'Caption the content of this image:'. +\b +- dedup: Identifies and optionally filters duplicate images in extraction. + Options: + - content_type (str): Content type to deduplicate ('image'). + - filter (bool): When set to True, duplicates will be filtered, otherwise, an info message will be added. +\b +- embed: Computes embeddings on multimodal extractions. + Options: + - filter_errors (bool): Flag to filter embedding errors. Optional. + - tables (bool): Flag to create embeddings for table extractions. Optional. + - text (bool): Flag to create embeddings for text extractions. Optional. \b - extract: Extracts content from documents, customizable per document type. Can be specified multiple times for different 'document_type' values. Options: - document_type (str): Document format ('pdf', 'docx', 'pptx', 'html', 'xml', 'excel', 'csv', 'parquet'). Required. - - extract_method (str): Extraction technique. Defaults are smartly chosen based on 'document_type'. - - extract_text (bool): Enables text extraction. Default: False. + - extract_charts (bool): Enables chart extraction. Default: False. - extract_images (bool): Enables image extraction. Default: False. + - extract_method (str): Extraction technique. Defaults are smartly chosen based on 'document_type'. - extract_tables (bool): Enables table extraction. Default: False. - - extract_charts (bool): Enables chart extraction. Default: False. - - text_depth (str): Text extraction granularity ('document', 'page'). Default: 'document'. + - extract_text (bool): Enables text extraction. Default: False. + - text_depth (str): Text extraction granularity ('document', 'page'). Default: 'document'. Note: this will affect the granularity of text extraction, and the associated metadata. ie. 'page' will extract text per page and you will get page-level metadata, 'document' will extract text for the entire document so elements like page numbers will not be associated with individual text elements. \b -- store: Stores any images extracted from documents. - Options: - - structured (bool): Flag to write extracted charts and tables to object store. - - images (bool): Flag to write extracted images to object store. - - store_method (str): Storage type ('minio', ). Required. -\b -- caption: Attempts to extract captions for images extracted from documents. Note: this is not generative, but rather a - simple extraction. - Options: - N/A -\b -- dedup: Identifies and optionally filters duplicate images in extraction. +- filter: Identifies and optionally filters images above or below scale thresholds. Options: - - content_type (str): Content type to deduplicate ('image') - - filter (bool): When set to True, duplicates will be filtered, otherwise, an info message will be added. + - content_type (str): Content type to filter ('image'). + - filter (bool): When set to True, filtered images will be excluded; otherwise, an info message will be added. + - max_aspect_ratio (Union[float, int]): Maximum allowable aspect ratio of extracted image. + - min_aspect_ratio (Union[float, int]): Minimum allowable aspect ratio of extracted image. + - min_size (int): Minimum allowable size of extracted image. \b -- filter: Identifies and optionally filters images above or below scale thresholds. +- split: Divides documents according to specified criteria. Options: - - content_type (str): Content type to deduplicate ('image') - - min_size: (Union[float, int]): Minimum allowable size of extracted image. - - max_aspect_ratio: (Union[float, int]): Maximum allowable aspect ratio of extracted image. - - min_aspect_ratio: (Union[float, int]): Minimum allowable aspect ratio of extracted image. - - filter (bool): When set to True, duplicates will be filtered, otherwise, an info message will be added. + - max_character_length (int): Maximum segment character count. No default. + - sentence_window_size (int): Sentence window size. No default. + - split_by (str): Criteria ('page', 'size', 'word', 'sentence'). No default. + - split_length (int): Segment length. No default. + - split_overlap (int): Segment overlap. No default. \b -- embed: Computes embeddings on multimodal extractions. +- store: Stores any images extracted from documents. Options: - - text (bool): Flag to create embeddings for text extractions. Optional. - - tables (bool): Flag to creae embeddings for table extractions. Optional. - - filter_errors (bool): Flag to filter embedding errors. Optional. + - images (bool): Flag to write extracted images to object store. + - structured (bool): Flag to write extracted charts and tables to object store. + - store_method (str): Storage type ('minio', ). Required. \b - vdb_upload: Uploads extraction embeddings to vector database. \b diff --git a/client/src/nv_ingest_client/primitives/tasks/caption.py b/client/src/nv_ingest_client/primitives/tasks/caption.py index 43db01bd..b2883193 100644 --- a/client/src/nv_ingest_client/primitives/tasks/caption.py +++ b/client/src/nv_ingest_client/primitives/tasks/caption.py @@ -7,7 +7,7 @@ # pylint: disable=too-many-arguments import logging -from typing import Dict +from typing import Dict, Optional from pydantic import BaseModel @@ -17,29 +17,56 @@ class CaptionTaskSchema(BaseModel): + api_key: Optional[str] = None + endpoint_url: Optional[str] = None + prompt: Optional[str] = None + class Config: extra = "forbid" class CaptionTask(Task): def __init__( - self, + self, + api_key: str = None, + endpoint_url: str = None, + prompt: str = None, ) -> None: super().__init__() + self._api_key = api_key + self._endpoint_url = endpoint_url + self._prompt = prompt + def __str__(self) -> str: """ Returns a string with the object's config and run time state """ info = "" + info += "Image Caption Task:\n" + + if (self._api_key): + info += f" api_key: [redacted]\n" + if (self._endpoint_url): + info += f" endpoint_url: {self._endpoint_url}\n" + if (self._prompt): + info += f" prompt: {self._prompt}\n" + return info def to_dict(self) -> Dict: """ Convert to a dict for submission to redis """ - task_properties = { - "content_type": "image", - } + task_properties = {} + + if (self._api_key): + task_properties["api_key"] = self._api_key + + if (self._endpoint_url): + task_properties["endpoint_url"] = self._endpoint_url + + if (self._prompt): + task_properties["prompt"] = self._prompt return {"type": "caption", "task_properties": task_properties} diff --git a/client/src/nv_ingest_client/primitives/tasks/extract.py b/client/src/nv_ingest_client/primitives/tasks/extract.py index 36f43d04..9070b6fe 100644 --- a/client/src/nv_ingest_client/primitives/tasks/extract.py +++ b/client/src/nv_ingest_client/primitives/tasks/extract.py @@ -34,34 +34,46 @@ ADOBE_CLIENT_SECRET = os.environ.get("ADOBE_CLIENT_SECRET", None) _DEFAULT_EXTRACTOR_MAP = { - "pdf": "pdfium", + "csv": "pandas", "docx": "python_docx", - "pptx": "python_pptx", - "html": "beautifulsoup", - "xml": "lxml", "excel": "openpyxl", - "csv": "pandas", + "html": "beautifulsoup", + "jpeg": "image", + "jpg": "image", "parquet": "pandas", + "pdf": "pdfium", + "png": "image", + "pptx": "python_pptx", + "svg": "image", + "tiff": "image", + "xml": "lxml", } _Type_Extract_Method_PDF = Literal[ - "pdfium", + "adobe", "doughnut", "haystack", + "llama_parse", + "pdfium", "tika", "unstructured_io", - "llama_parse", - "adobe", ] _Type_Extract_Method_DOCX = Literal["python_docx", "haystack", "unstructured_local", "unstructured_service"] _Type_Extract_Method_PPTX = Literal["python_pptx", "haystack", "unstructured_local", "unstructured_service"] +_Type_Extract_Method_Image = Literal["image"] + _Type_Extract_Method_Map = { - "pdf": get_args(_Type_Extract_Method_PDF), "docx": get_args(_Type_Extract_Method_DOCX), + "jpeg": get_args(_Type_Extract_Method_Image), + "jpg": get_args(_Type_Extract_Method_Image), + "pdf": get_args(_Type_Extract_Method_PDF), + "png": get_args(_Type_Extract_Method_Image), "pptx": get_args(_Type_Extract_Method_PPTX), + "svg": get_args(_Type_Extract_Method_Image), + "tiff": get_args(_Type_Extract_Method_Image), } _Type_Extract_Tables_Method_PDF = Literal["yolox", "pdfium"] @@ -77,12 +89,14 @@ } + + class ExtractTaskSchema(BaseModel): document_type: str extract_method: str = None # Initially allow None to set a smart default - extract_text: bool = (True,) - extract_images: bool = (True,) - extract_tables: bool = False + extract_text: bool = True + extract_images: bool = True + extract_tables: bool = True extract_tables_method: str = "yolox" extract_charts: Optional[bool] = None # Initially allow None to set a smart default text_depth: str = "document" diff --git a/client/src/nv_ingest_client/util/file_processing/extract.py b/client/src/nv_ingest_client/util/file_processing/extract.py index 09d20584..97851481 100644 --- a/client/src/nv_ingest_client/util/file_processing/extract.py +++ b/client/src/nv_ingest_client/util/file_processing/extract.py @@ -21,16 +21,17 @@ # Enums class DocumentTypeEnum(str, Enum): - pdf = "pdf" - txt = "text" + bmp = "bmp" docx = "docx" - pptx = "pptx" + html = "html" jpeg = "jpeg" - bmp = "bmp" + md = "md" + pdf = "pdf" png = "png" + pptx = "pptx" svg = "svg" - html = "html" - md = "md" + tiff = "tiff" + txt = "text" # Maps MIME types to DocumentTypeEnum @@ -49,19 +50,20 @@ class DocumentTypeEnum(str, Enum): # Maps file extensions to DocumentTypeEnum EXTENSION_TO_DOCUMENT_TYPE = { - "pdf": DocumentTypeEnum.pdf, - "txt": DocumentTypeEnum.txt, - "docx": DocumentTypeEnum.docx, - "pptx": DocumentTypeEnum.pptx, - "jpg": DocumentTypeEnum.jpeg, - "jpeg": DocumentTypeEnum.jpeg, "bmp": DocumentTypeEnum.bmp, - "png": DocumentTypeEnum.png, - "svg": DocumentTypeEnum.svg, + "docx": DocumentTypeEnum.docx, "html": DocumentTypeEnum.html, + "jpeg": DocumentTypeEnum.jpeg, + "jpg": DocumentTypeEnum.jpeg, + "json": DocumentTypeEnum.txt, "md": DocumentTypeEnum.txt, + "pdf": DocumentTypeEnum.pdf, + "png": DocumentTypeEnum.png, + "pptx": DocumentTypeEnum.pptx, "sh": DocumentTypeEnum.txt, - "json": DocumentTypeEnum.txt, + "svg": DocumentTypeEnum.svg, + "tiff": DocumentTypeEnum.tiff, + "txt": DocumentTypeEnum.txt, # Add more as needed } diff --git a/docker-compose.yaml b/docker-compose.yaml index a8e2476e..29506958 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -117,7 +117,7 @@ services: runtime: nvidia nv-ingest-ms-runtime: - image: nvcr.io/ohlfw0olaadg/ea-participants/nv-ingest:24.10 + image: nvcr.io/ohlfw0olaadg/ea-participants/nv-ingest:24.10.1 build: context: ${NV_INGEST_ROOT:-.} dockerfile: "./Dockerfile" @@ -157,6 +157,7 @@ services: - YOLOX_GRPC_ENDPOINT=yolox:8001 - YOLOX_HTTP_ENDPOINT=http://yolox:8000/v1/infer - YOLOX_INFER_PROTOCOL=grpc + - VLM_CAPTION_ENDPOINT=https://ai.api.nvidia.com/v1/gr/meta/llama-3.2-90b-vision-instruct/chat/completions healthcheck: test: curl --fail http://nv-ingest-ms-runtime:7670/v1/health/ready || exit 1 interval: 10s diff --git a/requirements.txt b/requirements.txt index 43c4a6ce..5362d5d5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,7 @@ aiohttp==3.9.4 charset-normalizer click +opencv-python dataclasses farm-haystack[ocr,inference,pdf,preprocessing,file-conversion] fastapi==0.109.1 @@ -35,3 +36,4 @@ tabulate torchvision==0.18.0 unstructured-client==0.23.3 uvicorn==0.24.0-post.1 +Wand==0.6.13 diff --git a/src/nv_ingest/extraction_workflows/image/__init__.py b/src/nv_ingest/extraction_workflows/image/__init__.py new file mode 100644 index 00000000..7186f69d --- /dev/null +++ b/src/nv_ingest/extraction_workflows/image/__init__.py @@ -0,0 +1,3 @@ +from .image_handlers import image_data_extractor as image + +__all__ = ["image"] diff --git a/src/nv_ingest/extraction_workflows/image/image_handlers.py b/src/nv_ingest/extraction_workflows/image/image_handlers.py new file mode 100644 index 00000000..810276f9 --- /dev/null +++ b/src/nv_ingest/extraction_workflows/image/image_handlers.py @@ -0,0 +1,447 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import json +# Copyright (c) 2024, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import traceback +from datetime import datetime + +from typing import List, Dict +from typing import Optional +from typing import Tuple + +from wand.image import Image as WandImage +from PIL import Image +import io + +import numpy as np +import tritonclient.grpc as grpcclient + +from nv_ingest.extraction_workflows.pdf.doughnut_utils import crop_image +import nv_ingest.util.nim.yolox as yolox_utils +from nv_ingest.schemas.image_extractor_schema import ImageExtractorSchema +from nv_ingest.schemas.metadata_schema import AccessLevelEnum +from nv_ingest.util.image_processing.transforms import numpy_to_base64 +from nv_ingest.util.nim.helpers import create_inference_client +from nv_ingest.util.nim.helpers import perform_model_inference +from nv_ingest.util.pdf.metadata_aggregators import CroppedImageWithContent, construct_image_metadata_from_pdf_image, \ + construct_image_metadata_from_base64 +from nv_ingest.util.pdf.metadata_aggregators import construct_table_and_chart_metadata + +logger = logging.getLogger(__name__) + +YOLOX_MAX_BATCH_SIZE = 8 +YOLOX_MAX_WIDTH = 1536 +YOLOX_MAX_HEIGHT = 1536 +YOLOX_NUM_CLASSES = 3 +YOLOX_CONF_THRESHOLD = 0.01 +YOLOX_IOU_THRESHOLD = 0.5 +YOLOX_MIN_SCORE = 0.1 +YOLOX_FINAL_SCORE = 0.48 + +RAW_FILE_FORMATS = ["jpeg", "jpg", "png", "tiff"] +PREPROC_FILE_FORMATS = ["svg"] + +SUPPORTED_FILE_TYPES = RAW_FILE_FORMATS + ["svg"] + + +def load_and_preprocess_image(image_stream: io.BytesIO) -> np.ndarray: + """ + Loads and preprocesses a JPEG, JPG, or PNG image from a bytestream. + + Parameters + ---------- + image_stream : io.BytesIO + A bytestream of the image file. + + Returns + ------- + np.ndarray + Preprocessed image as a numpy array. + """ + # Load image from the byte stream + processed_image = Image.open(image_stream).convert("RGB") + + # Convert image to numpy array and normalize pixel values + image_array = np.asarray(processed_image, dtype=np.float32) + + return image_array + + +def convert_svg_to_bitmap(image_stream: io.BytesIO) -> np.ndarray: + """ + Converts an SVG image from a bytestream to a bitmap format. + + Parameters + ---------- + image_stream : io.BytesIO + A bytestream of the SVG file. + + Returns + ------- + np.ndarray + Preprocessed image as a numpy array in bitmap format. + """ + # Convert SVG to PNG using Wand (ImageMagick) + with WandImage(blob=image_stream.read(), format="svg") as img: + img.format = "png" + png_data = img.make_blob() + + # Reload the PNG as a PIL Image + processed_image = Image.open(io.BytesIO(png_data)).convert("RGB") + + # Convert image to numpy array and normalize pixel values + image_array = np.asarray(processed_image, dtype=np.float32) + + return image_array + + +# TODO(Devin): Move to common file +def process_inference_results( + output_array: np.ndarray, + original_image_shapes: List[Tuple[int, int]], + num_classes: int, + conf_thresh: float, + iou_thresh: float, + min_score: float, + final_thresh: float, +): + """ + Process the model output to generate detection results and expand bounding boxes. + + Parameters + ---------- + output_array : np.ndarray + The raw output from the model inference. + original_image_shapes : List[Tuple[int, int]] + The shapes of the original images before resizing, used for scaling bounding boxes. + num_classes : int + The number of classes the model can detect. + conf_thresh : float + The confidence threshold for detecting objects. + iou_thresh : float + The Intersection Over Union (IoU) threshold for non-maximum suppression. + min_score : float + The minimum score for keeping a detection. + final_thresh: float + Threshold for keeping a bounding box applied after postprocessing. + + + Returns + ------- + List[dict] + A list of dictionaries, each containing processed detection results including expanded bounding boxes. + + Notes + ----- + This function applies non-maximum suppression to the model's output and scales the bounding boxes back to the + original image size. + + Examples + -------- + >>> output_array = np.random.rand(2, 100, 85) + >>> original_image_shapes = [(1536, 1536), (1536, 1536)] + >>> results = process_inference_results(output_array, original_image_shapes, 80, 0.5, 0.5, 0.1) + >>> len(results) + 2 + """ + pred = yolox_utils.postprocess_model_prediction( + output_array, num_classes, conf_thresh, iou_thresh, class_agnostic=True + ) + results = yolox_utils.postprocess_results(pred, original_image_shapes, min_score=min_score) + logger.debug(f"Number of results: {len(results)}") + logger.debug(f"Results: {results}") + + annotation_dicts = [yolox_utils.expand_chart_bboxes(annotation_dict) for annotation_dict in results] + inference_results = [] + + # Filter out bounding boxes below the final threshold + for annotation_dict in annotation_dicts: + new_dict = {} + if "table" in annotation_dict: + new_dict["table"] = [bb for bb in annotation_dict["table"] if bb[4] >= final_thresh] + if "chart" in annotation_dict: + new_dict["chart"] = [bb for bb in annotation_dict["chart"] if bb[4] >= final_thresh] + if "title" in annotation_dict: + new_dict["title"] = annotation_dict["title"] + inference_results.append(new_dict) + + return inference_results + + +def extract_table_and_chart_images( + annotation_dict: Dict[str, List[List[float]]], + original_image: np.ndarray, + page_idx: int, + tables_and_charts: List[Tuple[int, "CroppedImageWithContent"]], +) -> None: + """ + Handle the extraction of tables and charts from the inference results and run additional model inference. + + Parameters + ---------- + annotation_dict : dict of {str : list of list of float} + A dictionary containing detected objects and their bounding boxes. Keys should include "table" and "chart", + and each key's value should be a list of bounding boxes, with each bounding box represented as a list of floats. + original_image : np.ndarray + The original image from which objects were detected, expected to be in RGB format with shape (H, W, 3). + page_idx : int + The index of the current page being processed. + tables_and_charts : list of tuple of (int, CroppedImageWithContent) + A list to which extracted tables and charts will be appended. Each item in the list is a tuple where the first + element is the page index, and the second is an instance of CroppedImageWithContent representing a cropped image + and associated metadata. + + Returns + ------- + None + + Notes + ----- + This function iterates over detected objects labeled as "table" or "chart". For each object, it crops the original + image according to the bounding box coordinates, then creates an instance of `CroppedImageWithContent` containing + the cropped image and metadata, and appends it to `tables_and_charts`. + + Examples + -------- + >>> annotation_dict = {"table": [[0.1, 0.1, 0.5, 0.5, 0.8]], "chart": [[0.6, 0.6, 0.9, 0.9, 0.9]]} + >>> original_image = np.random.rand(1536, 1536, 3) + >>> tables_and_charts = [] + >>> extract_table_and_chart_images(annotation_dict, original_image, 0, tables_and_charts) + >>> len(tables_and_charts) + 2 + """ + + width, height, *_ = original_image.shape + for label in ["table", "chart"]: + if not annotation_dict or label not in annotation_dict: + continue + + objects = annotation_dict[label] + for idx, bboxes in enumerate(objects): + *bbox, _ = bboxes + h1, w1, h2, w2 = np.array(bbox) * np.array([height, width, height, width]) + + base64_img = crop_image(original_image, (int(h1), int(w1), int(h2), int(w2))) + + table_data = CroppedImageWithContent( + content="", + image=base64_img, + bbox=(int(w1), int(h1), int(w2), int(h2)), + max_width=width, + max_height=height, + type_string=label, + ) + tables_and_charts.append((page_idx, table_data)) + + +def extract_tables_and_charts_from_image( + image: np.ndarray, + config: ImageExtractorSchema, + num_classes: int = YOLOX_NUM_CLASSES, + conf_thresh: float = YOLOX_CONF_THRESHOLD, + iou_thresh: float = YOLOX_IOU_THRESHOLD, + min_score: float = YOLOX_MIN_SCORE, + final_thresh: float = YOLOX_FINAL_SCORE, + trace_info: Optional[List] = None, +) -> List[CroppedImageWithContent]: + """ + Extract tables and charts from a single image using an ensemble of image-based models. + + This function processes a single image to detect and extract tables and charts. + It uses a sequence of models hosted on different inference servers to achieve this. + + Parameters + ---------- + image : np.ndarray + A preprocessed image array for table and chart detection. + config : ImageExtractorSchema + Configuration for the inference client, including endpoint URLs and authentication. + num_classes : int, optional + The number of classes the model is trained to detect (default is 3). + conf_thresh : float, optional + The confidence threshold for detection (default is 0.01). + iou_thresh : float, optional + The Intersection Over Union (IoU) threshold for non-maximum suppression (default is 0.5). + min_score : float, optional + The minimum score threshold for considering a detection valid (default is 0.1). + final_thresh: float, optional + Threshold for keeping a bounding box applied after postprocessing (default is 0.48). + trace_info : Optional[List], optional + Tracing information for logging or debugging purposes. + + Returns + ------- + List[CroppedImageWithContent] + A list of `CroppedImageWithContent` objects representing detected tables or charts, + each containing metadata about the detected region. + """ + tables_and_charts = [] + + yolox_client = None + try: + yolox_client = create_inference_client(config.yolox_endpoints, config.auth_token) + + input_image = yolox_utils.prepare_images_for_inference([image]) + image_shape = image.shape + + output_array = perform_model_inference(yolox_client, "yolox", input_image, trace_info=trace_info) + + yolox_annotated_detections = process_inference_results( + output_array, [image_shape], num_classes, conf_thresh, iou_thresh, min_score, final_thresh + ) + + for annotation_dict in yolox_annotated_detections: + extract_table_and_chart_images( + annotation_dict, + image, + page_idx=0, # Single image treated as one page + tables_and_charts=tables_and_charts, + ) + + except Exception as e: + logger.error(f"Error during table/chart extraction from image: {str(e)}") + traceback.print_exc() + raise e + finally: + if isinstance(yolox_client, grpcclient.InferenceServerClient): + logger.debug("Closing YOLOX inference client.") + yolox_client.close() + + logger.debug(f"Extracted {len(tables_and_charts)} tables and charts from image.") + + return tables_and_charts + + +def image_data_extractor(image_stream, + document_type: str, + extract_text: bool, + extract_images: bool, + extract_tables: bool, + extract_charts: bool, + trace_info: dict = None, + **kwargs): + """ + Helper function to extract text, images, tables, and charts from an image bytestream. + + Parameters + ---------- + image_stream : io.BytesIO + A bytestream for the image file. + document_type : str + Specifies the type of the image document ('png', 'jpeg', 'jpg', 'svg', 'tiff'). + extract_text : bool + Specifies whether to extract text. + extract_images : bool + Specifies whether to extract images. + extract_tables : bool + Specifies whether to extract tables. + extract_charts : bool + Specifies whether to extract charts. + **kwargs + Additional extraction parameters. + + Returns + ------- + list + A list of extracted data items. + """ + logger.debug(f"Extracting {document_type.upper()} image with image extractor.") + + if (document_type not in SUPPORTED_FILE_TYPES): + raise ValueError(f"Unsupported document type: {document_type}") + + row_data = kwargs.get("row_data") + source_id = row_data.get("source_id", "unknown_source") + + # Metadata extraction setup + base_unified_metadata = row_data.get(kwargs.get("metadata_column", "metadata"), {}) + current_iso_datetime = datetime.now().isoformat() + source_metadata = { + "source_name": f"{source_id}_{document_type}", + "source_id": source_id, + "source_location": row_data.get("source_location", ""), + "source_type": document_type, + "collection_id": row_data.get("collection_id", ""), + "date_created": row_data.get("date_created", current_iso_datetime), + "last_modified": row_data.get("last_modified", current_iso_datetime), + "summary": f"Raw {document_type} image extracted from source {source_id}", + "partition_id": row_data.get("partition_id", -1), + "access_level": row_data.get("access_level", AccessLevelEnum.LEVEL_1), + } + + # Prepare for extraction + extracted_data = [] + logger.debug(f"Extract text: {extract_text} (not supported yet for raw images)") + logger.debug(f"Extract images: {extract_images} (not supported yet for raw images)") + logger.debug(f"Extract tables: {extract_tables}") + logger.debug(f"Extract charts: {extract_charts}") + + # Preprocess based on image type + if (document_type in RAW_FILE_FORMATS): + logger.debug(f"Loading and preprocessing {document_type} image.") + image_array = load_and_preprocess_image(image_stream) + elif (document_type in PREPROC_FILE_FORMATS): + logger.debug(f"Converting {document_type} to bitmap.") + image_array = convert_svg_to_bitmap(image_stream) + else: + raise ValueError(f"Unsupported document type: {document_type}") + + # Text extraction stub + if extract_text: + # Future function for text extraction based on document_type + logger.warning("Text extraction is not supported for raw images.") + + # Image extraction stub + if extract_images: + # Placeholder for image-specific extraction process + extracted_data.append( + construct_image_metadata_from_base64( + numpy_to_base64(image_array), + page_idx=0, # Single image treated as one page + page_count=1, + source_metadata=source_metadata, + base_unified_metadata=base_unified_metadata, + ) + ) + + # Table and chart extraction + if extract_tables or extract_charts: + try: + tables_and_charts = extract_tables_and_charts_from_image( + image_array, + config=kwargs.get("image_extraction_config"), + trace_info=trace_info, + ) + logger.debug(f"Extracted table/chart data from image") + for _, table_chart_data in tables_and_charts: + extracted_data.append( + construct_table_and_chart_metadata( + table_chart_data, + page_idx=0, # Single image treated as one page + page_count=1, + source_metadata=source_metadata, + base_unified_metadata=base_unified_metadata, + ) + ) + except Exception as e: + logger.error(f"Error extracting tables/charts from image: {e}") + + logger.debug(f"Extracted {len(extracted_data)} items from the image.") + + return extracted_data diff --git a/src/nv_ingest/extraction_workflows/pdf/__init__.py b/src/nv_ingest/extraction_workflows/pdf/__init__.py index 51ecd39d..ff752ef5 100644 --- a/src/nv_ingest/extraction_workflows/pdf/__init__.py +++ b/src/nv_ingest/extraction_workflows/pdf/__init__.py @@ -6,7 +6,7 @@ from nv_ingest.extraction_workflows.pdf.adobe_helper import adobe from nv_ingest.extraction_workflows.pdf.doughnut_helper import doughnut from nv_ingest.extraction_workflows.pdf.llama_parse_helper import llama_parse -from nv_ingest.extraction_workflows.pdf.pdfium_helper import pdfium +from nv_ingest.extraction_workflows.pdf.pdfium_helper import pdfium_extractor as pdfium from nv_ingest.extraction_workflows.pdf.tika_helper import tika from nv_ingest.extraction_workflows.pdf.unstructured_io_helper import unstructured_io diff --git a/src/nv_ingest/extraction_workflows/pdf/doughnut_helper.py b/src/nv_ingest/extraction_workflows/pdf/doughnut_helper.py index b0239e93..bed95c97 100644 --- a/src/nv_ingest/extraction_workflows/pdf/doughnut_helper.py +++ b/src/nv_ingest/extraction_workflows/pdf/doughnut_helper.py @@ -41,7 +41,7 @@ from nv_ingest.util.image_processing.transforms import numpy_to_base64 from nv_ingest.util.pdf.metadata_aggregators import Base64Image from nv_ingest.util.pdf.metadata_aggregators import LatexTable -from nv_ingest.util.pdf.metadata_aggregators import construct_image_metadata +from nv_ingest.util.pdf.metadata_aggregators import construct_image_metadata_from_pdf_image from nv_ingest.util.pdf.metadata_aggregators import construct_text_metadata from nv_ingest.util.pdf.metadata_aggregators import extract_pdf_metadata from nv_ingest.util.pdf.pdfium import pdfium_pages_to_numpy @@ -221,7 +221,7 @@ def doughnut(pdf_stream, extract_text: bool, extract_images: bool, extract_table if extract_images: for image in accumulated_images: extracted_data.append( - construct_image_metadata( + construct_image_metadata_from_pdf_image( image, page_idx, pdf_metadata.page_count, diff --git a/src/nv_ingest/extraction_workflows/pdf/pdfium_helper.py b/src/nv_ingest/extraction_workflows/pdf/pdfium_helper.py index 7a1de0f1..9216c12f 100644 --- a/src/nv_ingest/extraction_workflows/pdf/pdfium_helper.py +++ b/src/nv_ingest/extraction_workflows/pdf/pdfium_helper.py @@ -27,8 +27,8 @@ import numpy as np import pypdfium2 as libpdfium import tritonclient.grpc as grpcclient +import nv_ingest.util.nim.yolox as yolox_utils -from nv_ingest.extraction_workflows.pdf import yolox_utils from nv_ingest.schemas.metadata_schema import AccessLevelEnum from nv_ingest.schemas.metadata_schema import TextTypeEnum from nv_ingest.schemas.pdf_extractor_schema import PDFiumConfigSchema @@ -36,11 +36,12 @@ from nv_ingest.util.image_processing.transforms import numpy_to_base64 from nv_ingest.util.nim.helpers import create_inference_client from nv_ingest.util.nim.helpers import perform_model_inference +from nv_ingest.util.nim.yolox import prepare_images_for_inference from nv_ingest.util.pdf.metadata_aggregators import Base64Image -from nv_ingest.util.pdf.metadata_aggregators import CroppedImageWithContent -from nv_ingest.util.pdf.metadata_aggregators import construct_image_metadata +from nv_ingest.util.pdf.metadata_aggregators import construct_image_metadata_from_pdf_image from nv_ingest.util.pdf.metadata_aggregators import construct_table_and_chart_metadata from nv_ingest.util.pdf.metadata_aggregators import construct_text_metadata +from nv_ingest.util.pdf.metadata_aggregators import CroppedImageWithContent from nv_ingest.util.pdf.metadata_aggregators import extract_pdf_metadata from nv_ingest.util.pdf.pdfium import PDFIUM_PAGEOBJ_MAPPING from nv_ingest.util.pdf.pdfium import pdfium_pages_to_numpy @@ -179,37 +180,6 @@ def extract_tables_and_charts_using_image_ensemble( return tables_and_charts -def prepare_images_for_inference(images: List[np.ndarray]) -> np.ndarray: - """ - Prepare a list of images for model inference by resizing and reordering axes. - - Parameters - ---------- - images : List[np.ndarray] - A list of image arrays to be prepared for inference. - - Returns - ------- - np.ndarray - A numpy array suitable for model input, with the shape reordered to match the expected input format. - - Notes - ----- - The images are resized to 1024x1024 pixels and the axes are reordered to match the expected input shape for - the model (batch, channels, height, width). - - Examples - -------- - >>> images = [np.random.rand(1536, 1536, 3) for _ in range(2)] - >>> input_array = prepare_images_for_inference(images) - >>> input_array.shape - (2, 3, 1024, 1024) - """ - - resized_images = [yolox_utils.resize_image(image, (1024, 1024)) for image in images] - return np.einsum("bijk->bkij", resized_images).astype(np.float32) - - def process_inference_results( output_array: np.ndarray, original_image_shapes: List[Tuple[int, int]], @@ -292,7 +262,7 @@ def extract_table_and_chart_images( Parameters ---------- - annotation_dict : dict + annotation_dict : dict/ A dictionary containing detected objects and their bounding boxes. original_image : np.ndarray The original image from which objects were detected. @@ -337,7 +307,7 @@ def extract_table_and_chart_images( # Define a helper function to use unstructured-io to extract text from a base64 # encoded bytestream PDF -def pdfium( +def pdfium_extractor( pdf_stream, extract_text: bool, extract_images: bool, @@ -453,7 +423,7 @@ def pdfium( if obj_type == "IMAGE": try: # Attempt to retrieve the image bitmap - image_numpy: np.ndarray = pdfium_try_get_bitmap_as_numpy(obj) # noqa + image_numpy: np.ndarray = pdfium_try_get_bitmap_as_numpy(obj) # noqa image_base64: str = numpy_to_base64(image_numpy) image_bbox = obj.get_pos() image_size = obj.get_size() @@ -462,7 +432,7 @@ def pdfium( max_width=page_width, max_height=page_height ) - extracted_image_data = construct_image_metadata( + extracted_image_data = construct_image_metadata_from_pdf_image( image_data, page_idx, pdf_metadata.page_count, diff --git a/src/nv_ingest/schemas/image_caption_extraction_schema.py b/src/nv_ingest/schemas/image_caption_extraction_schema.py index 10584eed..4f6cdb3f 100644 --- a/src/nv_ingest/schemas/image_caption_extraction_schema.py +++ b/src/nv_ingest/schemas/image_caption_extraction_schema.py @@ -7,9 +7,9 @@ class ImageCaptionExtractionSchema(BaseModel): - batch_size: int = 8 - caption_classifier_model_name: str = "deberta_large" - endpoint_url: str = "triton:8001" + api_key: str + endpoint_url: str = "https://ai.api.nvidia.com/v1/gr/meta/llama-3.2-90b-vision-instruct/chat/completions" + prompt: str = "Caption the content of this image:" raise_on_failure: bool = False class Config: diff --git a/src/nv_ingest/schemas/image_extractor_schema.py b/src/nv_ingest/schemas/image_extractor_schema.py new file mode 100644 index 00000000..21e6b4c6 --- /dev/null +++ b/src/nv_ingest/schemas/image_extractor_schema.py @@ -0,0 +1,127 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + + +import logging +from typing import Optional +from typing import Tuple + +from pydantic import BaseModel +from pydantic import root_validator + +logger = logging.getLogger(__name__) + + +class ImageConfigSchema(BaseModel): + """ + Configuration schema for image extraction endpoints and options. + + Parameters + ---------- + auth_token : Optional[str], default=None + Authentication token required for secure services. + + yolox_endpoints : Tuple[str, str] + A tuple containing the gRPC and HTTP services for the yolox endpoint. + Either the gRPC or HTTP service can be empty, but not both. + + Methods + ------- + validate_endpoints(values) + Validates that at least one of the gRPC or HTTP services is provided for each endpoint. + + Raises + ------ + ValueError + If both gRPC and HTTP services are empty for any endpoint. + + Config + ------ + extra : str + Pydantic config option to forbid extra fields. + """ + + auth_token: Optional[str] = None + + yolox_endpoints: Tuple[Optional[str], Optional[str]] = (None, None) + yolox_infer_protocol: str = "" + + @root_validator(pre=True) + def validate_endpoints(cls, values): + """ + Validates the gRPC and HTTP services for all endpoints. + + Parameters + ---------- + values : dict + Dictionary containing the values of the attributes for the class. + + Returns + ------- + dict + The validated dictionary of values. + + Raises + ------ + ValueError + If both gRPC and HTTP services are empty for any endpoint. + """ + + def clean_service(service): + """Set service to None if it's an empty string or contains only spaces or quotes.""" + if service is None or not service.strip() or service.strip(" \"'") == "": + return None + return service + + for model_name in ["yolox"]: + endpoint_name = f"{model_name}_endpoints" + grpc_service, http_service = values.get(endpoint_name) + grpc_service = clean_service(grpc_service) + http_service = clean_service(http_service) + + if not grpc_service and not http_service: + raise ValueError(f"Both gRPC and HTTP services cannot be empty for {endpoint_name}.") + + values[endpoint_name] = (grpc_service, http_service) + + protocol_name = f"{model_name}_infer_protocol" + protocol_value = values.get(protocol_name) + if not protocol_value: + protocol_value = "http" if http_service else "grpc" if grpc_service else "" + protocol_value = protocol_value.lower() + values[protocol_name] = protocol_value + + return values + + class Config: + extra = "forbid" + + +class ImageExtractorSchema(BaseModel): + """ + Configuration schema for the PDF extractor settings. + + Parameters + ---------- + max_queue_size : int, default=1 + The maximum number of items allowed in the processing queue. + + n_workers : int, default=16 + The number of worker threads to use for processing. + + raise_on_failure : bool, default=False + A flag indicating whether to raise an exception on processing failure. + + image_extraction_config: Optional[ImageConfigSchema], default=None + Configuration schema for the image extraction stage. + """ + + max_queue_size: int = 1 + n_workers: int = 16 + raise_on_failure: bool = False + + image_extraction_config: Optional[ImageConfigSchema] = None + + class Config: + extra = "forbid" diff --git a/src/nv_ingest/schemas/ingest_job_schema.py b/src/nv_ingest/schemas/ingest_job_schema.py index 8b2dc7ef..9d001202 100644 --- a/src/nv_ingest/schemas/ingest_job_schema.py +++ b/src/nv_ingest/schemas/ingest_job_schema.py @@ -24,15 +24,16 @@ # Enums class DocumentTypeEnum(str, Enum): - pdf = "pdf" - txt = "text" + bmp = "bmp" docx = "docx" - pptx = "pptx" + html = "html" jpeg = "jpeg" - bmp = "bmp" + pdf = "pdf" png = "png" + pptx = "pptx" svg = "svg" - html = "html" + tiff = "tiff" + txt = "text" class TaskTypeEnum(str, Enum): @@ -94,9 +95,11 @@ class IngestTaskStoreSchema(BaseModelNoExt): params: dict +# All optional, the captioning stage requires default parameters, each of these are just overrides. class IngestTaskCaptionSchema(BaseModelNoExt): - content_type: str = "image" - n_neighbors: int = 5 + api_key: Optional[str] + endpoint_url: Optional[str] + prompt: Optional[str] class IngestTaskFilterParamsSchema(BaseModelNoExt): @@ -133,9 +136,11 @@ class IngestTaskVdbUploadSchema(BaseModelNoExt): class IngestTaskTableExtraction(BaseModelNoExt): params: Dict = {} + class IngestChartTableExtraction(BaseModelNoExt): params: Dict = {} + class IngestTaskSchema(BaseModelNoExt): type: TaskTypeEnum task_properties: Union[ diff --git a/src/nv_ingest/schemas/metadata_schema.py b/src/nv_ingest/schemas/metadata_schema.py index 6a51c83d..97a64ae6 100644 --- a/src/nv_ingest/schemas/metadata_schema.py +++ b/src/nv_ingest/schemas/metadata_schema.py @@ -40,32 +40,33 @@ class ContentTypeEnum(str, Enum): TEXT = "text" IMAGE = "image" STRUCTURED = "structured" + UNSTRUCTURED = "unstructured" INFO_MSG = "info_message" class StdContentDescEnum(str, Enum): - PDF_TEXT = "Unstructured text from PDF document." - PDF_IMAGE = "Image extracted from PDF document." - PDF_TABLE = "Structured table extracted from PDF document." - PDF_CHART = "Structured chart extracted from PDF document." - DOCX_TEXT = "Unstructured text from DOCX document." DOCX_IMAGE = "Image extracted from DOCX document." DOCX_TABLE = "Structured table extracted from DOCX document." - PPTX_TEXT = "Unstructured text from PPTX presentation." + DOCX_TEXT = "Unstructured text from DOCX document." + PDF_CHART = "Structured chart extracted from PDF document." + PDF_IMAGE = "Image extracted from PDF document." + PDF_TABLE = "Structured table extracted from PDF document." + PDF_TEXT = "Unstructured text from PDF document." PPTX_IMAGE = "Image extracted from PPTX presentation." PPTX_TABLE = "Structured table extracted from PPTX presentation." + PPTX_TEXT = "Unstructured text from PPTX presentation." class TextTypeEnum(str, Enum): - HEADER = "header" - BODY = "body" - SPAN = "span" - LINE = "line" BLOCK = "block" - PAGE = "page" + BODY = "body" DOCUMENT = "document" + HEADER = "header" + LINE = "line" NEARBY_BLOCK = "nearby_block" OTHER = "other" + PAGE = "page" + SPAN = "span" class LanguageEnum(str, Enum): @@ -132,10 +133,10 @@ def has_value(cls, value): class ImageTypeEnum(str, Enum): - JPEG = "jpeg" - PNG = "png" BMP = "bmp" GIF = "gif" + JPEG = "jpeg" + PNG = "png" TIFF = "tiff" image_type_1 = "image_type_1" # until classifier developed @@ -148,9 +149,9 @@ def has_value(cls, value): class TableFormatEnum(str, Enum): HTML = "html" - MARKDOWN = "markdown" - LATEX = "latex" IMAGE = "image" + LATEX = "latex" + MARKDOWN = "markdown" class TaskTypeEnum(str, Enum): @@ -249,13 +250,6 @@ class TextMetadataSchema(BaseModelNoExt): text_location: tuple = (0, 0, 0, 0) -import logging -from pydantic import validator - -# Set up logging -logger = logging.getLogger(__name__) - - class ImageMetadataSchema(BaseModelNoExt): image_type: Union[ImageTypeEnum, str] structured_image_type: ImageTypeEnum = ImageTypeEnum.image_type_1 @@ -317,6 +311,7 @@ class InfoMessageMetadataSchema(BaseModelNoExt): # Main metadata schema class MetadataSchema(BaseModelNoExt): content: str = "" + content_url: str = "" embedding: Optional[List[float]] = None source_metadata: Optional[SourceMetadataSchema] = None content_metadata: Optional[ContentMetadataSchema] = None diff --git a/tests/stages/__init__.py b/src/nv_ingest/stages/extractors/__init__.py similarity index 100% rename from tests/stages/__init__.py rename to src/nv_ingest/stages/extractors/__init__.py diff --git a/src/nv_ingest/stages/extractors/image_extractor_stage.py b/src/nv_ingest/stages/extractors/image_extractor_stage.py new file mode 100644 index 00000000..85afa4c2 --- /dev/null +++ b/src/nv_ingest/stages/extractors/image_extractor_stage.py @@ -0,0 +1,210 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + + +import base64 +import functools +import io +import logging +import traceback +from typing import Any +from typing import Dict +from typing import List +from typing import Optional + +import pandas as pd +import nv_ingest.extraction_workflows.image as image_helpers +from morpheus.config import Config +from nv_ingest.schemas.image_extractor_schema import ImageExtractorSchema +from nv_ingest.stages.multiprocessing_stage import MultiProcessingBaseStage + +logger = logging.getLogger(f"morpheus.{__name__}") + + +def decode_and_extract( + base64_row: pd.Series, + task_props: Dict[str, Any], + validated_config: Any, + default: str = "image", + trace_info: Optional[List] = None, +) -> Any: + """ + Decodes base64 content from a row and extracts data from it using the specified extraction method. + + Parameters + ---------- + base64_row : dict + A dictionary containing the base64-encoded content and other relevant data. + The key "content" should contain the base64 string, and the key "source_id" is optional. + task_props : dict + A dictionary containing task properties. It should have the keys: + - "method" (str): The extraction method to use (e.g., "image"). + - "params" (dict): Parameters to pass to the extraction function. + validated_config : Any + Configuration object that contains `image_config`. Used if the `image` method is selected. + default : str, optional + The default extraction method to use if the specified method in `task_props` is not available (default is "image"). + + Returns + ------- + Any + The extracted data from the decoded content. The exact return type depends on the extraction method used. + + Raises + ------ + KeyError + If the "content" key is missing from `base64_row`. + Exception + For any other unhandled exceptions during extraction, an error is logged, and the exception is re-raised. + """ + + document_type = base64_row["document_type"] + source_id = None + try: + base64_content = base64_row["content"] + except KeyError: + log_error_message = f"Unhandled error processing row, no content was found:\n{base64_row}" + logger.error(log_error_message) + raise + + try: + # Row data to include in extraction + bool_index = base64_row.index.isin(("content",)) + row_data = base64_row[~bool_index] + task_props["params"]["row_data"] = row_data + + # Get source_id + source_id = base64_row["source_id"] if "source_id" in base64_row.index else None + # Decode the base64 content + image_bytes = base64.b64decode(base64_content) + + # Load the PDF + image_stream = io.BytesIO(image_bytes) + + # Type of extraction method to use + extract_method = task_props.get("method", "image") + extract_params = task_props.get("params", {}) + + logger.debug( + f">>> Extracting image content, image_extraction_config: {validated_config.image_extraction_config}") + if (validated_config.image_extraction_config is not None): + extract_params["image_extraction_config"] = validated_config.image_extraction_config + + if (trace_info is not None): + extract_params["trace_info"] = trace_info + + if (not hasattr(image_helpers, extract_method)): + extract_method = default + + func = getattr(image_helpers, extract_method, default) + logger.debug("Running extraction method: %s", extract_method) + extracted_data = func(image_stream, document_type, **extract_params) + + return extracted_data + + except Exception as e: + traceback.print_exc() + err_msg = f"Unhandled exception in decode_and_extract for '{source_id}':\n{e}" + logger.error(err_msg) + + raise + + # Propagate error back and tag message as failed. + # exception_tag = create_exception_tag(error_message=log_error_message, source_id=source_id) + + +def process_image( + df: pd.DataFrame, + task_props: Dict[str, Any], + validated_config: Any, + trace_info: Optional[Dict[str, Any]] = None +) -> pd.DataFrame: + """ + Processes a pandas DataFrame containing image files in base64 encoding. + Each image's content is replaced with its extracted components. + + Parameters + ---------- + df : pd.DataFrame + The input DataFrame with columns 'source_id' and 'content' (base64-encoded image data). + task_props : dict + Dictionary containing instructions and parameters for the image processing task. + validated_config : Any + Configuration object validated for processing images. + trace_info : dict, optional + Dictionary for tracing and logging additional information during processing (default is None). + + Returns + ------- + Tuple[pd.DataFrame, Dict[str, Any]] + A tuple containing: + - A pandas DataFrame with the processed image content, including columns 'document_type', 'metadata', and 'uuid'. + - A dictionary with trace information collected during processing. + + Raises + ------ + Exception + If an error occurs during the image processing stage. + """ + logger.debug("Processing image content") + if trace_info is None: + trace_info = {} + + try: + # Apply the helper function to each row in the 'content' column + _decode_and_extract = functools.partial( + decode_and_extract, task_props=task_props, validated_config=validated_config, trace_info=trace_info + ) + logger.debug(f"Processing method: {task_props.get('method', None)}") + sr_extraction = df.apply(_decode_and_extract, axis=1) + sr_extraction = sr_extraction.explode().dropna() + + if not sr_extraction.empty: + extracted_df = pd.DataFrame(sr_extraction.to_list(), columns=["document_type", "metadata", "uuid"]) + else: + extracted_df = pd.DataFrame({"document_type": [], "metadata": [], "uuid": []}) + + return extracted_df, {"trace_info": trace_info} + + except Exception as e: + err_msg = f"Unhandled exception in image extractor stage's process_image: {e}" + logger.error(err_msg) + raise + + +def generate_image_extractor_stage( + c: Config, + extractor_config: Dict[str, Any], + task: str = "extract", + task_desc: str = "image_content_extractor", + pe_count: int = 24, +): + """ + Helper function to generate a multiprocessing stage to perform image content extraction. + + Parameters + ---------- + c : Config + Morpheus global configuration object + extractor_config : dict + Configuration parameters for pdf content extractor. + task : str + The task name to match for the stage worker function. + task_desc : str + A descriptor to be used in latency tracing. + pe_count : int + Integer for how many process engines to use for pdf content extraction. + + Returns + ------- + MultiProcessingBaseStage + A Morpheus stage with applied worker function. + """ + validated_config = ImageExtractorSchema(**extractor_config) + _wrapped_process_fn = functools.partial(process_image, validated_config=validated_config) + + return MultiProcessingBaseStage( + c=c, pe_count=pe_count, task=task, task_desc=task_desc, process_fn=_wrapped_process_fn, + document_type="regex:^(png|svg|jpeg|jpg|tiff)$" + ) diff --git a/src/nv_ingest/stages/multiprocessing_stage.py b/src/nv_ingest/stages/multiprocessing_stage.py index b3af02d6..c4c5d0bd 100644 --- a/src/nv_ingest/stages/multiprocessing_stage.py +++ b/src/nv_ingest/stages/multiprocessing_stage.py @@ -155,7 +155,7 @@ def __init__( task_desc: str, pe_count: int, process_fn: typing.Callable[[pd.DataFrame, dict], pd.DataFrame], - document_type: str = None, + document_type: typing.Union[typing.List[str],str] = None, filter_properties: dict = None, ): super().__init__(c) diff --git a/src/nv_ingest/stages/transforms/image_caption_extraction.py b/src/nv_ingest/stages/transforms/image_caption_extraction.py index d1b702bb..40b798fe 100644 --- a/src/nv_ingest/stages/transforms/image_caption_extraction.py +++ b/src/nv_ingest/stages/transforms/image_caption_extraction.py @@ -2,452 +2,161 @@ # All rights reserved. # SPDX-License-Identifier: Apache-2.0 - +import base64 +import io import logging -import traceback from functools import partial -from typing import Any +from typing import Any, Optional from typing import Dict -from typing import List from typing import Tuple -import numpy as np import pandas as pd -import tritonclient.grpc as grpcclient +import requests +from PIL import Image from morpheus.config import Config -from morpheus.utils.module_utils import ModuleLoaderFactory -from sklearn.neighbors import NearestNeighbors -from transformers import AutoTokenizer from nv_ingest.schemas.image_caption_extraction_schema import ImageCaptionExtractionSchema from nv_ingest.schemas.metadata_schema import ContentTypeEnum from nv_ingest.stages.multiprocessing_stage import MultiProcessingBaseStage +from nv_ingest.util.image_processing.transforms import scale_image_to_encoding_size +from nv_ingest.util.tracing.tagging import traceable_func logger = logging.getLogger(__name__) MODULE_NAME = "image_caption_extraction" MODULE_NAMESPACE = "nv_ingest" -ImageCaptionExtractionLoaderFactory = ModuleLoaderFactory(MODULE_NAME, MODULE_NAMESPACE) - - -def _extract_bboxes_and_content(data: Dict[str, Any]) -> Tuple[List[Tuple[int, int, int, int]], List[str]]: - """ - Extract bounding boxes and associated content from a deeply nested data structure. - - Parameters - ---------- - data : Dict[str, Any] - A dictionary containing nested data from which bounding boxes and content are extracted. - - Returns - ------- - Tuple[List[Tuple[int, int, int, int]], List[str]] - A tuple containing two lists: - - First list of tuples representing bounding boxes (x1, y1, x2, y2). - - Second list of strings representing associated content. - """ - nearby_objects = data["content_metadata"]["hierarchy"]["nearby_objects"]["text"] - bboxes = nearby_objects["bbox"] - content = nearby_objects["content"] - return bboxes, content - - -def _calculate_centroids(bboxes: List[Tuple[int, int, int, int]]) -> List[Tuple[float, float]]: - """ - Calculate centroids from bounding boxes. - - Parameters - ---------- - bboxes : List[Tuple[int, int, int, int]] - A list of tuples each representing a bounding box as (x1, y1, x2, y2). - - Returns - ------- - List[Tuple[float, float]] - A list of tuples each representing the centroid (x, y) of the corresponding bounding box. - """ - return [(bbox[0] + (bbox[2] - bbox[0]) / 2, bbox[1] + (bbox[3] - bbox[1]) / 2) for bbox in bboxes] - - -def _fit_nearest_neighbors(centroids: List[Tuple[float, float]], n_neighbors: int = 5) -> Tuple[NearestNeighbors, int]: - """ - Fit the NearestNeighbors model to the centroids, ensuring the number of neighbors does not exceed available - centroids. - - Parameters - ---------- - centroids : List[Tuple[float, float]] - A list of tuples each representing the centroid coordinates (x, y) of bounding boxes. - n_neighbors : int, optional - The number of neighbors to use by default for kneighbors queries. - - Returns - ------- - Tuple[NearestNeighbors, int] - A tuple containing: - - NearestNeighbors model fitted to the centroids. - - The adjusted number of neighbors, which is the minimum of `n_neighbors` and the number of centroids. - """ - centroids_array = np.array(centroids) - adjusted_n_neighbors = min(n_neighbors, len(centroids_array)) - nbrs = NearestNeighbors(n_neighbors=adjusted_n_neighbors, algorithm="auto", metric="euclidean") - nbrs.fit(centroids_array) - - return nbrs, adjusted_n_neighbors - - -def _find_nearest_neighbors( - nbrs: NearestNeighbors, new_bbox: Tuple[int, int, int, int], content: List[str], n_neighbors: int -) -> Tuple[np.ndarray, np.ndarray, List[str]]: - """ - Find the nearest neighbors for a new bounding box and return associated content. - - Parameters - ---------- - nbrs : NearestNeighbors - The trained NearestNeighbors model. - new_bbox : Tuple[int, int, int, int] - The bounding box for which to find the nearest neighbors, specified as (x1, y1, x2, y2). - content : List[str] - A list of content strings associated with each bounding box. - n_neighbors : int - The number of neighbors to retrieve. - - Returns - ------- - Tuple[np.ndarray, np.ndarray, List[str]] - A tuple containing: - - distances: An array of distances to the nearest neighbors. - - indices: An array of indices for the nearest neighbors. - - A list of content strings corresponding to the nearest neighbors. - """ - new_centroid = np.array( - [(new_bbox[0] + (new_bbox[2] - new_bbox[0]) / 2, new_bbox[1] + (new_bbox[3] - new_bbox[1]) / 2)] - ) - new_centroid_reshaped = new_centroid.reshape(1, -1) # Reshape to ensure 2D - distances, indices = nbrs.kneighbors(new_centroid_reshaped, n_neighbors=n_neighbors) - return distances, indices, [content[i] for i in indices.flatten()] - - -def _sanitize_inputs(inputs: List[List[str]]) -> List[List[str]]: - """ - Replace non-ASCII characters with '?' in inputs. - - Parameters - ---------- - inputs : List[List[str]] - A list of lists where each sub-list contains strings. - - Returns - ------- - List[List[str]] - A list of lists where each string has been sanitized to contain only ASCII characters, - with non-ASCII characters replaced by '?'. - """ - cleaned_inputs = [ - [candidate.encode("ascii", "replace").decode("ascii") for candidate in candidates] for candidates in inputs - ] - return cleaned_inputs - - -TOKENIZER_NAME = "microsoft/deberta-large" -tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME) - - -def _predict_caption(triton_url: str, caption_model: str, inputs: List[List[str]], n_candidates: int = 5) -> List[str]: - """ - Sends a request to a Triton inference server to generate captions based on provided inputs. - - Parameters - ---------- - triton_url : str - The URL of the Triton inference server. - headers : Dict[str, str] - HTTP headers to send with the request. - inputs : List[List[str]] - The input data for which captions are generated. - n_candidates : int, optional - The number of candidates per input data. Default is 5. - - Returns - ------- - List[str] - A list of generated captions, one for each input. - """ - - sanitized_inputs = _sanitize_inputs(inputs) - captions = [""] * len(sanitized_inputs) - max_batch_size = 128 - - try: - client = grpcclient.InferenceServerClient(url=triton_url) - - # Process inputs in batches - for batch_start in range(0, len(sanitized_inputs), max_batch_size // n_candidates): - batch_end = min(batch_start + (max_batch_size // n_candidates), len(sanitized_inputs)) - input_batch = sanitized_inputs[batch_start:batch_end] - flattened_sentences = [sentence for batch in input_batch for sentence in batch] - encoded_inputs = tokenizer( - flattened_sentences, max_length=128, padding="max_length", truncation=True, return_tensors="np" - ) - - input_ids = encoded_inputs["input_ids"].astype(np.int64) - attention_mask = encoded_inputs["attention_mask"].astype(np.float32) - infer_inputs = [ - grpcclient.InferInput("input_ids", input_ids.shape, "INT64"), - grpcclient.InferInput("input_mask", attention_mask.shape, "FP32"), - ] - - # Set the data for the input tensors - infer_inputs[0].set_data_from_numpy(input_ids) - infer_inputs[1].set_data_from_numpy(attention_mask) - - outputs = [grpcclient.InferRequestedOutput("output")] - - # Perform inference - response = client.infer(model_name=caption_model, inputs=infer_inputs, outputs=outputs) - - output_data = response.as_numpy("output") - - # Process the output to find the best sentence in each batch - batch_size = n_candidates - for i, batch in enumerate(input_batch): - start_idx = i * batch_size - end_idx = (i + 1) * batch_size - batch_output = output_data[start_idx:end_idx] - max_index = np.argmax(batch_output) - best_sentence = batch[max_index] - best_probability = batch_output[max_index] - if best_probability > 0.5: - captions[batch_start + i] = best_sentence - - except Exception as e: - logging.error(f"An error occurred: {e}") - - return captions - def _prepare_dataframes_mod(df) -> Tuple[pd.DataFrame, pd.DataFrame, pd.Series]: if df.empty or "document_type" not in df.columns: return df, pd.DataFrame(), pd.Series(dtype=bool) bool_index = df["document_type"] == ContentTypeEnum.IMAGE - df_filtered = df.loc[bool_index] - - return df, df_filtered, bool_index - - -def _process_documents(df_filtered: pd.DataFrame) -> Tuple[List[Any], List[List[str]]]: - """ - Processes documents to extract content and bounding boxes, then finds nearest neighbors. - - Parameters - ---------- - df_filtered : pd.DataFrame - The dataframe filtered to contain only relevant documents. - - Returns - ------- - Tuple[List[Any], List[List[str]]] - A tuple containing metadata for each document and a list of neighbor content. - """ - neighbor_content = [] - metadata_list = [] - for _, row in df_filtered.iterrows(): - metadata = row.metadata - metadata_list.append(metadata) - bboxes, content = _extract_bboxes_and_content(metadata) - _process_content(bboxes, content, metadata, neighbor_content) - - return metadata_list, neighbor_content - - -def _process_content( - bboxes: List[Tuple[int, int, int, int]], - content: List[str], - metadata: Dict[str, Any], - neighbor_content: List[List[str]], - n_neighbors: int = 5, -) -> None: - """ - Process content by finding nearest neighbors and appending the results. - - Parameters - ---------- - bboxes : List[Tuple[int, int, int, int]] - A list of bounding boxes. - content : List[str] - Content associated with each bounding box. - metadata : Dict[str, Any] - Metadata associated with each content piece, containing image metadata. - neighbor_content : List[List[str]] - A list that will be appended with the nearest neighbor content. - n_neighbors : int, optional - The number of nearest neighbors to find (default is 5). - - Returns - ------- - None - """ - if bboxes and content: - centroids = _calculate_centroids(bboxes) - nn_mod, adj_neighbors = _fit_nearest_neighbors(centroids) - image_bbox = metadata["image_metadata"]["image_location"] - distances, indices, nearest_content = _find_nearest_neighbors(nn_mod, image_bbox, content, adj_neighbors) - else: - nearest_content = [] - - if len(nearest_content) < n_neighbors: - nearest_content.extend([""] * (n_neighbors - len(nearest_content))) + df_matched = df.loc[bool_index] - neighbor_content.append(nearest_content) + return df, df_matched, bool_index -def _generate_captions(neighbor_content: List[List[str]], config: Any) -> List[str]: +def _generate_captions(base64_image: str, prompt: str, api_key: str, endpoint_url: str) -> str: """ - Generate captions for provided content using a Triton inference server. + Sends a base64-encoded PNG image to the NVIDIA LLaMA model API and retrieves the generated caption. Parameters ---------- - neighbor_content : List[List[str]] - A list of content batches for which to generate captions. - config : Any - Configuration object containing endpoint URL, headers, and batch size. + base64_image : str + Base64-encoded PNG image string. + api_key : str + API key for authentication with the NVIDIA model endpoint. Returns ------- - List[str] - A list of generated captions. + str + Generated caption for the image or an error message. """ - captions = [] - for i in range(0, len(neighbor_content), config.batch_size): - batch = neighbor_content[i : i + config.batch_size] # noqa: E203 - batch_captions = _predict_caption(config.endpoint_url, config.caption_classifier_model_name, batch) - captions.extend(batch_captions) - - return captions - + stream = False # Set to False for non-streaming response -def _update_metadata_with_captions( - metadata_list: List[Dict[str, Any]], captions: List[str], df_filtered: pd.DataFrame -) -> List[Dict[str, Any]]: - """ - Update metadata with captions and compile into a list of image document dictionaries. + # Ensure the base64 image size is within acceptable limits + base64_image = scale_image_to_encoding_size(base64_image) - Parameters - ---------- - metadata_list : List[Dict[str, Any]] - A list of metadata dictionaries. - captions : List[str] - A list of captions corresponding to the metadata. - df_filtered : pd.DataFrame - The filtered DataFrame containing document UUIDs. + headers = { + "Authorization": f"Bearer {api_key}", + "Accept": "application/json" + } - Returns - ------- - List[Dict[str, Any]] - A list of dictionaries each containing updated document metadata and type. - """ - image_docs = [] - for metadata, caption, (_, row) in zip(metadata_list, captions, df_filtered.iterrows()): - metadata["image_metadata"]["caption"] = caption - image_docs.append( + # Payload for the request + payload = { + "model": 'meta/llama-3.2-90b-vision-instruct', + "messages": [ { - "document_type": ContentTypeEnum.IMAGE.value, - "metadata": metadata, - "uuid": row.get("uuid"), + "role": "user", + "content": f'{prompt} ' } - ) - - return image_docs - - -def _prepare_final_dataframe_mod(df: pd.DataFrame, image_docs: List[Dict[str, Any]], filter_index: pd.Series) -> None: - """ - Prepares the final dataframe by combining original dataframe with new image document data, converting to GPU - dataframe, and updating the message with the new dataframe. - - Parameters - ---------- - df : pd.DataFrame - The original dataframe. - image_docs : List[Dict[str, Any]] - A list of dictionaries containing image document data. - filter_index : pd.Series - A boolean series that filters the dataframe. - - Returns - ------- - None - """ - - image_docs_df = pd.DataFrame(image_docs) - docs_df = pd.concat([df[~filter_index], image_docs_df], axis=0).reset_index(drop=True) - - return docs_df + ], + "max_tokens": 512, + "temperature": 1.00, + "top_p": 1.00, + "stream": stream + } - -def caption_extract_stage(df, task_props, validated_config) -> pd.DataFrame: - """ - Extracts captions from images within the provided dataframe. + try: + response = requests.post(endpoint_url, headers=headers, json=payload) + response.raise_for_status() # Raise an exception for HTTP errors + + if stream: + result = [] + for line in response.iter_lines(): + if line: + result.append(line.decode("utf-8")) + return "\n".join(result) + else: + response_data = response.json() + return response_data.get('choices', [{}])[0].get('message', {}).get('content', 'No caption returned') + except requests.exceptions.RequestException as e: + logger.error(f"Error generating caption: {e}") + raise + + +def caption_extract_stage(df: pd.DataFrame, + task_props: Dict[str, Any], + validated_config: Any, + trace_info: Optional[Dict[str, Any]] = None + ) -> pd.DataFrame: + """ + Extracts captions for image content in the DataFrame using an external NVIDIA API. + Updates the 'metadata' column by adding the generated captions under 'image_metadata.caption'. Parameters ---------- df : pd.DataFrame - The dataframe containing image data. - task_props : dict - Task properties required for processing. - validated_config : ImageCaptionExtractionSchema - Validated configuration for caption extraction. + The input DataFrame containing image data in 'metadata.content'. + validated_config : Any + A configuration schema object containing settings for caption extraction. Returns ------- pd.DataFrame - The dataframe with updated image captions. + The updated DataFrame with generated captions in the 'metadata' column's 'image_metadata.caption' field. Raises ------ - ValueError - If an error occurs during caption extraction. + Exception + If there is an error during the caption extraction process. """ - try: - logger.debug("Performing caption extraction") - - # Data preparation and filtering - df, df_filtered, filter_index = _prepare_dataframes_mod(df) - if df_filtered.empty: - return df + logger.debug("Attempting to caption image content") - # Process each image document - metadata_list, neighbor_content = _process_documents(df_filtered) + # Ensure the validated configuration is available for future use + _ = trace_info - # Generate captions - captions = _generate_captions(neighbor_content, validated_config) + api_key = task_props.get("api_key", validated_config.api_key) + prompt = task_props.get("prompt", validated_config.prompt) + endpoint_url = task_props.get("endpoint_url", validated_config.endpoint_url) - # Update metadata with captions - image_docs = _update_metadata_with_captions(metadata_list, captions, df_filtered) + # Create a mask for rows where the document type is IMAGE + df_mask = df['metadata'].apply(lambda meta: meta.get('content_metadata', {}).get('type') == "image") - logger.debug(f"Extracted captions from {len(image_docs)} images") + if not df_mask.any(): + return df - # Final dataframe merge - df_final = _prepare_final_dataframe_mod(df, image_docs, filter_index) + df.loc[df_mask, 'metadata'] = df.loc[df_mask, 'metadata'].apply( + lambda meta: { + **meta, + 'image_metadata': { + **meta.get('image_metadata', {}), + 'caption': _generate_captions(meta['content'], prompt, api_key, endpoint_url) + } + } + ) - if (df_final is None) or df_final.empty: - logger.warning("NO IMAGE DOCUMENTS FOUND IN THE DATAFRAME") - return df - return df_final - except Exception as e: - traceback.print_exc() - raise ValueError(f"Failed to do caption extraction: {e}") + logger.debug("Image content captioning complete") + return df def generate_caption_extraction_stage( - c: Config, - caption_config: Dict[str, Any], - task: str = "caption", - task_desc: str = "caption_extraction", - pe_count: int = 8, + c: Config, + caption_config: Dict[str, Any], + task: str = "caption", + task_desc: str = "caption_extraction", + pe_count: int = 8, ): """ Generates a caption extraction stage with the specified configuration. @@ -483,10 +192,5 @@ def generate_caption_extraction_stage( f"Generating caption extraction stage with {pe_count} processing elements. task: {task}, document_type: *" ) return MultiProcessingBaseStage( - c=c, - pe_count=pe_count, - task=task, - task_desc=task_desc, - process_fn=_wrapped_caption_extract, - filter_properties={"content_type": ContentTypeEnum.IMAGE.value}, + c=c, pe_count=pe_count, task=task, task_desc=task_desc, process_fn=_wrapped_caption_extract ) diff --git a/src/nv_ingest/util/converters/type_mappings.py b/src/nv_ingest/util/converters/type_mappings.py index b3834d50..4fbfb0a9 100644 --- a/src/nv_ingest/util/converters/type_mappings.py +++ b/src/nv_ingest/util/converters/type_mappings.py @@ -15,6 +15,7 @@ DocumentTypeEnum.png: ContentTypeEnum.IMAGE, DocumentTypeEnum.pptx: ContentTypeEnum.STRUCTURED, DocumentTypeEnum.svg: ContentTypeEnum.IMAGE, + DocumentTypeEnum.tiff: ContentTypeEnum.IMAGE, DocumentTypeEnum.txt: ContentTypeEnum.TEXT, } diff --git a/src/nv_ingest/util/flow_control/filter_by_task.py b/src/nv_ingest/util/flow_control/filter_by_task.py index 2e973372..11a73a43 100644 --- a/src/nv_ingest/util/flow_control/filter_by_task.py +++ b/src/nv_ingest/util/flow_control/filter_by_task.py @@ -5,6 +5,7 @@ import logging import typing +import re from functools import wraps from morpheus.messages import ControlMessage @@ -50,10 +51,12 @@ def wrapper(*args, **kwargs): continue task_props_list = tasks.get(required_task_name, []) + logger.debug(f"Checking task properties for: {required_task_name}") + logger.debug(f"Required task properties: {required_task_props_list}") for task_props in task_props_list: if all( - _is_subset(task_props, required_task_props) - for required_task_props in required_task_props_list + _is_subset(task_props, required_task_props) + for required_task_props in required_task_props_list ): return func(*args, **kwargs) @@ -75,22 +78,38 @@ def _is_subset(superset, subset): if subset == "*": return True if isinstance(superset, dict) and isinstance(subset, dict): - return all(key in superset and _is_subset(superset[key], val) for key, val in subset.items()) + return all( + key in superset and _is_subset(superset[key], val) + for key, val in subset.items() + ) + if isinstance(subset, str) and subset.startswith('regex:'): + # The subset is a regex pattern + pattern = subset[len('regex:'):] + if isinstance(superset, list): + return any(re.match(pattern, str(sup_item)) for sup_item in superset) + else: + return re.match(pattern, str(superset)) is not None + if isinstance(superset, list) and not isinstance(subset, list): + # Check if the subset value matches any item in the superset + return any(_is_subset(sup_item, subset) for sup_item in superset) if isinstance(superset, list) or isinstance(superset, set): - return all(any(_is_subset(sup_item, sub_item) for sup_item in superset) for sub_item in subset) + return all( + any(_is_subset(sup_item, sub_item) for sup_item in superset) + for sub_item in subset + ) return superset == subset def remove_task_subset(ctrl_msg: ControlMessage, task_type: typing.List, subset: typing.Dict): """ - A helper function to extract a task based on subset matching when the task might be out of order wrt the + A helper function to extract a task based on subset matching when the task might be out of order with respect to the Morpheus pipeline. For example, if a deduplication filter occurs before scale filtering in the pipeline, but the task list includes scale filtering before deduplication. Parameters ---------- ctrl_msg : ControlMessage - A list of task keys to check for in the ControlMessage. + The ControlMessage object containing tasks. task_type : list The name of the ControlMessage task to operate on. subset : dict diff --git a/src/nv_ingest/util/image_processing/transforms.py b/src/nv_ingest/util/image_processing/transforms.py index d441db72..15f25757 100644 --- a/src/nv_ingest/util/image_processing/transforms.py +++ b/src/nv_ingest/util/image_processing/transforms.py @@ -3,6 +3,8 @@ # SPDX-License-Identifier: Apache-2.0 import base64 +import io +import logging from io import BytesIO from math import ceil from math import floor @@ -18,13 +20,112 @@ DEFAULT_MAX_WIDTH = 1024 DEFAULT_MAX_HEIGHT = 1280 +logger = logging.getLogger(__name__) + + +def scale_image_to_encoding_size(base64_image: str, max_base64_size: int = 180_000, + initial_reduction: float = 0.9) -> str: + """ + Decodes a base64-encoded image, resizes it if needed, and re-encodes it as base64. + Ensures the final image size is within the specified limit. + + Parameters + ---------- + base64_image : str + Base64-encoded image string. + max_base64_size : int, optional + Maximum allowable size for the base64-encoded image, by default 180,000 characters. + initial_reduction : float, optional + Initial reduction step for resizing, by default 0.9. + + Returns + ------- + str + Base64-encoded PNG image string, resized if necessary. + + Raises + ------ + Exception + If the image cannot be resized below the specified max_base64_size. + """ + try: + # Decode the base64 image and open it as a PIL image + image_data = base64.b64decode(base64_image) + img = Image.open(io.BytesIO(image_data)).convert("RGB") + + # Check initial size + if len(base64_image) <= max_base64_size: + logger.debug("Initial image is within the size limit.") + return base64_image + + # Initial reduction step + reduction_step = initial_reduction + while len(base64_image) > max_base64_size: + width, height = img.size + new_size = (int(width * reduction_step), int(height * reduction_step)) + logger.debug(f"Resizing image to {new_size}") + + img_resized = img.resize(new_size, Image.LANCZOS) + buffered = io.BytesIO() + img_resized.save(buffered, format="PNG") + base64_image = base64.b64encode(buffered.getvalue()).decode("utf-8") + + logger.debug(f"Resized base64 image size: {len(base64_image)} characters.") + + # Adjust the reduction step if necessary + if len(base64_image) > max_base64_size: + reduction_step *= 0.95 # Reduce size further if needed + logger.debug(f"Reducing step size for further resizing: {reduction_step:.3f}") + + # Safety check + if new_size[0] < 1 or new_size[1] < 1: + raise Exception("Image cannot be resized further without becoming too small.") + + return base64_image + + except Exception as e: + logger.error(f"Error resizing the image: {e}") + raise + + +def ensure_base64_is_png(base64_image: str) -> str: + """ + Ensures the given base64-encoded image is in PNG format. Converts to PNG if necessary. + + Parameters + ---------- + base64_image : str + Base64-encoded image string. + + Returns + ------- + str + Base64-encoded PNG image string. + """ + try: + # Decode the base64 string and load the image + image_data = base64.b64decode(base64_image) + image = Image.open(io.BytesIO(image_data)) + + # Check if the image is already in PNG format + if image.format != 'PNG': + # Convert the image to PNG + buffered = io.BytesIO() + image.convert("RGB").save(buffered, format="PNG") + base64_image = base64.b64encode(buffered.getvalue()).decode("utf-8") + + return base64_image + except Exception as e: + logger.error(f"Error ensuring PNG format: {e}") + return None + def pad_image( - array: np.ndarray, - target_width: int = DEFAULT_MAX_WIDTH, - target_height: int = DEFAULT_MAX_HEIGHT, - background_color: int = 255, - dtype=np.uint8, + array: np.ndarray, + target_width: int = DEFAULT_MAX_WIDTH, + target_height: int = DEFAULT_MAX_HEIGHT, + background_color: int = 255, + dtype=np.uint8, ) -> Tuple[np.ndarray, Tuple[int, int]]: """ Pads a NumPy array representing an image to the specified target dimensions. @@ -75,10 +176,11 @@ def pad_image( # Create the canvas and place the original image on it canvas = background_color * np.ones((final_height, final_width, array.shape[2]), dtype=dtype) - canvas[pad_height : pad_height + height, pad_width : pad_width + width] = array # noqa: E203 + canvas[pad_height: pad_height + height, pad_width: pad_width + width] = array # noqa: E203 return canvas, (pad_width, pad_height) + def check_numpy_image_size(image: np.ndarray, min_height: int, min_width: int) -> bool: """ Checks if the height and width of the image are larger than the specified minimum values. @@ -98,8 +200,9 @@ def check_numpy_image_size(image: np.ndarray, min_height: int, min_width: int) - height, width = image.shape[:2] return height >= min_height and width >= min_width + def crop_image( - array: np.array, bbox: Tuple[int, int, int, int], min_width: int = 1, min_height: int = 1 + array: np.array, bbox: Tuple[int, int, int, int], min_width: int = 1, min_height: int = 1 ) -> Optional[np.ndarray]: """ Crops a NumPy array representing an image according to the specified bounding box. @@ -138,13 +241,13 @@ def crop_image( def normalize_image( - array: np.ndarray, - r_mean: float = 0.485, - g_mean: float = 0.456, - b_mean: float = 0.406, - r_std: float = 0.229, - g_std: float = 0.224, - b_std: float = 0.225, + array: np.ndarray, + r_mean: float = 0.485, + g_mean: float = 0.456, + b_mean: float = 0.406, + r_std: float = 0.229, + g_std: float = 0.224, + b_std: float = 0.225, ) -> np.ndarray: """ Normalizes an RGB image by applying a mean and standard deviation to each channel. diff --git a/src/nv_ingest/extraction_workflows/pdf/yolox_utils.py b/src/nv_ingest/util/nim/yolox.py similarity index 93% rename from src/nv_ingest/extraction_workflows/pdf/yolox_utils.py rename to src/nv_ingest/util/nim/yolox.py index e8458ea0..f77460e8 100644 --- a/src/nv_ingest/extraction_workflows/pdf/yolox_utils.py +++ b/src/nv_ingest/util/nim/yolox.py @@ -2,6 +2,8 @@ # All rights reserved. # SPDX-License-Identifier: Apache-2.0 +from typing import List +import numpy as np import warnings import cv2 @@ -26,7 +28,7 @@ def postprocess_model_prediction(prediction, num_classes, conf_thre=0.7, nms_thr continue # Get score and class with highest confidence - class_conf, class_pred = torch.max(image_pred[:, 5 : 5 + num_classes], 1, keepdim=True) # noqa: E203 + class_conf, class_pred = torch.max(image_pred[:, 5: 5 + num_classes], 1, keepdim=True) # noqa: E203 conf_mask = (image_pred[:, 4] * class_conf.squeeze() >= conf_thre).squeeze() # Detections ordered as (x1, y1, x2, y2, obj_conf, class_conf, class_pred) @@ -201,14 +203,14 @@ def expand_chart_bboxes(annotation_dict, labels=None): def weighted_boxes_fusion( - boxes_list, - scores_list, - labels_list, - iou_thr=0.5, - skip_box_thr=0.0, - conf_type="avg", - merge_type="weighted", - class_agnostic=False, + boxes_list, + scores_list, + labels_list, + iou_thr=0.5, + skip_box_thr=0.0, + conf_type="avg", + merge_type="weighted", + class_agnostic=False, ): """ Custom wbf implementation that supports a class_agnostic mode and a biggest box fusion. @@ -581,3 +583,35 @@ def get_weighted_box(boxes, conf_type="avg"): box[3] = -1 # model index field is retained for consistency but is not used. box[4:] /= conf return box + + +def prepare_images_for_inference(images: List[np.ndarray]) -> np.ndarray: + """ + Prepare a list of images for model inference by resizing and reordering axes. + + Parameters + ---------- + images : List[np.ndarray] + A list of image arrays to be prepared for inference. + + Returns + ------- + np.ndarray + A numpy array suitable for model input, with the shape reordered to match the expected input format. + + Notes + ----- + The images are resized to 1024x1024 pixels and the axes are reordered to match the expected input shape for + the model (batch, channels, height, width). + + Examples + -------- + >>> images = [np.random.rand(1536, 1536, 3) for _ in range(2)] + >>> input_array = prepare_images_for_inference(images) + >>> input_array.shape + (2, 3, 1024, 1024) + """ + + resized_images = [resize_image(image, (1024, 1024)) for image in images] + + return np.einsum("bijk->bkij", resized_images).astype(np.float32) diff --git a/src/nv_ingest/util/pdf/metadata_aggregators.py b/src/nv_ingest/util/pdf/metadata_aggregators.py index 851a931b..64b040af 100644 --- a/src/nv_ingest/util/pdf/metadata_aggregators.py +++ b/src/nv_ingest/util/pdf/metadata_aggregators.py @@ -3,16 +3,20 @@ # SPDX-License-Identifier: Apache-2.0 -import uuid from dataclasses import dataclass from datetime import datetime +from PIL import Image from typing import Any from typing import Dict from typing import List from typing import Tuple +import base64 +import io +import uuid import pandas as pd import pypdfium2 as pdfium +from pypdfium2 import PdfImage from nv_ingest.schemas.metadata_schema import ContentSubtypeEnum from nv_ingest.schemas.metadata_schema import ContentTypeEnum @@ -186,8 +190,95 @@ def construct_text_metadata( return [ContentTypeEnum.TEXT, validated_unified_metadata.dict(), str(uuid.uuid4())] -def construct_image_metadata( - image_base64: Base64Image, +def construct_image_metadata_from_base64( + base64_image: str, + page_idx: int, + page_count: int, + source_metadata: Dict[str, Any], + base_unified_metadata: Dict[str, Any], +) -> List[Any]: + """ + Extracts image data from a base64-encoded image string, decodes the image to get + its dimensions and bounding box, and constructs metadata for the image. + + Parameters + ---------- + base64_image : str + A base64-encoded string representing the image. + page_idx : int + The index of the current page being processed. + page_count : int + The total number of pages in the PDF document. + source_metadata : Dict[str, Any] + Metadata related to the source of the PDF document. + base_unified_metadata : Dict[str, Any] + The base unified metadata structure to be updated with the extracted image information. + + Returns + ------- + List[Any] + A list containing the content type, validated metadata dictionary, and a UUID string. + + Raises + ------ + ValueError + If the image cannot be decoded from the base64 string. + """ + # Decode the base64 image + try: + image_data = base64.b64decode(base64_image) + image = Image.open(io.BytesIO(image_data)) + except Exception as e: + raise ValueError(f"Failed to decode image from base64: {e}") + + # Extract image dimensions and bounding box + width, height = image.size + bbox = (0, 0, width, height) # Assuming the full image as the bounding box + + # Construct content metadata + content_metadata: Dict[str, Any] = { + "type": ContentTypeEnum.IMAGE, + "description": StdContentDescEnum.PDF_IMAGE, + "page_number": page_idx, + "hierarchy": { + "page_count": page_count, + "page": page_idx, + "block": -1, + "line": -1, + "span": -1, + "nearby_objects": [], + }, + } + + # Construct image metadata + image_metadata: Dict[str, Any] = { + "image_type": "PNG", # This can be dynamic if needed + "structured_image_type": ImageTypeEnum.image_type_1, + "caption": "", + "text": "", + "image_location": bbox, + "image_location_max_dimensions": (width, height), + "height": height, + } + + # Update the unified metadata with the extracted image information + unified_metadata: Dict[str, Any] = base_unified_metadata.copy() + unified_metadata.update( + { + "content": base64_image, + "source_metadata": source_metadata, + "content_metadata": content_metadata, + "image_metadata": image_metadata, + } + ) + + # Validate and return the unified metadata + validated_unified_metadata = validate_metadata(unified_metadata) + return [ContentTypeEnum.IMAGE, validated_unified_metadata.dict(), str(uuid.uuid4())] + + +def construct_image_metadata_from_pdf_image( + pdf_image: PdfImage, page_idx: int, page_count: int, source_metadata: Dict[str, Any], @@ -219,7 +310,7 @@ def construct_image_metadata( ------ PdfiumError If the image cannot be extracted due to an issue with the PdfImage object. - :param image_base64: + :param pdf_image: """ # Define the assumed image type (e.g., PNG) image_type: str = "PNG" @@ -245,16 +336,16 @@ def construct_image_metadata( "structured_image_type": ImageTypeEnum.image_type_1, "caption": "", "text": "", - "image_location": image_base64.bbox, - "image_location_max_dimensions": (max(image_base64.max_width,0), max(image_base64.max_height,0)), - "height": image_base64.height, + "image_location": pdf_image.bbox, + "image_location_max_dimensions": (max(pdf_image.max_width, 0), max(pdf_image.max_height, 0)), + "height": pdf_image.height, } # Update the unified metadata with the extracted image information unified_metadata: Dict[str, Any] = base_unified_metadata.copy() unified_metadata.update( { - "content": image_base64.image, + "content": pdf_image.image, "source_metadata": source_metadata, "content_metadata": content_metadata, "image_metadata": image_metadata, diff --git a/src/nv_ingest/util/pipeline/stage_builders.py b/src/nv_ingest/util/pipeline/stage_builders.py index ee8ae3fe..a55f9e95 100644 --- a/src/nv_ingest/util/pipeline/stage_builders.py +++ b/src/nv_ingest/util/pipeline/stage_builders.py @@ -23,6 +23,7 @@ from nv_ingest.modules.transforms.embed_extractions import EmbedExtractionsLoaderFactory from nv_ingest.modules.transforms.nemo_doc_splitter import NemoDocSplitterLoaderFactory from nv_ingest.stages.docx_extractor_stage import generate_docx_extractor_stage +from nv_ingest.stages.extractors.image_extractor_stage import generate_image_extractor_stage from nv_ingest.stages.filters import generate_dedup_stage from nv_ingest.stages.filters import generate_image_filter_stage from nv_ingest.stages.nim.chart_extraction import generate_chart_extractor_stage @@ -247,6 +248,29 @@ def add_chart_extractor_stage(pipe, morpheus_pipeline_config, ingest_config, def return table_extractor_stage +def add_image_extractor_stage(pipe, morpheus_pipeline_config, ingest_config, default_cpu_count): + yolox_grpc, yolox_http, yolox_auth, yolox_protocol = get_table_detection_service("yolox") + image_extractor_config = ingest_config.get("image_extraction_module", + { + "image_extraction_config": { + "yolox_endpoints": (yolox_grpc, yolox_http), + "yolox_infer_protocol": yolox_protocol, + "auth_token": yolox_auth, + # All auth tokens are the same for the moment + } + }) + image_extractor_stage = pipe.add_stage( + generate_image_extractor_stage( + morpheus_pipeline_config, + extractor_config=image_extractor_config, + pe_count=8, + task="extract", + task_desc="docx_content_extractor", + ) + ) + return image_extractor_stage + + def add_docx_extractor_stage(pipe, morpheus_pipeline_config, default_cpu_count): docx_extractor_stage = pipe.add_stage( generate_docx_extractor_stage( @@ -319,14 +343,25 @@ def add_nemo_splitter_stage(pipe, morpheus_pipeline_config, ingest_config): def add_image_caption_stage(pipe, morpheus_pipeline_config, ingest_config, default_cpu_count): - endpoint_url, model_name = get_caption_classifier_service() + auth_token = os.environ.get( + "NVIDIA_BUILD_API_KEY", + "", + ) or os.environ.get( + "NGC_API_KEY", + "", + ) + + endpoint_url = os.environ.get("VLM_CAPTION_ENDPOINT") + image_caption_config = ingest_config.get( "image_caption_extraction_module", { - "caption_classifier_model_name": model_name, + "api_key": auth_token, "endpoint_url": endpoint_url, + "prompt": "Caption the content of this image:", }, ) + image_caption_stage = pipe.add_stage( generate_caption_extraction_stage( morpheus_pipeline_config, diff --git a/src/pipeline.py b/src/pipeline.py index 4d8289d7..6397765b 100644 --- a/src/pipeline.py +++ b/src/pipeline.py @@ -49,6 +49,7 @@ def setup_ingestion_pipeline( ## Primitive extraction ######################################################################################################## pdf_extractor_stage = add_pdf_extractor_stage(pipe, morpheus_pipeline_config, ingest_config, default_cpu_count) + image_extractor_stage = add_image_extractor_stage(pipe, morpheus_pipeline_config, ingest_config, default_cpu_count) docx_extractor_stage = add_docx_extractor_stage(pipe, morpheus_pipeline_config, default_cpu_count) pptx_extractor_stage = add_pptx_extractor_stage(pipe, morpheus_pipeline_config, default_cpu_count) ######################################################################################################## @@ -60,6 +61,7 @@ def setup_ingestion_pipeline( image_filter_stage = add_image_filter_stage(pipe, morpheus_pipeline_config, ingest_config, default_cpu_count) table_extraction_stage = add_table_extractor_stage(pipe, morpheus_pipeline_config, ingest_config, default_cpu_count) chart_extraction_stage = add_chart_extractor_stage(pipe, morpheus_pipeline_config, ingest_config, default_cpu_count) + image_caption_stage = add_image_caption_stage(pipe, morpheus_pipeline_config, ingest_config, default_cpu_count) ######################################################################################################## ######################################################################################################## @@ -73,10 +75,10 @@ def setup_ingestion_pipeline( ## Storage and output ######################################################################################################## image_storage_stage = add_image_storage_stage(pipe, morpheus_pipeline_config) + vdb_task_sink_stage = add_vdb_task_sink_stage(pipe, morpheus_pipeline_config, ingest_config) sink_stage = add_sink_stage( pipe, morpheus_pipeline_config, ingest_config, message_provider_host, message_provider_port ) - vdb_task_sink_stage = add_vdb_task_sink_stage(pipe, morpheus_pipeline_config, ingest_config) ######################################################################################################## ####################################################################################################### @@ -93,14 +95,16 @@ def setup_ingestion_pipeline( pipe.add_edge(source_stage, submitted_job_counter_stage) pipe.add_edge(submitted_job_counter_stage, metadata_injector_stage) pipe.add_edge(metadata_injector_stage, pdf_extractor_stage) - pipe.add_edge(pdf_extractor_stage, docx_extractor_stage) + pipe.add_edge(pdf_extractor_stage, image_extractor_stage) + pipe.add_edge(image_extractor_stage, docx_extractor_stage) pipe.add_edge(docx_extractor_stage, pptx_extractor_stage) pipe.add_edge(pptx_extractor_stage, image_dedup_stage) pipe.add_edge(image_dedup_stage, image_filter_stage) pipe.add_edge(image_filter_stage, table_extraction_stage) pipe.add_edge(table_extraction_stage, chart_extraction_stage) pipe.add_edge(chart_extraction_stage, nemo_splitter_stage) - pipe.add_edge(nemo_splitter_stage, embed_extractions_stage) + pipe.add_edge(nemo_splitter_stage, image_caption_stage) + pipe.add_edge(image_caption_stage, embed_extractions_stage) pipe.add_edge(embed_extractions_stage, image_storage_stage) pipe.add_edge(image_storage_stage, vdb_task_sink_stage) pipe.add_edge(vdb_task_sink_stage, sink_stage) diff --git a/tests/stages/nims/__init__.py b/tests/nv_ingest/extraction_workflows/image/__init__.py similarity index 100% rename from tests/stages/nims/__init__.py rename to tests/nv_ingest/extraction_workflows/image/__init__.py diff --git a/tests/nv_ingest/extraction_workflows/image/test_image_handlers.py b/tests/nv_ingest/extraction_workflows/image/test_image_handlers.py new file mode 100644 index 00000000..b58c235f --- /dev/null +++ b/tests/nv_ingest/extraction_workflows/image/test_image_handlers.py @@ -0,0 +1,354 @@ +import urllib +from pyexpat import ExpatError +from xml.etree.ElementTree import ParseError + +from wand.exceptions import WandException + +from nv_ingest.extraction_workflows.image.image_handlers import load_and_preprocess_image, convert_svg_to_bitmap, \ + extract_table_and_chart_images +from PIL import Image +import io +import numpy as np +from typing import List, Tuple + +from nv_ingest.extraction_workflows.image.image_handlers import process_inference_results +from nv_ingest.util.pdf.metadata_aggregators import CroppedImageWithContent + + +def test_load_and_preprocess_image_jpeg(): + """Test loading and preprocessing a JPEG image.""" + # Create a small sample image and save it to a BytesIO stream as JPEG + image = Image.new("RGB", (10, 10), color="red") + image_stream = io.BytesIO() + image.save(image_stream, format="JPEG") + image_stream.seek(0) + + # Load and preprocess the image + result = load_and_preprocess_image(image_stream) + + # Check the output type and shape + assert isinstance(result, np.ndarray) + assert result.shape == (10, 10, 3) + assert result.dtype == np.float32 + assert np.all(result[:, :, 0] == 254) # All red pixels + assert np.all(result[:, :, 1] == 0) # No green + assert np.all(result[:, :, 2] == 0) # No blue + + +def test_load_and_preprocess_image_png(): + """Test loading and preprocessing a PNG image.""" + # Create a small sample image and save it to a BytesIO stream as PNG + image = Image.new("RGB", (5, 5), color="blue") + image_stream = io.BytesIO() + image.save(image_stream, format="PNG") + image_stream.seek(0) + + # Load and preprocess the image + result = load_and_preprocess_image(image_stream) + + # Check the output type and shape + assert isinstance(result, np.ndarray) + assert result.shape == (5, 5, 3) + assert result.dtype == np.float32 + assert np.all(result[:, :, 0] == 0) # No red + assert np.all(result[:, :, 1] == 0) # No green + assert np.all(result[:, :, 2] == 255) # All blue pixels + + +def test_load_and_preprocess_image_invalid_format(): + """Test that an invalid image format raises an error.""" + # Create a BytesIO stream with non-image content + invalid_stream = io.BytesIO(b"This is not an image file") + + # Expect an OSError when trying to open a non-image stream + try: + load_and_preprocess_image(invalid_stream) + except OSError as e: + assert "cannot identify image file" in str(e) + + +def test_load_and_preprocess_image_corrupt_image(): + """Test that a corrupt image raises an error.""" + # Create a valid JPEG header but corrupt the rest + corrupt_stream = io.BytesIO(b"\xFF\xD8\xFF\xE0" + b"\x00" * 10) + + # Expect an OSError when trying to open a corrupt image stream + try: + load_and_preprocess_image(corrupt_stream) + except OSError as e: + assert "cannot identify image file" in str(e) + + +def test_convert_svg_to_bitmap_basic_svg(): + """Test converting a simple SVG to a bitmap image.""" + # Sample SVG image data (a small red square) + svg_data = b""" + + + + """ + image_stream = io.BytesIO(svg_data) + + # Convert SVG to bitmap + result = convert_svg_to_bitmap(image_stream) + + # Check the output type, shape, and color values + assert isinstance(result, np.ndarray) + assert result.shape == (10, 10, 3) + assert result.dtype == np.float32 + assert np.all(result[:, :, 0] == 255) # Red channel fully on + assert np.all(result[:, :, 1] == 0) # Green channel off + assert np.all(result[:, :, 2] == 0) # Blue channel off + + +def test_convert_svg_to_bitmap_large_svg(): + """Test converting a larger SVG to ensure scalability.""" + # Large SVG image data (blue rectangle 100x100) + svg_data = b""" + + + + """ + image_stream = io.BytesIO(svg_data) + + # Convert SVG to bitmap + result = convert_svg_to_bitmap(image_stream) + + # Check the output type and shape + assert isinstance(result, np.ndarray) + assert result.shape == (100, 100, 3) + assert result.dtype == np.float32 + assert np.all(result[:, :, 0] == 0) # Red channel off + assert np.all(result[:, :, 1] == 0) # Green channel off + assert np.all(result[:, :, 2] == 255) # Blue channel fully on + + +def test_process_inference_results_basic_case(): + """Test process_inference_results with a typical valid input.""" + + # Simulated model output array for a single image with several detections. + # Array format is (batch_size, num_detections, 85) - 80 classes + 5 box coordinates + # For simplicity, use random values for the boxes and class predictions. + output_array = np.zeros((1, 3, 85), dtype=np.float32) + + # Mock bounding box coordinates + output_array[0, 0, :4] = [0.5, 0.5, 0.2, 0.2] # x_center, y_center, width, height + output_array[0, 1, :4] = [0.6, 0.6, 0.2, 0.2] + output_array[0, 2, :4] = [0.7, 0.7, 0.2, 0.2] + + # Mock object confidence scores + output_array[0, :, 4] = [0.8, 0.9, 0.85] + + # Mock class scores (set class 1 with highest confidence for simplicity) + output_array[0, 0, 5 + 1] = 0.7 + output_array[0, 1, 5 + 1] = 0.75 + output_array[0, 2, 5 + 1] = 0.72 + + original_image_shapes = [(640, 640)] # Original shape of the image before resizing + + # Process inference results with thresholds that should retain all mock detections + results = process_inference_results( + output_array, + original_image_shapes, + num_classes=80, + conf_thresh=0.5, + iou_thresh=0.5, + min_score=0.1, + final_thresh=0.3, + ) + + # Check output structure + assert isinstance(results, list) + assert len(results) == 1 + assert isinstance(results[0], dict) + + # Validate bounding box scaling and structure + assert "chart" in results[0] or "table" in results[0] + if "chart" in results[0]: + assert isinstance(results[0]["chart"], list) + assert len(results[0]["chart"]) > 0 + # Check bounding box format for each detected "chart" item (5 values per box) + for bbox in results[0]["chart"]: + assert len(bbox) == 5 # [x1, y1, x2, y2, score] + assert bbox[4] >= 0.3 # score meets final threshold + + print("Processed inference results:", results) + + +def test_process_inference_results_multiple_images(): + """Test with multiple images to verify batch processing.""" + # Simulate model output with 2 images and 3 detections each + output_array = np.zeros((2, 3, 85), dtype=np.float32) + # Set bounding boxes and confidence for the mock detections + output_array[0, 0, :5] = [0.5, 0.5, 0.2, 0.2, 0.8] + output_array[0, 1, :5] = [0.6, 0.6, 0.2, 0.2, 0.7] + output_array[1, 0, :5] = [0.4, 0.4, 0.1, 0.1, 0.9] + # Assign class confidences for classes 0 and 1 + output_array[0, 0, 5 + 1] = 0.75 + output_array[0, 1, 5 + 1] = 0.65 + output_array[1, 0, 5 + 0] = 0.8 + + original_image_shapes = [(640, 640), (800, 800)] + + results = process_inference_results( + output_array, + original_image_shapes, + num_classes=80, + conf_thresh=0.5, + iou_thresh=0.5, + min_score=0.1, + final_thresh=0.3, + ) + + assert isinstance(results, list) + assert len(results) == 2 + for result in results: + assert isinstance(result, dict) + if "chart" in result: + assert all(len(bbox) == 5 and bbox[4] >= 0.3 for bbox in result["chart"]) + + +def test_process_inference_results_high_confidence_threshold(): + """Test with a high confidence threshold to verify filtering.""" + output_array = np.zeros((1, 5, 85), dtype=np.float32) + # Set low confidence scores below the threshold + output_array[0, :, 4] = [0.2, 0.3, 0.4, 0.4, 0.2] + output_array[0, :, 5] = [0.5] * 5 # Class confidence + + original_image_shapes = [(640, 640)] + + results = process_inference_results( + output_array, + original_image_shapes, + num_classes=80, + conf_thresh=0.9, # High confidence threshold + iou_thresh=0.5, + min_score=0.1, + final_thresh=0.3, + ) + + assert isinstance(results, list) + assert len(results) == 1 + assert results[0] == {} # No detections should pass the high confidence threshold + + +def test_process_inference_results_varied_num_classes(): + """Test compatibility with different model class counts.""" + output_array = np.zeros((1, 3, 25), dtype=np.float32) # 20 classes + 5 box coords + # Assign box, object confidence, and class scores + output_array[0, 0, :5] = [0.5, 0.5, 0.2, 0.2, 0.8] + output_array[0, 1, :5] = [0.6, 0.6, 0.3, 0.3, 0.7] + output_array[0, 0, 5 + 1] = 0.9 # Assign highest confidence to class 1 + + original_image_shapes = [(640, 640)] + + results = process_inference_results( + output_array, + original_image_shapes, + num_classes=20, # Different class count + conf_thresh=0.5, + iou_thresh=0.5, + min_score=0.1, + final_thresh=0.3, + ) + + assert isinstance(results, list) + assert len(results) == 1 + assert isinstance(results[0], dict) + assert "chart" in results[0] + assert len(results[0]["chart"]) > 0 # Verify detections processed correctly with 20 classes + + +def crop_image(image: np.ndarray, bbox: Tuple[int, int, int, int]) -> np.ndarray: + """Mock function to simulate cropping an image.""" + h1, w1, h2, w2 = bbox + return image[int(h1):int(h2), int(w1):int(w2)] + + +def test_extract_table_and_chart_images_empty_annotations(): + """Test when annotation_dict has no objects to extract.""" + annotation_dict = {"table": [], "chart": []} + original_image = np.random.rand(640, 640, 3) + tables_and_charts = [] + + extract_table_and_chart_images(annotation_dict, original_image, 0, tables_and_charts) + + # Expect no entries added to tables_and_charts since there are no objects + assert tables_and_charts == [] + + +def test_extract_table_and_chart_images_single_table(): + """Test extraction with a single table bounding box.""" + annotation_dict = {"table": [[0.1, 0.1, 0.3, 0.3, 0.8]], "chart": []} + original_image = np.random.rand(640, 640, 3) + tables_and_charts = [] + + extract_table_and_chart_images(annotation_dict, original_image, 0, tables_and_charts) + + # Expect one entry in tables_and_charts for the table + assert len(tables_and_charts) == 1 + page_idx, cropped_image_data = tables_and_charts[0] + assert page_idx == 0 + assert isinstance(cropped_image_data, CroppedImageWithContent) + + # Verify attribute values + assert cropped_image_data.content == "" + assert cropped_image_data.type_string == "table" + assert cropped_image_data.bbox == (64, 64, 192, 192) # Scaled bounding box from (0.1, 0.1, 0.3, 0.3) + assert cropped_image_data.max_width == 640 + assert cropped_image_data.max_height == 640 + assert isinstance(cropped_image_data.image, str) # Assuming the image is base64-encoded + + +def test_extract_table_and_chart_images_single_chart(): + """Test extraction with a single chart bounding box.""" + annotation_dict = {"table": [], "chart": [[0.4, 0.4, 0.6, 0.6, 0.9]]} + original_image = np.random.rand(640, 640, 3) + tables_and_charts = [] + + extract_table_and_chart_images(annotation_dict, original_image, 1, tables_and_charts) + + # Expect one entry in tables_and_charts for the chart + assert len(tables_and_charts) == 1 + page_idx, cropped_image_data = tables_and_charts[0] + assert page_idx == 1 + assert isinstance(cropped_image_data, CroppedImageWithContent) + assert cropped_image_data.type_string == "chart" + assert cropped_image_data.bbox == (256, 256, 384, 384) # Scaled bounding box + + +def test_extract_table_and_chart_images_multiple_objects(): + """Test extraction with multiple table and chart objects.""" + annotation_dict = { + "table": [[0.1, 0.1, 0.3, 0.3, 0.8], [0.5, 0.5, 0.7, 0.7, 0.85]], + "chart": [[0.2, 0.2, 0.4, 0.4, 0.9]] + } + original_image = np.random.rand(640, 640, 3) + tables_and_charts = [] + + extract_table_and_chart_images(annotation_dict, original_image, 2, tables_and_charts) + + # Expect three entries in tables_and_charts: two tables and one chart + assert len(tables_and_charts) == 3 + for page_idx, cropped_image_data in tables_and_charts: + assert page_idx == 2 + assert isinstance(cropped_image_data, CroppedImageWithContent) + assert cropped_image_data.type_string in ["table", "chart"] + assert cropped_image_data.bbox is not None # Bounding box should be defined + + +def test_extract_table_and_chart_images_invalid_bounding_box(): + """Test with an invalid bounding box to check handling of incorrect coordinates.""" + annotation_dict = {"table": [[1.1, 1.1, 1.5, 1.5, 0.9]], "chart": []} # Out of bounds + original_image = np.random.rand(640, 640, 3) + tables_and_charts = [] + + extract_table_and_chart_images(annotation_dict, original_image, 3, tables_and_charts) + + # Verify that the function processes the bounding box as is + assert len(tables_and_charts) == 1 + page_idx, cropped_image_data = tables_and_charts[0] + assert page_idx == 3 + assert isinstance(cropped_image_data, CroppedImageWithContent) + assert cropped_image_data.type_string == "table" + assert cropped_image_data.bbox == (704, 704, 960, 960) # Scaled bounding box with out-of-bounds values diff --git a/tests/nv_ingest/extraction_workflows/pdf/test_pdfium_helper.py b/tests/nv_ingest/extraction_workflows/pdf/test_pdfium_helper.py index 3455271d..b9999603 100644 --- a/tests/nv_ingest/extraction_workflows/pdf/test_pdfium_helper.py +++ b/tests/nv_ingest/extraction_workflows/pdf/test_pdfium_helper.py @@ -9,7 +9,7 @@ import pandas as pd import pytest -from nv_ingest.extraction_workflows.pdf.pdfium_helper import pdfium +from nv_ingest.extraction_workflows.pdf.pdfium_helper import pdfium_extractor from nv_ingest.schemas.metadata_schema import TextTypeEnum @@ -38,8 +38,8 @@ def pdf_stream_embedded_tables_pdf(): @pytest.mark.xfail(reason="PDFium conversion required") -def test_pdfium_basic(pdf_stream_test_pdf, document_df): - extracted_data = pdfium( +def test_pdfium_extractor_basic(pdf_stream_test_pdf, document_df): + extracted_data = pdfium_extractor( pdf_stream_test_pdf, extract_text=True, extract_images=False, @@ -63,8 +63,8 @@ def test_pdfium_basic(pdf_stream_test_pdf, document_df): "text_depth", ["span", TextTypeEnum.SPAN, "line", TextTypeEnum.LINE, "block", TextTypeEnum.BLOCK], ) -def test_pdfium_text_depth_line(pdf_stream_test_pdf, document_df, text_depth): - extracted_data = pdfium( +def test_pdfium_extractor_text_depth_line(pdf_stream_test_pdf, document_df, text_depth): + extracted_data = pdfium_extractor( pdf_stream_test_pdf, extract_text=True, extract_images=False, @@ -89,8 +89,8 @@ def test_pdfium_text_depth_line(pdf_stream_test_pdf, document_df, text_depth): "text_depth", ["page", TextTypeEnum.PAGE, "document", TextTypeEnum.DOCUMENT], ) -def test_pdfium_text_depth_page(pdf_stream_test_pdf, document_df, text_depth): - extracted_data = pdfium( +def test_pdfium_extractor_text_depth_page(pdf_stream_test_pdf, document_df, text_depth): + extracted_data = pdfium_extractor( pdf_stream_test_pdf, extract_text=True, extract_images=False, @@ -111,8 +111,8 @@ def test_pdfium_text_depth_page(pdf_stream_test_pdf, document_df, text_depth): @pytest.mark.xfail(reason="PDFium conversion required") -def test_pdfium_extract_image(pdf_stream_test_pdf, document_df): - extracted_data = pdfium( +def test_pdfium_extractor_extract_image(pdf_stream_test_pdf, document_df): + extracted_data = pdfium_extractor( pdf_stream_test_pdf, extract_text=True, extract_images=True, @@ -147,8 +147,8 @@ def read_markdown_table(table_str: str) -> pd.DataFrame: @pytest.mark.xfail(reason="PDFium conversion required") -def test_pdfium_table_extraction_on_pdf_with_no_tables(pdf_stream_test_pdf, document_df): - extracted_data = pdfium( +def test_pdfium_extractor_table_extraction_on_pdf_with_no_tables(pdf_stream_test_pdf, document_df): + extracted_data = pdfium_extractor( pdf_stream_test_pdf, extract_text=False, extract_images=False, @@ -162,11 +162,11 @@ def test_pdfium_table_extraction_on_pdf_with_no_tables(pdf_stream_test_pdf, docu @pytest.mark.xfail(reason="PDFium conversion required") -def test_pdfium_table_extraction_on_pdf_with_tables(pdf_stream_embedded_tables_pdf, document_df): +def test_pdfium_extractor_table_extraction_on_pdf_with_tables(pdf_stream_embedded_tables_pdf, document_df): """ Test to ensure pdfium's table extraction is able to extract easy-to-read tables from a PDF. """ - extracted_data = pdfium( + extracted_data = pdfium_extractor( pdf_stream_embedded_tables_pdf, extract_text=False, extract_images=False, diff --git a/tests/nv_ingest/schemas/test_image_caption_extraction_schema.py b/tests/nv_ingest/schemas/test_image_caption_extraction_schema.py new file mode 100644 index 00000000..c059a253 --- /dev/null +++ b/tests/nv_ingest/schemas/test_image_caption_extraction_schema.py @@ -0,0 +1,70 @@ +import pytest +from pydantic import ValidationError + +from nv_ingest.schemas import ImageCaptionExtractionSchema + + +def test_valid_schema(): + # Test with all required fields and optional defaults + valid_data = { + "api_key": "your-api-key-here", + } + schema = ImageCaptionExtractionSchema(**valid_data) + assert schema.api_key == "your-api-key-here" + assert schema.endpoint_url == "https://ai.api.nvidia.com/v1/gr/meta/llama-3.2-90b-vision-instruct/chat/completions" + assert schema.prompt == "Caption the content of this image:" + assert schema.raise_on_failure is False + + +def test_valid_schema_with_custom_values(): + # Test with all fields including custom values for optional fields + valid_data = { + "api_key": "your-api-key-here", + "endpoint_url": "https://custom.api.endpoint", + "prompt": "Describe the image:", + "raise_on_failure": True, + } + schema = ImageCaptionExtractionSchema(**valid_data) + assert schema.api_key == "your-api-key-here" + assert schema.endpoint_url == "https://custom.api.endpoint" + assert schema.prompt == "Describe the image:" + assert schema.raise_on_failure is True + + +def test_missing_api_key(): + # Test with missing required field `api_key` + with pytest.raises(ValidationError) as exc_info: + ImageCaptionExtractionSchema() + assert "field required" in str(exc_info.value) + + +def test_invalid_extra_field(): + # Test with an additional field that should be forbidden + data_with_extra_field = { + "api_key": "your-api-key-here", + "extra_field": "should_not_be_allowed" + } + with pytest.raises(ValidationError) as exc_info: + ImageCaptionExtractionSchema(**data_with_extra_field) + assert "extra fields not permitted" in str(exc_info.value) + + +def test_invalid_field_types(): + # Test with wrong types for optional fields + invalid_data = { + "api_key": "your-api-key-here", + "endpoint_url": 12345, # invalid type + "prompt": 123, # invalid type + "raise_on_failure": "not_boolean" # invalid type + } + with pytest.raises(ValidationError) as exc_info: + ImageCaptionExtractionSchema(**invalid_data) + + +def test_default_values(): + # Test that default values are correctly assigned when not provided + data = {"api_key": "your-api-key-here"} + schema = ImageCaptionExtractionSchema(**data) + assert schema.endpoint_url == "https://ai.api.nvidia.com/v1/gr/meta/llama-3.2-90b-vision-instruct/chat/completions" + assert schema.prompt == "Caption the content of this image:" + assert schema.raise_on_failure is False diff --git a/tests/nv_ingest/schemas/test_image_extrator_schema.py b/tests/nv_ingest/schemas/test_image_extrator_schema.py new file mode 100644 index 00000000..bf1ea2d6 --- /dev/null +++ b/tests/nv_ingest/schemas/test_image_extrator_schema.py @@ -0,0 +1,119 @@ +import pytest +from pydantic import ValidationError + +from nv_ingest.schemas.image_extractor_schema import ImageConfigSchema, ImageExtractorSchema + + +def test_image_config_schema_valid(): + # Test valid data with both gRPC and HTTP endpoints + config = ImageConfigSchema( + auth_token="token123", + yolox_endpoints=("grpc_service_url", "http_service_url"), + yolox_infer_protocol="http" + ) + assert config.auth_token == "token123" + assert config.yolox_endpoints == ("grpc_service_url", "http_service_url") + assert config.yolox_infer_protocol == "http" + + +def test_image_config_schema_valid_single_service(): + # Test valid data with only gRPC service + config = ImageConfigSchema( + yolox_endpoints=("grpc_service_url", None), + ) + assert config.yolox_endpoints == ("grpc_service_url", None) + assert config.yolox_infer_protocol == "grpc" + + # Test valid data with only HTTP service + config = ImageConfigSchema( + yolox_endpoints=(None, "http_service_url"), + ) + assert config.yolox_endpoints == (None, "http_service_url") + assert config.yolox_infer_protocol == "http" + + +def test_image_config_schema_invalid_both_services_empty(): + # Test invalid data with both gRPC and HTTP services empty + with pytest.raises(ValidationError) as exc_info: + ImageConfigSchema(yolox_endpoints=(None, None)) + errors = exc_info.value.errors() + assert any("Both gRPC and HTTP services cannot be empty" in error['msg'] for error in errors) + + +def test_image_config_schema_empty_service_strings(): + # Test services that are empty strings or whitespace + config = ImageConfigSchema( + yolox_endpoints=(" ", "http_service_url") + ) + assert config.yolox_endpoints == (None, "http_service_url") # Cleaned empty strings are None + + +def test_image_config_schema_missing_infer_protocol(): + # Test infer_protocol default setting based on available service + config = ImageConfigSchema( + yolox_endpoints=("grpc_service_url", None) + ) + assert config.yolox_infer_protocol == "grpc" + + +def test_image_config_schema_extra_field(): + # Test extra fields raise a validation error + with pytest.raises(ValidationError): + ImageConfigSchema( + auth_token="token123", + yolox_endpoints=("grpc_service_url", "http_service_url"), + extra_field="should_not_be_allowed" + ) + + +def test_image_extractor_schema_valid(): + # Test valid data for ImageExtractorSchema with nested ImageConfigSchema + config = ImageExtractorSchema( + max_queue_size=10, + n_workers=4, + raise_on_failure=True, + image_extraction_config=ImageConfigSchema( + auth_token="token123", + yolox_endpoints=("grpc_service_url", "http_service_url"), + yolox_infer_protocol="http" + ) + ) + assert config.max_queue_size == 10 + assert config.n_workers == 4 + assert config.raise_on_failure is True + assert config.image_extraction_config.auth_token == "token123" + + +def test_image_extractor_schema_defaults(): + # Test default values for optional fields + config = ImageExtractorSchema() + assert config.max_queue_size == 1 + assert config.n_workers == 16 + assert config.raise_on_failure is False + assert config.image_extraction_config is None + + +def test_image_extractor_schema_invalid_max_queue_size(): + # Test invalid type for max_queue_size + with pytest.raises(ValidationError) as exc_info: + ImageExtractorSchema(max_queue_size="invalid_type") + errors = exc_info.value.errors() + assert any(error['loc'] == ('max_queue_size',) and error['type'] == 'type_error.integer' for error in errors) + + +def test_image_extractor_schema_invalid_n_workers(): + # Test invalid type for n_workers + with pytest.raises(ValidationError) as exc_info: + ImageExtractorSchema(n_workers="invalid_type") + errors = exc_info.value.errors() + assert any(error['loc'] == ('n_workers',) and error['type'] == 'type_error.integer' for error in errors) + + +def test_image_extractor_schema_invalid_nested_config(): + # Test invalid nested image_extraction_config + with pytest.raises(ValidationError) as exc_info: + ImageExtractorSchema( + image_extraction_config={"auth_token": "token123", "yolox_endpoints": (None, None)} + ) + errors = exc_info.value.errors() + assert any("Both gRPC and HTTP services cannot be empty" in error['msg'] for error in errors) diff --git a/tests/nv_ingest/stages/__init__.py b/tests/nv_ingest/stages/__init__.py new file mode 100644 index 00000000..6a35633c --- /dev/null +++ b/tests/nv_ingest/stages/__init__.py @@ -0,0 +1,4 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + diff --git a/tests/nv_ingest/stages/nims/__init__.py b/tests/nv_ingest/stages/nims/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/stages/nims/test_chart_extraction.py b/tests/nv_ingest/stages/nims/test_chart_extraction.py similarity index 100% rename from tests/stages/nims/test_chart_extraction.py rename to tests/nv_ingest/stages/nims/test_chart_extraction.py diff --git a/tests/stages/nims/test_table_extraction.py b/tests/nv_ingest/stages/nims/test_table_extraction.py similarity index 100% rename from tests/stages/nims/test_table_extraction.py rename to tests/nv_ingest/stages/nims/test_table_extraction.py diff --git a/tests/nv_ingest/stages/test_image_extractor_stage.py b/tests/nv_ingest/stages/test_image_extractor_stage.py new file mode 100644 index 00000000..6965ec48 --- /dev/null +++ b/tests/nv_ingest/stages/test_image_extractor_stage.py @@ -0,0 +1,229 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +import base64 + +import pytest +from unittest.mock import patch, MagicMock +import pandas as pd + +from nv_ingest.stages.extractors.image_extractor_stage import process_image +from nv_ingest.stages.extractors.image_extractor_stage import decode_and_extract + +MODULE_UNDER_TEST = 'nv_ingest.stages.extractors.image_extractor_stage' + + +# Define the test function using pytest +@patch(f'{MODULE_UNDER_TEST}.decode_and_extract') +def test_process_image_single_row(mock_decode_and_extract): + mock_decode_and_extract.return_value = [ + {"document_type": "type1", "metadata": {"key": "value"}, "uuid": "1234"} + ] + + input_df = pd.DataFrame({ + "source_id": [1], + "content": ["base64encodedstring"] + }) + + task_props = {"method": "some_method"} + validated_config = MagicMock() + trace_info = {} + + processed_df, trace_info_output = process_image(input_df, task_props, validated_config, trace_info) + + assert len(processed_df) == 1 + assert "document_type" in processed_df.columns + assert "metadata" in processed_df.columns + assert "uuid" in processed_df.columns + assert processed_df.iloc[0]["document_type"] == "type1" + assert processed_df.iloc[0]["metadata"] == {"key": "value"} + assert processed_df.iloc[0]["uuid"] == "1234" + assert trace_info_output["trace_info"] == trace_info + + +@patch(f'{MODULE_UNDER_TEST}.decode_and_extract') +def test_process_image_empty_dataframe(mock_decode_and_extract): + mock_decode_and_extract.return_value = [] + + input_df = pd.DataFrame(columns=["source_id", "content"]) + task_props = {"method": "some_method"} + validated_config = MagicMock() + trace_info = {} + + processed_df, trace_info_output = process_image(input_df, task_props, validated_config, trace_info) + + assert processed_df.empty + assert "document_type" in processed_df.columns + assert "metadata" in processed_df.columns + assert "uuid" in processed_df.columns + assert trace_info_output["trace_info"] == trace_info + + +@patch(f'{MODULE_UNDER_TEST}.decode_and_extract') +def test_process_image_multiple_rows(mock_decode_and_extract): + mock_decode_and_extract.side_effect = [ + [{"document_type": "type1", "metadata": {"key": "value1"}, "uuid": "1234"}], + [{"document_type": "type2", "metadata": {"key": "value2"}, "uuid": "5678"}] + ] + + input_df = pd.DataFrame({ + "source_id": [1, 2], + "content": ["base64encodedstring1", "base64encodedstring2"] + }) + + task_props = {"method": "some_method"} + validated_config = MagicMock() + trace_info = {} + + processed_df, trace_info_output = process_image(input_df, task_props, validated_config, trace_info) + + assert len(processed_df) == 2 + assert processed_df.iloc[0]["document_type"] == "type1" + assert processed_df.iloc[0]["metadata"] == {"key": "value1"} + assert processed_df.iloc[0]["uuid"] == "1234" + assert processed_df.iloc[1]["document_type"] == "type2" + assert processed_df.iloc[1]["metadata"] == {"key": "value2"} + assert processed_df.iloc[1]["uuid"] == "5678" + assert trace_info_output["trace_info"] == trace_info + + +@patch(f'{MODULE_UNDER_TEST}.decode_and_extract') +def test_process_image_with_exception(mock_decode_and_extract): + mock_decode_and_extract.side_effect = Exception("Decoding error") + + input_df = pd.DataFrame({ + "source_id": [1], + "content": ["base64encodedstring"] + }) + + task_props = {"method": "some_method"} + validated_config = MagicMock() + trace_info = {} + + with pytest.raises(Exception) as excinfo: + process_image(input_df, task_props, validated_config, trace_info) + + assert "Decoding error" in str(excinfo.value) + + +@patch(f'{MODULE_UNDER_TEST}.image_helpers') +def test_decode_and_extract_valid_method(mock_image_helpers): + # Mock the extraction function inside image_helpers + mock_func = MagicMock(return_value="extracted_data") + mock_image_helpers.image = mock_func # Default extraction method + + # Sample inputs as a pandas Series (row) + base64_content = base64.b64encode(b"dummy_image_data").decode('utf-8') + base64_row = pd.Series({ + "content": base64_content, + "document_type": "image", + "source_id": 1 + }) + task_props = {"method": "image", "params": {}} + validated_config = MagicMock() + trace_info = [] + + # Call the function + result = decode_and_extract(base64_row, task_props, validated_config, default="image", trace_info=trace_info) + + # Assert that the mocked function was called correctly + assert result == "extracted_data" + mock_func.assert_called_once() + + +@patch(f'{MODULE_UNDER_TEST}.image_helpers') +def test_decode_and_extract_missing_content_key(mock_image_helpers): + # Sample inputs with missing 'content' key as a pandas Series (row) + base64_row = pd.Series({ + "document_type": "image", + "source_id": 1 + }) + task_props = {"method": "image", "params": {}} + validated_config = MagicMock() + trace_info = [] + + # Expecting a KeyError + with pytest.raises(KeyError): + decode_and_extract(base64_row, task_props, validated_config, trace_info=trace_info) + + +@patch(f'{MODULE_UNDER_TEST}.image_helpers') +def test_decode_and_extract_fallback_to_default_method(mock_image_helpers): + # Mock only the default method; other methods will appear as non-existent + mock_default_func = MagicMock(return_value="default_extracted_data") + setattr(mock_image_helpers, 'default', mock_default_func) + + # Ensure that non_existing_method does not exist on mock_image_helpers + if hasattr(mock_image_helpers, "non_existing_method"): + delattr(mock_image_helpers, "non_existing_method") + + # Input with a non-existing extraction method as a pandas Series (row) + base64_content = base64.b64encode(b"dummy_image_data").decode('utf-8') + base64_row = pd.Series({ + "content": base64_content, + "document_type": "image", + "source_id": 1 + }) + task_props = {"method": "non_existing_method", "params": {}} + validated_config = MagicMock() + trace_info = [] + + # Call the function + result = decode_and_extract(base64_row, task_props, validated_config, default="default", trace_info=trace_info) + + # Assert that the default function was called instead of the missing one + assert result == "default_extracted_data" + mock_default_func.assert_called_once() + + +@patch(f'{MODULE_UNDER_TEST}.image_helpers') +def test_decode_and_extract_with_trace_info(mock_image_helpers): + # Mock the extraction function with trace_info usage + mock_func = MagicMock(return_value="extracted_data_with_trace") + mock_image_helpers.image = mock_func # Default extraction method + + # Sample inputs with trace_info as a pandas Series (row) + base64_content = base64.b64encode(b"dummy_image_data").decode('utf-8') + base64_row = pd.Series({ + "content": base64_content, + "document_type": "image", + "source_id": 1 + }) + task_props = {"method": "image", "params": {}} + validated_config = MagicMock() + trace_info = [{"some": "trace_info"}] + + # Call the function + result = decode_and_extract(base64_row, task_props, validated_config, trace_info=trace_info) + + # Assert that the mocked function was called with trace_info in params + assert result == "extracted_data_with_trace" + mock_func.assert_called_once() + _, _, kwargs = mock_func.mock_calls[0] + assert "trace_info" in kwargs + assert kwargs["trace_info"] == trace_info + + +@patch(f'{MODULE_UNDER_TEST}.image_helpers') +def test_decode_and_extract_handles_exception_in_extraction(mock_image_helpers): + # Mock the extraction function (using a valid method) to raise an exception + mock_func = MagicMock(side_effect=Exception("Extraction error")) + mock_image_helpers.image = mock_func # Use the default method or a valid method + + # Sample inputs as a pandas Series (row) + base64_content = base64.b64encode(b"dummy_image_data").decode('utf-8') + base64_row = pd.Series({ + "content": base64_content, + "document_type": "image", + "source_id": 1 + }) + task_props = {"method": "image", "params": {}} # Use a valid method name + validated_config = MagicMock() + trace_info = [] + + # Expecting an exception during extraction + with pytest.raises(Exception) as excinfo: + decode_and_extract(base64_row, task_props, validated_config, trace_info=trace_info) + + # Verify the exception message + assert str(excinfo.value) == "Extraction error" diff --git a/tests/nv_ingest/stages/transforms/__init__.py b/tests/nv_ingest/stages/transforms/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/nv_ingest/stages/transforms/test_image_caption_extraction.py b/tests/nv_ingest/stages/transforms/test_image_caption_extraction.py new file mode 100644 index 00000000..5b946566 --- /dev/null +++ b/tests/nv_ingest/stages/transforms/test_image_caption_extraction.py @@ -0,0 +1,312 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +import base64 +import io + +import requests +from PIL import Image +from unittest.mock import MagicMock, patch + +import pytest + +MODULE_UNDER_TEST = 'nv_ingest.stages.transforms.image_caption_extraction' + +import pandas as pd + +from nv_ingest.schemas.metadata_schema import ContentTypeEnum +from nv_ingest.stages.transforms.image_caption_extraction import _prepare_dataframes_mod +from nv_ingest.stages.transforms.image_caption_extraction import _generate_captions +from nv_ingest.stages.transforms.image_caption_extraction import caption_extract_stage + + +def generate_base64_png_image() -> str: + """Helper function to generate a base64-encoded PNG image.""" + img = Image.new("RGB", (10, 10), color="blue") # Create a simple blue image + buffered = io.BytesIO() + img.save(buffered, format="PNG") + return base64.b64encode(buffered.getvalue()).decode("utf-8") + + +def test_prepare_dataframes_empty_dataframe(): + # Test with an empty DataFrame + df = pd.DataFrame() + + df_out, df_matched, bool_index = _prepare_dataframes_mod(df) + + assert df_out.equals(df) + assert df_matched.empty + assert bool_index.empty + assert bool_index.dtype == bool + + +def test_prepare_dataframes_missing_document_type_column(): + # Test with a DataFrame missing the 'document_type' column + df = pd.DataFrame({ + "other_column": [1, 2, 3] + }) + + df_out, df_matched, bool_index = _prepare_dataframes_mod(df) + + assert df_out.equals(df) + assert df_matched.empty + assert bool_index.empty + assert bool_index.dtype == bool + + +def test_prepare_dataframes_no_matches(): + # Test with a DataFrame where no 'document_type' matches ContentTypeEnum.IMAGE + df = pd.DataFrame({ + "document_type": [ContentTypeEnum.TEXT, ContentTypeEnum.STRUCTURED, ContentTypeEnum.UNSTRUCTURED] + }) + + df_out, df_matched, bool_index = _prepare_dataframes_mod(df) + + assert df_out.equals(df) + assert df_matched.empty + assert bool_index.equals(pd.Series([False, False, False])) + assert bool_index.dtype == bool + + +def test_prepare_dataframes_partial_matches(): + # Test with a DataFrame where some rows match ContentTypeEnum.IMAGE + df = pd.DataFrame({ + "document_type": [ContentTypeEnum.IMAGE, ContentTypeEnum.TEXT, ContentTypeEnum.IMAGE] + }) + + df_out, df_matched, bool_index = _prepare_dataframes_mod(df) + + assert df_out.equals(df) + assert not df_matched.empty + assert df_matched.equals(df[df["document_type"] == ContentTypeEnum.IMAGE]) + assert bool_index.equals(pd.Series([True, False, True])) + assert bool_index.dtype == bool + + +def test_prepare_dataframes_all_matches(): + # Test with a DataFrame where all rows match ContentTypeEnum.IMAGE + df = pd.DataFrame({ + "document_type": [ContentTypeEnum.IMAGE, ContentTypeEnum.IMAGE, ContentTypeEnum.IMAGE] + }) + + df_out, df_matched, bool_index = _prepare_dataframes_mod(df) + + assert df_out.equals(df) + assert df_matched.equals(df) + assert bool_index.equals(pd.Series([True, True, True])) + assert bool_index.dtype == bool + + +@patch(f'{MODULE_UNDER_TEST}._generate_captions') +def test_caption_extract_no_image_content(mock_generate_captions): + # DataFrame with no image content + df = pd.DataFrame({ + "metadata": [{"content_metadata": {"type": "text"}}, {"content_metadata": {"type": "pdf"}}] + }) + task_props = {"api_key": "test_api_key", "prompt": "Describe the image", "endpoint_url": "https://api.example.com"} + validated_config = MagicMock() + trace_info = {} + + # Call the function + result_df = caption_extract_stage(df, task_props, validated_config, trace_info) + + # Check that _generate_captions was not called and df is unchanged + mock_generate_captions.assert_not_called() + assert result_df.equals(df) + + +@patch(f'{MODULE_UNDER_TEST}._generate_captions') +def test_caption_extract_with_image_content(mock_generate_captions): + # Mock caption generation + mock_generate_captions.return_value = "A description of the image." + + # DataFrame with image content + df = pd.DataFrame({ + "metadata": [{"content_metadata": {"type": "image"}, "content": "base64_encoded_image_data"}] + }) + task_props = {"api_key": "test_api_key", "prompt": "Describe the image", "endpoint_url": "https://api.example.com"} + validated_config = MagicMock() + trace_info = {} + + # Call the function + result_df = caption_extract_stage(df, task_props, validated_config, trace_info) + + # Check that _generate_captions was called once + mock_generate_captions.assert_called_once_with("base64_encoded_image_data", "Describe the image", "test_api_key", + "https://api.example.com") + + # Verify that the caption was added to image_metadata + assert result_df.loc[0, "metadata"]["image_metadata"]["caption"] == "A description of the image." + + +@patch(f'{MODULE_UNDER_TEST}._generate_captions') +def test_caption_extract_mixed_content(mock_generate_captions): + # Mock caption generation + mock_generate_captions.return_value = "A description of the image." + + # DataFrame with mixed content types + df = pd.DataFrame({ + "metadata": [ + {"content_metadata": {"type": "image"}, "content": "image_data_1"}, + {"content_metadata": {"type": "text"}, "content": "text_data"}, + {"content_metadata": {"type": "image"}, "content": "image_data_2"} + ] + }) + task_props = {"api_key": "test_api_key", "prompt": "Describe the image", "endpoint_url": "https://api.example.com"} + validated_config = MagicMock() + trace_info = {} + + # Call the function + result_df = caption_extract_stage(df, task_props, validated_config, trace_info) + + # Check that _generate_captions was called twice for images only + assert mock_generate_captions.call_count == 2 + mock_generate_captions.assert_any_call("image_data_1", "Describe the image", "test_api_key", + "https://api.example.com") + mock_generate_captions.assert_any_call("image_data_2", "Describe the image", "test_api_key", + "https://api.example.com") + + # Verify that captions were added only for image rows + assert result_df.loc[0, "metadata"]["image_metadata"]["caption"] == "A description of the image." + assert "caption" not in result_df.loc[1, "metadata"].get("image_metadata", {}) + assert result_df.loc[2, "metadata"]["image_metadata"]["caption"] == "A description of the image." + + +@patch(f'{MODULE_UNDER_TEST}._generate_captions') +def test_caption_extract_empty_dataframe(mock_generate_captions): + # Empty DataFrame + df = pd.DataFrame(columns=["metadata"]) + task_props = {"api_key": "test_api_key", "prompt": "Describe the image", "endpoint_url": "https://api.example.com"} + validated_config = MagicMock() + trace_info = {} + + # Call the function + result_df = caption_extract_stage(df, task_props, validated_config, trace_info) + + # Check that _generate_captions was not called and df is still empty + mock_generate_captions.assert_not_called() + assert result_df.empty + + +@patch(f'{MODULE_UNDER_TEST}._generate_captions') +def test_caption_extract_malformed_metadata(mock_generate_captions): + # Mock caption generation + mock_generate_captions.return_value = "A description of the image." + + # DataFrame with malformed metadata (missing 'content' key in one row) + df = pd.DataFrame({ + "metadata": [{"unexpected_key": "value"}, {"content_metadata": {"type": "image"}}] + }) + task_props = {"api_key": "test_api_key", "prompt": "Describe the image", "endpoint_url": "https://api.example.com"} + validated_config = MagicMock() + trace_info = {} + + # Expecting KeyError for missing 'content' in the second row + with pytest.raises(KeyError, match="'content'"): + caption_extract_stage(df, task_props, validated_config, trace_info) + + +@patch(f"{MODULE_UNDER_TEST}.requests.post") +def test_generate_captions_successful(mock_post): + # Mock the successful API response + mock_response = MagicMock() + mock_response.raise_for_status = MagicMock() + mock_response.json.return_value = { + "choices": [ + {"message": {"content": "A beautiful sunset over the mountains."}} + ] + } + mock_post.return_value = mock_response + + # Parameters + base64_image = generate_base64_png_image() + prompt = "Describe the image" + api_key = "test_api_key" + endpoint_url = "https://api.example.com" + + # Call the function + result = _generate_captions(base64_image, prompt, api_key, endpoint_url) + + # Verify that the correct caption was returned + assert result == "A beautiful sunset over the mountains." + mock_post.assert_called_once_with( + endpoint_url, + headers={"Authorization": f"Bearer {api_key}", "Accept": "application/json"}, + json={ + "model": 'meta/llama-3.2-90b-vision-instruct', + "messages": [ + { + "role": "user", + "content": f'{prompt} ' + } + ], + "max_tokens": 512, + "temperature": 1.00, + "top_p": 1.00, + "stream": False + } + ) + + +@patch(f"{MODULE_UNDER_TEST}.requests.post") +def test_generate_captions_api_error(mock_post): + # Mock a 500 Internal Server Error response + mock_response = MagicMock() + mock_response.raise_for_status.side_effect = requests.exceptions.HTTPError("500 Server Error") + mock_post.return_value = mock_response + + # Parameters + base64_image = generate_base64_png_image() + prompt = "Describe the image" + api_key = "test_api_key" + endpoint_url = "https://api.example.com" + + # Expect an exception due to the server error + with pytest.raises(requests.exceptions.RequestException, match="500 Server Error"): + _generate_captions(base64_image, prompt, api_key, endpoint_url) + + +@patch(f"{MODULE_UNDER_TEST}.requests.post") +def test_generate_captions_malformed_json(mock_post): + # Mock a response with an unexpected JSON structure + mock_response = MagicMock() + mock_response.raise_for_status = MagicMock() + mock_response.json.return_value = {"unexpected_key": "unexpected_value"} + mock_post.return_value = mock_response + + # Parameters + base64_image = generate_base64_png_image() + prompt = "Describe the image" + api_key = "test_api_key" + endpoint_url = "https://api.example.com" + + # Call the function + result = _generate_captions(base64_image, prompt, api_key, endpoint_url) + + # Verify fallback response when JSON is malformed + assert result == "No caption returned" + + +@patch(f"{MODULE_UNDER_TEST}.requests.post") +def test_generate_captions_empty_caption_content(mock_post): + # Mock a response with empty caption content + mock_response = MagicMock() + mock_response.raise_for_status = MagicMock() + mock_response.json.return_value = { + "choices": [ + {"message": {"content": ""}} + ] + } + mock_post.return_value = mock_response + + # Parameters + base64_image = generate_base64_png_image() + prompt = "Describe the image" + api_key = "test_api_key" + endpoint_url = "https://api.example.com" + + # Call the function + result = _generate_captions(base64_image, prompt, api_key, endpoint_url) + + # Verify that the fallback response is returned + assert result == "" diff --git a/tests/nv_ingest/util/image_processing/test_transforms.py b/tests/nv_ingest/util/image_processing/test_transforms.py index ad86c191..55b51585 100644 --- a/tests/nv_ingest/util/image_processing/test_transforms.py +++ b/tests/nv_ingest/util/image_processing/test_transforms.py @@ -2,14 +2,18 @@ # All rights reserved. # SPDX-License-Identifier: Apache-2.0 -import pytest -import numpy as np -from PIL import Image import base64 +import io +import numpy as np +import pytest + from io import BytesIO +from PIL import Image +from typing import Tuple from unittest import mock -from nv_ingest.util.image_processing.transforms import numpy_to_base64, base64_to_numpy, check_numpy_image_size +from nv_ingest.util.image_processing.transforms import numpy_to_base64, base64_to_numpy, check_numpy_image_size, \ + scale_image_to_encoding_size, ensure_base64_is_png # Helper function to create a base64-encoded string from an image @@ -107,3 +111,125 @@ def test_check_numpy_image_size_invalid_dimensions(): img = np.zeros((100,), dtype=np.uint8) # 1D array with pytest.raises(ValueError, match="The input array does not have sufficient dimensions for an image."): check_numpy_image_size(img, 50, 50) + + +def generate_base64_image(size: Tuple[int, int]) -> str: + """Helper function to generate a base64-encoded PNG image of a specific size.""" + img = Image.new("RGB", size, color="blue") # Create a simple blue image + buffered = io.BytesIO() + img.save(buffered, format="PNG") + return base64.b64encode(buffered.getvalue()).decode("utf-8") + + +def generate_base64_image_with_format(format: str = 'PNG', size: Tuple[int, int] = (100, 100)) -> str: + """Helper function to generate a base64-encoded image of a specified format and size.""" + img = Image.new("RGB", size, color="blue") # Simple blue image + buffered = io.BytesIO() + img.save(buffered, format=format) + return base64.b64encode(buffered.getvalue()).decode("utf-8") + + +def test_resize_image_within_size_limit(): + # Generate a base64 image within the size limit + base64_image = generate_base64_image((100, 100)) # Small image + max_base64_size = len(base64_image) + 10 # Set limit slightly above image size + + result = scale_image_to_encoding_size(base64_image, max_base64_size) + assert result == base64_image # Should return unchanged + + +def test_resize_image_one_resize_needed(): + # Generate a large base64 image that requires resizing + base64_image = generate_base64_image((500, 500)) + max_base64_size = len(base64_image) - 1000 # Set limit slightly below current size + + result = scale_image_to_encoding_size(base64_image, max_base64_size) + assert len(result) <= max_base64_size # Should be resized within limit + + +def test_resize_image_multiple_resizes_needed(): + # Generate a very large base64 image that will require multiple reductions + base64_image = generate_base64_image((1000, 1000)) + max_base64_size = len(base64_image) // 2 # Set limit well below current size + + result = scale_image_to_encoding_size(base64_image, max_base64_size) + assert len(result) <= max_base64_size # Final size should be within limit + + +def test_resize_image_cannot_be_resized_below_limit(): + # Generate a small base64 image where further resizing would be impractical + base64_image = generate_base64_image((10, 10)) + max_base64_size = 1 # Unreachable size limit + + with pytest.raises(ValueError, match="height and width must be > 0"): + scale_image_to_encoding_size(base64_image, max_base64_size) + + +def test_resize_image_edge_case_minimal_reduction(): + # Generate an image just above the size limit + base64_image = generate_base64_image((500, 500)) + max_base64_size = len(base64_image) - 50 # Just a slight reduction needed + + result = scale_image_to_encoding_size(base64_image, max_base64_size) + assert len(result) <= max_base64_size # Should achieve minimal reduction within limit + + +def test_resize_image_with_invalid_input(): + # Provide non-image data as input + non_image_base64 = base64.b64encode(b"This is not an image").decode("utf-8") + + with pytest.raises(Exception): + scale_image_to_encoding_size(non_image_base64) + + +def test_ensure_base64_is_png_already_png(): + # Generate a base64-encoded PNG image + base64_image = generate_base64_image_with_format("PNG") + + result = ensure_base64_is_png(base64_image) + assert result == base64_image # Should be unchanged + + +def test_ensure_base64_is_png_jpeg_to_png_conversion(): + # Generate a base64-encoded JPEG image + base64_image = generate_base64_image_with_format("JPEG") + + result = ensure_base64_is_png(base64_image) + + # Decode the result and check format + image_data = base64.b64decode(result) + image = Image.open(io.BytesIO(image_data)) + assert image.format == "PNG" # Should be converted to PNG + + +def test_ensure_base64_is_png_invalid_base64(): + # Provide invalid base64 input + invalid_base64 = "This is not base64 encoded data" + + result = ensure_base64_is_png(invalid_base64) + assert result is None # Should return None for invalid input + + +def test_ensure_base64_is_png_non_image_base64_data(): + # Provide valid base64 data that isn’t an image + non_image_base64 = base64.b64encode(b"This is not an image").decode("utf-8") + + result = ensure_base64_is_png(non_image_base64) + assert result is None # Should return None for non-image data + + +def test_ensure_base64_is_png_unsupported_format(): + # Generate an image in a rare format and base64 encode it + img = Image.new("RGB", (100, 100), color="blue") + buffered = io.BytesIO() + img.save(buffered, format="BMP") # Use an uncommon format like BMP + base64_image = base64.b64encode(buffered.getvalue()).decode("utf-8") + + result = ensure_base64_is_png(base64_image) + # Decode result to verify conversion + if result: + image_data = base64.b64decode(result) + image = Image.open(io.BytesIO(image_data)) + assert image.format == "PNG" # Should be converted to PNG if supported + else: + assert result is None # If unsupported, result should be None diff --git a/tests/nv_ingest_client/primitives/tasks/test_caption.py b/tests/nv_ingest_client/primitives/tasks/test_caption.py new file mode 100644 index 00000000..f9af7469 --- /dev/null +++ b/tests/nv_ingest_client/primitives/tasks/test_caption.py @@ -0,0 +1,139 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from pydantic import ValidationError +from nv_ingest_client.primitives.tasks.caption import CaptionTaskSchema, CaptionTask + + +# Testing CaptionTaskSchema + +def test_valid_schema_initialization(): + """Test valid initialization of CaptionTaskSchema with all fields.""" + schema = CaptionTaskSchema(api_key="test_key", endpoint_url="http://example.com", prompt="Generate a caption") + assert schema.api_key == "test_key" + assert schema.endpoint_url == "http://example.com" + assert schema.prompt == "Generate a caption" + + +def test_partial_schema_initialization(): + """Test valid initialization of CaptionTaskSchema with some fields omitted.""" + schema = CaptionTaskSchema(api_key="test_key") + assert schema.api_key == "test_key" + assert schema.endpoint_url is None + assert schema.prompt is None + + +def test_empty_schema_initialization(): + """Test valid initialization of CaptionTaskSchema with no fields.""" + schema = CaptionTaskSchema() + assert schema.api_key is None + assert schema.endpoint_url is None + assert schema.prompt is None + + +def test_schema_invalid_extra_field(): + """Test that CaptionTaskSchema raises an error with extra fields.""" + try: + CaptionTaskSchema(api_key="test_key", extra_field="invalid") + except ValidationError as e: + assert "extra_field" in str(e) + + +# Testing CaptionTask + +def test_caption_task_initialization(): + """Test initializing CaptionTask with all fields.""" + task = CaptionTask(api_key="test_key", endpoint_url="http://example.com", prompt="Generate a caption") + assert task._api_key == "test_key" + assert task._endpoint_url == "http://example.com" + assert task._prompt == "Generate a caption" + + +def test_caption_task_partial_initialization(): + """Test initializing CaptionTask with some fields omitted.""" + task = CaptionTask(api_key="test_key") + assert task._api_key == "test_key" + assert task._endpoint_url is None + assert task._prompt is None + + +def test_caption_task_empty_initialization(): + """Test initializing CaptionTask with no fields.""" + task = CaptionTask() + assert task._api_key is None + assert task._endpoint_url is None + assert task._prompt is None + + +def test_caption_task_str_representation_all_fields(): + """Test string representation of CaptionTask with all fields.""" + task = CaptionTask(api_key="test_key", endpoint_url="http://example.com", prompt="Generate a caption") + task_str = str(task) + assert "Image Caption Task:" in task_str + assert "api_key: [redacted]" in task_str + assert "endpoint_url: http://example.com" in task_str + assert "prompt: Generate a caption" in task_str + + +def test_caption_task_str_representation_partial_fields(): + """Test string representation of CaptionTask with partial fields.""" + task = CaptionTask(api_key="test_key") + task_str = str(task) + assert "Image Caption Task:" in task_str + assert "api_key: [redacted]" in task_str + assert "endpoint_url" not in task_str + assert "prompt" not in task_str + + +def test_caption_task_to_dict_all_fields(): + """Test to_dict method of CaptionTask with all fields.""" + task = CaptionTask(api_key="test_key", endpoint_url="http://example.com", prompt="Generate a caption") + task_dict = task.to_dict() + assert task_dict == { + "type": "caption", + "task_properties": { + "api_key": "test_key", + "endpoint_url": "http://example.com", + "prompt": "Generate a caption" + } + } + + +def test_caption_task_to_dict_partial_fields(): + """Test to_dict method of CaptionTask with partial fields.""" + task = CaptionTask(api_key="test_key") + task_dict = task.to_dict() + assert task_dict == { + "type": "caption", + "task_properties": { + "api_key": "test_key" + } + } + + +def test_caption_task_to_dict_empty_fields(): + """Test to_dict method of CaptionTask with no fields.""" + task = CaptionTask() + task_dict = task.to_dict() + assert task_dict == { + "type": "caption", + "task_properties": {} + } + + +# Execute tests +if __name__ == "__main__": + test_valid_schema_initialization() + test_partial_schema_initialization() + test_empty_schema_initialization() + test_schema_invalid_extra_field() + test_caption_task_initialization() + test_caption_task_partial_initialization() + test_caption_task_empty_initialization() + test_caption_task_str_representation_all_fields() + test_caption_task_str_representation_partial_fields() + test_caption_task_to_dict_all_fields() + test_caption_task_to_dict_partial_fields() + test_caption_task_to_dict_empty_fields() + print("All tests passed.")