diff --git a/Dockerfile b/Dockerfile
index a9663eab..c4405279 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -49,6 +49,7 @@ RUN source activate nv_ingest \
&& mamba install -y \
nvidia/label/dev::morpheus-core \
nvidia/label/dev::morpheus-llm \
+ imagemagick \
# pin to earlier version of cuda-python until __pyx_capi__ fix is upstreamed.
cuda-python=12.6.0 \
-c rapidsai -c pytorch -c nvidia -c conda-forge
diff --git a/client/src/nv_ingest_client/nv_ingest_cli.py b/client/src/nv_ingest_client/nv_ingest_cli.py
index 8b38c10b..d4705273 100644
--- a/client/src/nv_ingest_client/nv_ingest_cli.py
+++ b/client/src/nv_ingest_client/nv_ingest_cli.py
@@ -116,57 +116,61 @@
\b
Tasks and Options:
-- split: Divides documents according to specified criteria.
+- caption: Attempts to extract captions for unstructured images extracted from documents.
Options:
- - split_by (str): Criteria ('page', 'size', 'word', 'sentence'). No default.
- - split_length (int): Segment length. No default.
- - split_overlap (int): Segment overlap. No default.
- - max_character_length (int): Maximum segment character count. No default.
- - sentence_window_size (int): Sentence window size. No default.
+ - api_key (str): API key for captioning service.
+ Default: os.environ(NVIDIA_BUILD_API_KEY).'
+ - endpoint_url (str): Endpoint URL for captioning service.
+ Default: 'https://build.nvidia.com/meta/llama-3.2-90b-vision-instruct'.
+ - prompt (str): Prompt for captioning service.
+ Default: 'Caption the content of this image:'.
+\b
+- dedup: Identifies and optionally filters duplicate images in extraction.
+ Options:
+ - content_type (str): Content type to deduplicate ('image').
+ - filter (bool): When set to True, duplicates will be filtered, otherwise, an info message will be added.
+\b
+- embed: Computes embeddings on multimodal extractions.
+ Options:
+ - filter_errors (bool): Flag to filter embedding errors. Optional.
+ - tables (bool): Flag to create embeddings for table extractions. Optional.
+ - text (bool): Flag to create embeddings for text extractions. Optional.
\b
- extract: Extracts content from documents, customizable per document type.
Can be specified multiple times for different 'document_type' values.
Options:
- document_type (str): Document format ('pdf', 'docx', 'pptx', 'html', 'xml', 'excel', 'csv', 'parquet'). Required.
- - extract_method (str): Extraction technique. Defaults are smartly chosen based on 'document_type'.
- - extract_text (bool): Enables text extraction. Default: False.
+ - extract_charts (bool): Enables chart extraction. Default: False.
- extract_images (bool): Enables image extraction. Default: False.
+ - extract_method (str): Extraction technique. Defaults are smartly chosen based on 'document_type'.
- extract_tables (bool): Enables table extraction. Default: False.
- - extract_charts (bool): Enables chart extraction. Default: False.
- - text_depth (str): Text extraction granularity ('document', 'page'). Default: 'document'.
+ - extract_text (bool): Enables text extraction. Default: False.
+ - text_depth (str): Text extraction granularity ('document', 'page'). Default: 'document'.
Note: this will affect the granularity of text extraction, and the associated metadata. ie. 'page' will extract
text per page and you will get page-level metadata, 'document' will extract text for the entire document so
elements like page numbers will not be associated with individual text elements.
\b
-- store: Stores any images extracted from documents.
- Options:
- - structured (bool): Flag to write extracted charts and tables to object store.
- - images (bool): Flag to write extracted images to object store.
- - store_method (str): Storage type ('minio', ). Required.
-\b
-- caption: Attempts to extract captions for images extracted from documents. Note: this is not generative, but rather a
- simple extraction.
- Options:
- N/A
-\b
-- dedup: Identifies and optionally filters duplicate images in extraction.
+- filter: Identifies and optionally filters images above or below scale thresholds.
Options:
- - content_type (str): Content type to deduplicate ('image')
- - filter (bool): When set to True, duplicates will be filtered, otherwise, an info message will be added.
+ - content_type (str): Content type to filter ('image').
+ - filter (bool): When set to True, filtered images will be excluded; otherwise, an info message will be added.
+ - max_aspect_ratio (Union[float, int]): Maximum allowable aspect ratio of extracted image.
+ - min_aspect_ratio (Union[float, int]): Minimum allowable aspect ratio of extracted image.
+ - min_size (int): Minimum allowable size of extracted image.
\b
-- filter: Identifies and optionally filters images above or below scale thresholds.
+- split: Divides documents according to specified criteria.
Options:
- - content_type (str): Content type to deduplicate ('image')
- - min_size: (Union[float, int]): Minimum allowable size of extracted image.
- - max_aspect_ratio: (Union[float, int]): Maximum allowable aspect ratio of extracted image.
- - min_aspect_ratio: (Union[float, int]): Minimum allowable aspect ratio of extracted image.
- - filter (bool): When set to True, duplicates will be filtered, otherwise, an info message will be added.
+ - max_character_length (int): Maximum segment character count. No default.
+ - sentence_window_size (int): Sentence window size. No default.
+ - split_by (str): Criteria ('page', 'size', 'word', 'sentence'). No default.
+ - split_length (int): Segment length. No default.
+ - split_overlap (int): Segment overlap. No default.
\b
-- embed: Computes embeddings on multimodal extractions.
+- store: Stores any images extracted from documents.
Options:
- - text (bool): Flag to create embeddings for text extractions. Optional.
- - tables (bool): Flag to creae embeddings for table extractions. Optional.
- - filter_errors (bool): Flag to filter embedding errors. Optional.
+ - images (bool): Flag to write extracted images to object store.
+ - structured (bool): Flag to write extracted charts and tables to object store.
+ - store_method (str): Storage type ('minio', ). Required.
\b
- vdb_upload: Uploads extraction embeddings to vector database.
\b
diff --git a/client/src/nv_ingest_client/primitives/tasks/caption.py b/client/src/nv_ingest_client/primitives/tasks/caption.py
index 43db01bd..b2883193 100644
--- a/client/src/nv_ingest_client/primitives/tasks/caption.py
+++ b/client/src/nv_ingest_client/primitives/tasks/caption.py
@@ -7,7 +7,7 @@
# pylint: disable=too-many-arguments
import logging
-from typing import Dict
+from typing import Dict, Optional
from pydantic import BaseModel
@@ -17,29 +17,56 @@
class CaptionTaskSchema(BaseModel):
+ api_key: Optional[str] = None
+ endpoint_url: Optional[str] = None
+ prompt: Optional[str] = None
+
class Config:
extra = "forbid"
class CaptionTask(Task):
def __init__(
- self,
+ self,
+ api_key: str = None,
+ endpoint_url: str = None,
+ prompt: str = None,
) -> None:
super().__init__()
+ self._api_key = api_key
+ self._endpoint_url = endpoint_url
+ self._prompt = prompt
+
def __str__(self) -> str:
"""
Returns a string with the object's config and run time state
"""
info = ""
+ info += "Image Caption Task:\n"
+
+ if (self._api_key):
+ info += f" api_key: [redacted]\n"
+ if (self._endpoint_url):
+ info += f" endpoint_url: {self._endpoint_url}\n"
+ if (self._prompt):
+ info += f" prompt: {self._prompt}\n"
+
return info
def to_dict(self) -> Dict:
"""
Convert to a dict for submission to redis
"""
- task_properties = {
- "content_type": "image",
- }
+ task_properties = {}
+
+ if (self._api_key):
+ task_properties["api_key"] = self._api_key
+
+ if (self._endpoint_url):
+ task_properties["endpoint_url"] = self._endpoint_url
+
+ if (self._prompt):
+ task_properties["prompt"] = self._prompt
return {"type": "caption", "task_properties": task_properties}
diff --git a/client/src/nv_ingest_client/primitives/tasks/extract.py b/client/src/nv_ingest_client/primitives/tasks/extract.py
index 36f43d04..9070b6fe 100644
--- a/client/src/nv_ingest_client/primitives/tasks/extract.py
+++ b/client/src/nv_ingest_client/primitives/tasks/extract.py
@@ -34,34 +34,46 @@
ADOBE_CLIENT_SECRET = os.environ.get("ADOBE_CLIENT_SECRET", None)
_DEFAULT_EXTRACTOR_MAP = {
- "pdf": "pdfium",
+ "csv": "pandas",
"docx": "python_docx",
- "pptx": "python_pptx",
- "html": "beautifulsoup",
- "xml": "lxml",
"excel": "openpyxl",
- "csv": "pandas",
+ "html": "beautifulsoup",
+ "jpeg": "image",
+ "jpg": "image",
"parquet": "pandas",
+ "pdf": "pdfium",
+ "png": "image",
+ "pptx": "python_pptx",
+ "svg": "image",
+ "tiff": "image",
+ "xml": "lxml",
}
_Type_Extract_Method_PDF = Literal[
- "pdfium",
+ "adobe",
"doughnut",
"haystack",
+ "llama_parse",
+ "pdfium",
"tika",
"unstructured_io",
- "llama_parse",
- "adobe",
]
_Type_Extract_Method_DOCX = Literal["python_docx", "haystack", "unstructured_local", "unstructured_service"]
_Type_Extract_Method_PPTX = Literal["python_pptx", "haystack", "unstructured_local", "unstructured_service"]
+_Type_Extract_Method_Image = Literal["image"]
+
_Type_Extract_Method_Map = {
- "pdf": get_args(_Type_Extract_Method_PDF),
"docx": get_args(_Type_Extract_Method_DOCX),
+ "jpeg": get_args(_Type_Extract_Method_Image),
+ "jpg": get_args(_Type_Extract_Method_Image),
+ "pdf": get_args(_Type_Extract_Method_PDF),
+ "png": get_args(_Type_Extract_Method_Image),
"pptx": get_args(_Type_Extract_Method_PPTX),
+ "svg": get_args(_Type_Extract_Method_Image),
+ "tiff": get_args(_Type_Extract_Method_Image),
}
_Type_Extract_Tables_Method_PDF = Literal["yolox", "pdfium"]
@@ -77,12 +89,14 @@
}
+
+
class ExtractTaskSchema(BaseModel):
document_type: str
extract_method: str = None # Initially allow None to set a smart default
- extract_text: bool = (True,)
- extract_images: bool = (True,)
- extract_tables: bool = False
+ extract_text: bool = True
+ extract_images: bool = True
+ extract_tables: bool = True
extract_tables_method: str = "yolox"
extract_charts: Optional[bool] = None # Initially allow None to set a smart default
text_depth: str = "document"
diff --git a/client/src/nv_ingest_client/util/file_processing/extract.py b/client/src/nv_ingest_client/util/file_processing/extract.py
index 09d20584..97851481 100644
--- a/client/src/nv_ingest_client/util/file_processing/extract.py
+++ b/client/src/nv_ingest_client/util/file_processing/extract.py
@@ -21,16 +21,17 @@
# Enums
class DocumentTypeEnum(str, Enum):
- pdf = "pdf"
- txt = "text"
+ bmp = "bmp"
docx = "docx"
- pptx = "pptx"
+ html = "html"
jpeg = "jpeg"
- bmp = "bmp"
+ md = "md"
+ pdf = "pdf"
png = "png"
+ pptx = "pptx"
svg = "svg"
- html = "html"
- md = "md"
+ tiff = "tiff"
+ txt = "text"
# Maps MIME types to DocumentTypeEnum
@@ -49,19 +50,20 @@ class DocumentTypeEnum(str, Enum):
# Maps file extensions to DocumentTypeEnum
EXTENSION_TO_DOCUMENT_TYPE = {
- "pdf": DocumentTypeEnum.pdf,
- "txt": DocumentTypeEnum.txt,
- "docx": DocumentTypeEnum.docx,
- "pptx": DocumentTypeEnum.pptx,
- "jpg": DocumentTypeEnum.jpeg,
- "jpeg": DocumentTypeEnum.jpeg,
"bmp": DocumentTypeEnum.bmp,
- "png": DocumentTypeEnum.png,
- "svg": DocumentTypeEnum.svg,
+ "docx": DocumentTypeEnum.docx,
"html": DocumentTypeEnum.html,
+ "jpeg": DocumentTypeEnum.jpeg,
+ "jpg": DocumentTypeEnum.jpeg,
+ "json": DocumentTypeEnum.txt,
"md": DocumentTypeEnum.txt,
+ "pdf": DocumentTypeEnum.pdf,
+ "png": DocumentTypeEnum.png,
+ "pptx": DocumentTypeEnum.pptx,
"sh": DocumentTypeEnum.txt,
- "json": DocumentTypeEnum.txt,
+ "svg": DocumentTypeEnum.svg,
+ "tiff": DocumentTypeEnum.tiff,
+ "txt": DocumentTypeEnum.txt,
# Add more as needed
}
diff --git a/docker-compose.yaml b/docker-compose.yaml
index a8e2476e..29506958 100644
--- a/docker-compose.yaml
+++ b/docker-compose.yaml
@@ -117,7 +117,7 @@ services:
runtime: nvidia
nv-ingest-ms-runtime:
- image: nvcr.io/ohlfw0olaadg/ea-participants/nv-ingest:24.10
+ image: nvcr.io/ohlfw0olaadg/ea-participants/nv-ingest:24.10.1
build:
context: ${NV_INGEST_ROOT:-.}
dockerfile: "./Dockerfile"
@@ -157,6 +157,7 @@ services:
- YOLOX_GRPC_ENDPOINT=yolox:8001
- YOLOX_HTTP_ENDPOINT=http://yolox:8000/v1/infer
- YOLOX_INFER_PROTOCOL=grpc
+ - VLM_CAPTION_ENDPOINT=https://ai.api.nvidia.com/v1/gr/meta/llama-3.2-90b-vision-instruct/chat/completions
healthcheck:
test: curl --fail http://nv-ingest-ms-runtime:7670/v1/health/ready || exit 1
interval: 10s
diff --git a/requirements.txt b/requirements.txt
index 43c4a6ce..5362d5d5 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,6 +1,7 @@
aiohttp==3.9.4
charset-normalizer
click
+opencv-python
dataclasses
farm-haystack[ocr,inference,pdf,preprocessing,file-conversion]
fastapi==0.109.1
@@ -35,3 +36,4 @@ tabulate
torchvision==0.18.0
unstructured-client==0.23.3
uvicorn==0.24.0-post.1
+Wand==0.6.13
diff --git a/src/nv_ingest/extraction_workflows/image/__init__.py b/src/nv_ingest/extraction_workflows/image/__init__.py
new file mode 100644
index 00000000..7186f69d
--- /dev/null
+++ b/src/nv_ingest/extraction_workflows/image/__init__.py
@@ -0,0 +1,3 @@
+from .image_handlers import image_data_extractor as image
+
+__all__ = ["image"]
diff --git a/src/nv_ingest/extraction_workflows/image/image_handlers.py b/src/nv_ingest/extraction_workflows/image/image_handlers.py
new file mode 100644
index 00000000..810276f9
--- /dev/null
+++ b/src/nv_ingest/extraction_workflows/image/image_handlers.py
@@ -0,0 +1,447 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES.
+# All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+
+import json
+# Copyright (c) 2024, NVIDIA CORPORATION.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import traceback
+from datetime import datetime
+
+from typing import List, Dict
+from typing import Optional
+from typing import Tuple
+
+from wand.image import Image as WandImage
+from PIL import Image
+import io
+
+import numpy as np
+import tritonclient.grpc as grpcclient
+
+from nv_ingest.extraction_workflows.pdf.doughnut_utils import crop_image
+import nv_ingest.util.nim.yolox as yolox_utils
+from nv_ingest.schemas.image_extractor_schema import ImageExtractorSchema
+from nv_ingest.schemas.metadata_schema import AccessLevelEnum
+from nv_ingest.util.image_processing.transforms import numpy_to_base64
+from nv_ingest.util.nim.helpers import create_inference_client
+from nv_ingest.util.nim.helpers import perform_model_inference
+from nv_ingest.util.pdf.metadata_aggregators import CroppedImageWithContent, construct_image_metadata_from_pdf_image, \
+ construct_image_metadata_from_base64
+from nv_ingest.util.pdf.metadata_aggregators import construct_table_and_chart_metadata
+
+logger = logging.getLogger(__name__)
+
+YOLOX_MAX_BATCH_SIZE = 8
+YOLOX_MAX_WIDTH = 1536
+YOLOX_MAX_HEIGHT = 1536
+YOLOX_NUM_CLASSES = 3
+YOLOX_CONF_THRESHOLD = 0.01
+YOLOX_IOU_THRESHOLD = 0.5
+YOLOX_MIN_SCORE = 0.1
+YOLOX_FINAL_SCORE = 0.48
+
+RAW_FILE_FORMATS = ["jpeg", "jpg", "png", "tiff"]
+PREPROC_FILE_FORMATS = ["svg"]
+
+SUPPORTED_FILE_TYPES = RAW_FILE_FORMATS + ["svg"]
+
+
+def load_and_preprocess_image(image_stream: io.BytesIO) -> np.ndarray:
+ """
+ Loads and preprocesses a JPEG, JPG, or PNG image from a bytestream.
+
+ Parameters
+ ----------
+ image_stream : io.BytesIO
+ A bytestream of the image file.
+
+ Returns
+ -------
+ np.ndarray
+ Preprocessed image as a numpy array.
+ """
+ # Load image from the byte stream
+ processed_image = Image.open(image_stream).convert("RGB")
+
+ # Convert image to numpy array and normalize pixel values
+ image_array = np.asarray(processed_image, dtype=np.float32)
+
+ return image_array
+
+
+def convert_svg_to_bitmap(image_stream: io.BytesIO) -> np.ndarray:
+ """
+ Converts an SVG image from a bytestream to a bitmap format.
+
+ Parameters
+ ----------
+ image_stream : io.BytesIO
+ A bytestream of the SVG file.
+
+ Returns
+ -------
+ np.ndarray
+ Preprocessed image as a numpy array in bitmap format.
+ """
+ # Convert SVG to PNG using Wand (ImageMagick)
+ with WandImage(blob=image_stream.read(), format="svg") as img:
+ img.format = "png"
+ png_data = img.make_blob()
+
+ # Reload the PNG as a PIL Image
+ processed_image = Image.open(io.BytesIO(png_data)).convert("RGB")
+
+ # Convert image to numpy array and normalize pixel values
+ image_array = np.asarray(processed_image, dtype=np.float32)
+
+ return image_array
+
+
+# TODO(Devin): Move to common file
+def process_inference_results(
+ output_array: np.ndarray,
+ original_image_shapes: List[Tuple[int, int]],
+ num_classes: int,
+ conf_thresh: float,
+ iou_thresh: float,
+ min_score: float,
+ final_thresh: float,
+):
+ """
+ Process the model output to generate detection results and expand bounding boxes.
+
+ Parameters
+ ----------
+ output_array : np.ndarray
+ The raw output from the model inference.
+ original_image_shapes : List[Tuple[int, int]]
+ The shapes of the original images before resizing, used for scaling bounding boxes.
+ num_classes : int
+ The number of classes the model can detect.
+ conf_thresh : float
+ The confidence threshold for detecting objects.
+ iou_thresh : float
+ The Intersection Over Union (IoU) threshold for non-maximum suppression.
+ min_score : float
+ The minimum score for keeping a detection.
+ final_thresh: float
+ Threshold for keeping a bounding box applied after postprocessing.
+
+
+ Returns
+ -------
+ List[dict]
+ A list of dictionaries, each containing processed detection results including expanded bounding boxes.
+
+ Notes
+ -----
+ This function applies non-maximum suppression to the model's output and scales the bounding boxes back to the
+ original image size.
+
+ Examples
+ --------
+ >>> output_array = np.random.rand(2, 100, 85)
+ >>> original_image_shapes = [(1536, 1536), (1536, 1536)]
+ >>> results = process_inference_results(output_array, original_image_shapes, 80, 0.5, 0.5, 0.1)
+ >>> len(results)
+ 2
+ """
+ pred = yolox_utils.postprocess_model_prediction(
+ output_array, num_classes, conf_thresh, iou_thresh, class_agnostic=True
+ )
+ results = yolox_utils.postprocess_results(pred, original_image_shapes, min_score=min_score)
+ logger.debug(f"Number of results: {len(results)}")
+ logger.debug(f"Results: {results}")
+
+ annotation_dicts = [yolox_utils.expand_chart_bboxes(annotation_dict) for annotation_dict in results]
+ inference_results = []
+
+ # Filter out bounding boxes below the final threshold
+ for annotation_dict in annotation_dicts:
+ new_dict = {}
+ if "table" in annotation_dict:
+ new_dict["table"] = [bb for bb in annotation_dict["table"] if bb[4] >= final_thresh]
+ if "chart" in annotation_dict:
+ new_dict["chart"] = [bb for bb in annotation_dict["chart"] if bb[4] >= final_thresh]
+ if "title" in annotation_dict:
+ new_dict["title"] = annotation_dict["title"]
+ inference_results.append(new_dict)
+
+ return inference_results
+
+
+def extract_table_and_chart_images(
+ annotation_dict: Dict[str, List[List[float]]],
+ original_image: np.ndarray,
+ page_idx: int,
+ tables_and_charts: List[Tuple[int, "CroppedImageWithContent"]],
+) -> None:
+ """
+ Handle the extraction of tables and charts from the inference results and run additional model inference.
+
+ Parameters
+ ----------
+ annotation_dict : dict of {str : list of list of float}
+ A dictionary containing detected objects and their bounding boxes. Keys should include "table" and "chart",
+ and each key's value should be a list of bounding boxes, with each bounding box represented as a list of floats.
+ original_image : np.ndarray
+ The original image from which objects were detected, expected to be in RGB format with shape (H, W, 3).
+ page_idx : int
+ The index of the current page being processed.
+ tables_and_charts : list of tuple of (int, CroppedImageWithContent)
+ A list to which extracted tables and charts will be appended. Each item in the list is a tuple where the first
+ element is the page index, and the second is an instance of CroppedImageWithContent representing a cropped image
+ and associated metadata.
+
+ Returns
+ -------
+ None
+
+ Notes
+ -----
+ This function iterates over detected objects labeled as "table" or "chart". For each object, it crops the original
+ image according to the bounding box coordinates, then creates an instance of `CroppedImageWithContent` containing
+ the cropped image and metadata, and appends it to `tables_and_charts`.
+
+ Examples
+ --------
+ >>> annotation_dict = {"table": [[0.1, 0.1, 0.5, 0.5, 0.8]], "chart": [[0.6, 0.6, 0.9, 0.9, 0.9]]}
+ >>> original_image = np.random.rand(1536, 1536, 3)
+ >>> tables_and_charts = []
+ >>> extract_table_and_chart_images(annotation_dict, original_image, 0, tables_and_charts)
+ >>> len(tables_and_charts)
+ 2
+ """
+
+ width, height, *_ = original_image.shape
+ for label in ["table", "chart"]:
+ if not annotation_dict or label not in annotation_dict:
+ continue
+
+ objects = annotation_dict[label]
+ for idx, bboxes in enumerate(objects):
+ *bbox, _ = bboxes
+ h1, w1, h2, w2 = np.array(bbox) * np.array([height, width, height, width])
+
+ base64_img = crop_image(original_image, (int(h1), int(w1), int(h2), int(w2)))
+
+ table_data = CroppedImageWithContent(
+ content="",
+ image=base64_img,
+ bbox=(int(w1), int(h1), int(w2), int(h2)),
+ max_width=width,
+ max_height=height,
+ type_string=label,
+ )
+ tables_and_charts.append((page_idx, table_data))
+
+
+def extract_tables_and_charts_from_image(
+ image: np.ndarray,
+ config: ImageExtractorSchema,
+ num_classes: int = YOLOX_NUM_CLASSES,
+ conf_thresh: float = YOLOX_CONF_THRESHOLD,
+ iou_thresh: float = YOLOX_IOU_THRESHOLD,
+ min_score: float = YOLOX_MIN_SCORE,
+ final_thresh: float = YOLOX_FINAL_SCORE,
+ trace_info: Optional[List] = None,
+) -> List[CroppedImageWithContent]:
+ """
+ Extract tables and charts from a single image using an ensemble of image-based models.
+
+ This function processes a single image to detect and extract tables and charts.
+ It uses a sequence of models hosted on different inference servers to achieve this.
+
+ Parameters
+ ----------
+ image : np.ndarray
+ A preprocessed image array for table and chart detection.
+ config : ImageExtractorSchema
+ Configuration for the inference client, including endpoint URLs and authentication.
+ num_classes : int, optional
+ The number of classes the model is trained to detect (default is 3).
+ conf_thresh : float, optional
+ The confidence threshold for detection (default is 0.01).
+ iou_thresh : float, optional
+ The Intersection Over Union (IoU) threshold for non-maximum suppression (default is 0.5).
+ min_score : float, optional
+ The minimum score threshold for considering a detection valid (default is 0.1).
+ final_thresh: float, optional
+ Threshold for keeping a bounding box applied after postprocessing (default is 0.48).
+ trace_info : Optional[List], optional
+ Tracing information for logging or debugging purposes.
+
+ Returns
+ -------
+ List[CroppedImageWithContent]
+ A list of `CroppedImageWithContent` objects representing detected tables or charts,
+ each containing metadata about the detected region.
+ """
+ tables_and_charts = []
+
+ yolox_client = None
+ try:
+ yolox_client = create_inference_client(config.yolox_endpoints, config.auth_token)
+
+ input_image = yolox_utils.prepare_images_for_inference([image])
+ image_shape = image.shape
+
+ output_array = perform_model_inference(yolox_client, "yolox", input_image, trace_info=trace_info)
+
+ yolox_annotated_detections = process_inference_results(
+ output_array, [image_shape], num_classes, conf_thresh, iou_thresh, min_score, final_thresh
+ )
+
+ for annotation_dict in yolox_annotated_detections:
+ extract_table_and_chart_images(
+ annotation_dict,
+ image,
+ page_idx=0, # Single image treated as one page
+ tables_and_charts=tables_and_charts,
+ )
+
+ except Exception as e:
+ logger.error(f"Error during table/chart extraction from image: {str(e)}")
+ traceback.print_exc()
+ raise e
+ finally:
+ if isinstance(yolox_client, grpcclient.InferenceServerClient):
+ logger.debug("Closing YOLOX inference client.")
+ yolox_client.close()
+
+ logger.debug(f"Extracted {len(tables_and_charts)} tables and charts from image.")
+
+ return tables_and_charts
+
+
+def image_data_extractor(image_stream,
+ document_type: str,
+ extract_text: bool,
+ extract_images: bool,
+ extract_tables: bool,
+ extract_charts: bool,
+ trace_info: dict = None,
+ **kwargs):
+ """
+ Helper function to extract text, images, tables, and charts from an image bytestream.
+
+ Parameters
+ ----------
+ image_stream : io.BytesIO
+ A bytestream for the image file.
+ document_type : str
+ Specifies the type of the image document ('png', 'jpeg', 'jpg', 'svg', 'tiff').
+ extract_text : bool
+ Specifies whether to extract text.
+ extract_images : bool
+ Specifies whether to extract images.
+ extract_tables : bool
+ Specifies whether to extract tables.
+ extract_charts : bool
+ Specifies whether to extract charts.
+ **kwargs
+ Additional extraction parameters.
+
+ Returns
+ -------
+ list
+ A list of extracted data items.
+ """
+ logger.debug(f"Extracting {document_type.upper()} image with image extractor.")
+
+ if (document_type not in SUPPORTED_FILE_TYPES):
+ raise ValueError(f"Unsupported document type: {document_type}")
+
+ row_data = kwargs.get("row_data")
+ source_id = row_data.get("source_id", "unknown_source")
+
+ # Metadata extraction setup
+ base_unified_metadata = row_data.get(kwargs.get("metadata_column", "metadata"), {})
+ current_iso_datetime = datetime.now().isoformat()
+ source_metadata = {
+ "source_name": f"{source_id}_{document_type}",
+ "source_id": source_id,
+ "source_location": row_data.get("source_location", ""),
+ "source_type": document_type,
+ "collection_id": row_data.get("collection_id", ""),
+ "date_created": row_data.get("date_created", current_iso_datetime),
+ "last_modified": row_data.get("last_modified", current_iso_datetime),
+ "summary": f"Raw {document_type} image extracted from source {source_id}",
+ "partition_id": row_data.get("partition_id", -1),
+ "access_level": row_data.get("access_level", AccessLevelEnum.LEVEL_1),
+ }
+
+ # Prepare for extraction
+ extracted_data = []
+ logger.debug(f"Extract text: {extract_text} (not supported yet for raw images)")
+ logger.debug(f"Extract images: {extract_images} (not supported yet for raw images)")
+ logger.debug(f"Extract tables: {extract_tables}")
+ logger.debug(f"Extract charts: {extract_charts}")
+
+ # Preprocess based on image type
+ if (document_type in RAW_FILE_FORMATS):
+ logger.debug(f"Loading and preprocessing {document_type} image.")
+ image_array = load_and_preprocess_image(image_stream)
+ elif (document_type in PREPROC_FILE_FORMATS):
+ logger.debug(f"Converting {document_type} to bitmap.")
+ image_array = convert_svg_to_bitmap(image_stream)
+ else:
+ raise ValueError(f"Unsupported document type: {document_type}")
+
+ # Text extraction stub
+ if extract_text:
+ # Future function for text extraction based on document_type
+ logger.warning("Text extraction is not supported for raw images.")
+
+ # Image extraction stub
+ if extract_images:
+ # Placeholder for image-specific extraction process
+ extracted_data.append(
+ construct_image_metadata_from_base64(
+ numpy_to_base64(image_array),
+ page_idx=0, # Single image treated as one page
+ page_count=1,
+ source_metadata=source_metadata,
+ base_unified_metadata=base_unified_metadata,
+ )
+ )
+
+ # Table and chart extraction
+ if extract_tables or extract_charts:
+ try:
+ tables_and_charts = extract_tables_and_charts_from_image(
+ image_array,
+ config=kwargs.get("image_extraction_config"),
+ trace_info=trace_info,
+ )
+ logger.debug(f"Extracted table/chart data from image")
+ for _, table_chart_data in tables_and_charts:
+ extracted_data.append(
+ construct_table_and_chart_metadata(
+ table_chart_data,
+ page_idx=0, # Single image treated as one page
+ page_count=1,
+ source_metadata=source_metadata,
+ base_unified_metadata=base_unified_metadata,
+ )
+ )
+ except Exception as e:
+ logger.error(f"Error extracting tables/charts from image: {e}")
+
+ logger.debug(f"Extracted {len(extracted_data)} items from the image.")
+
+ return extracted_data
diff --git a/src/nv_ingest/extraction_workflows/pdf/__init__.py b/src/nv_ingest/extraction_workflows/pdf/__init__.py
index 51ecd39d..ff752ef5 100644
--- a/src/nv_ingest/extraction_workflows/pdf/__init__.py
+++ b/src/nv_ingest/extraction_workflows/pdf/__init__.py
@@ -6,7 +6,7 @@
from nv_ingest.extraction_workflows.pdf.adobe_helper import adobe
from nv_ingest.extraction_workflows.pdf.doughnut_helper import doughnut
from nv_ingest.extraction_workflows.pdf.llama_parse_helper import llama_parse
-from nv_ingest.extraction_workflows.pdf.pdfium_helper import pdfium
+from nv_ingest.extraction_workflows.pdf.pdfium_helper import pdfium_extractor as pdfium
from nv_ingest.extraction_workflows.pdf.tika_helper import tika
from nv_ingest.extraction_workflows.pdf.unstructured_io_helper import unstructured_io
diff --git a/src/nv_ingest/extraction_workflows/pdf/doughnut_helper.py b/src/nv_ingest/extraction_workflows/pdf/doughnut_helper.py
index b0239e93..bed95c97 100644
--- a/src/nv_ingest/extraction_workflows/pdf/doughnut_helper.py
+++ b/src/nv_ingest/extraction_workflows/pdf/doughnut_helper.py
@@ -41,7 +41,7 @@
from nv_ingest.util.image_processing.transforms import numpy_to_base64
from nv_ingest.util.pdf.metadata_aggregators import Base64Image
from nv_ingest.util.pdf.metadata_aggregators import LatexTable
-from nv_ingest.util.pdf.metadata_aggregators import construct_image_metadata
+from nv_ingest.util.pdf.metadata_aggregators import construct_image_metadata_from_pdf_image
from nv_ingest.util.pdf.metadata_aggregators import construct_text_metadata
from nv_ingest.util.pdf.metadata_aggregators import extract_pdf_metadata
from nv_ingest.util.pdf.pdfium import pdfium_pages_to_numpy
@@ -221,7 +221,7 @@ def doughnut(pdf_stream, extract_text: bool, extract_images: bool, extract_table
if extract_images:
for image in accumulated_images:
extracted_data.append(
- construct_image_metadata(
+ construct_image_metadata_from_pdf_image(
image,
page_idx,
pdf_metadata.page_count,
diff --git a/src/nv_ingest/extraction_workflows/pdf/pdfium_helper.py b/src/nv_ingest/extraction_workflows/pdf/pdfium_helper.py
index 7a1de0f1..9216c12f 100644
--- a/src/nv_ingest/extraction_workflows/pdf/pdfium_helper.py
+++ b/src/nv_ingest/extraction_workflows/pdf/pdfium_helper.py
@@ -27,8 +27,8 @@
import numpy as np
import pypdfium2 as libpdfium
import tritonclient.grpc as grpcclient
+import nv_ingest.util.nim.yolox as yolox_utils
-from nv_ingest.extraction_workflows.pdf import yolox_utils
from nv_ingest.schemas.metadata_schema import AccessLevelEnum
from nv_ingest.schemas.metadata_schema import TextTypeEnum
from nv_ingest.schemas.pdf_extractor_schema import PDFiumConfigSchema
@@ -36,11 +36,12 @@
from nv_ingest.util.image_processing.transforms import numpy_to_base64
from nv_ingest.util.nim.helpers import create_inference_client
from nv_ingest.util.nim.helpers import perform_model_inference
+from nv_ingest.util.nim.yolox import prepare_images_for_inference
from nv_ingest.util.pdf.metadata_aggregators import Base64Image
-from nv_ingest.util.pdf.metadata_aggregators import CroppedImageWithContent
-from nv_ingest.util.pdf.metadata_aggregators import construct_image_metadata
+from nv_ingest.util.pdf.metadata_aggregators import construct_image_metadata_from_pdf_image
from nv_ingest.util.pdf.metadata_aggregators import construct_table_and_chart_metadata
from nv_ingest.util.pdf.metadata_aggregators import construct_text_metadata
+from nv_ingest.util.pdf.metadata_aggregators import CroppedImageWithContent
from nv_ingest.util.pdf.metadata_aggregators import extract_pdf_metadata
from nv_ingest.util.pdf.pdfium import PDFIUM_PAGEOBJ_MAPPING
from nv_ingest.util.pdf.pdfium import pdfium_pages_to_numpy
@@ -179,37 +180,6 @@ def extract_tables_and_charts_using_image_ensemble(
return tables_and_charts
-def prepare_images_for_inference(images: List[np.ndarray]) -> np.ndarray:
- """
- Prepare a list of images for model inference by resizing and reordering axes.
-
- Parameters
- ----------
- images : List[np.ndarray]
- A list of image arrays to be prepared for inference.
-
- Returns
- -------
- np.ndarray
- A numpy array suitable for model input, with the shape reordered to match the expected input format.
-
- Notes
- -----
- The images are resized to 1024x1024 pixels and the axes are reordered to match the expected input shape for
- the model (batch, channels, height, width).
-
- Examples
- --------
- >>> images = [np.random.rand(1536, 1536, 3) for _ in range(2)]
- >>> input_array = prepare_images_for_inference(images)
- >>> input_array.shape
- (2, 3, 1024, 1024)
- """
-
- resized_images = [yolox_utils.resize_image(image, (1024, 1024)) for image in images]
- return np.einsum("bijk->bkij", resized_images).astype(np.float32)
-
-
def process_inference_results(
output_array: np.ndarray,
original_image_shapes: List[Tuple[int, int]],
@@ -292,7 +262,7 @@ def extract_table_and_chart_images(
Parameters
----------
- annotation_dict : dict
+ annotation_dict : dict/
A dictionary containing detected objects and their bounding boxes.
original_image : np.ndarray
The original image from which objects were detected.
@@ -337,7 +307,7 @@ def extract_table_and_chart_images(
# Define a helper function to use unstructured-io to extract text from a base64
# encoded bytestream PDF
-def pdfium(
+def pdfium_extractor(
pdf_stream,
extract_text: bool,
extract_images: bool,
@@ -453,7 +423,7 @@ def pdfium(
if obj_type == "IMAGE":
try:
# Attempt to retrieve the image bitmap
- image_numpy: np.ndarray = pdfium_try_get_bitmap_as_numpy(obj) # noqa
+ image_numpy: np.ndarray = pdfium_try_get_bitmap_as_numpy(obj) # noqa
image_base64: str = numpy_to_base64(image_numpy)
image_bbox = obj.get_pos()
image_size = obj.get_size()
@@ -462,7 +432,7 @@ def pdfium(
max_width=page_width, max_height=page_height
)
- extracted_image_data = construct_image_metadata(
+ extracted_image_data = construct_image_metadata_from_pdf_image(
image_data,
page_idx,
pdf_metadata.page_count,
diff --git a/src/nv_ingest/schemas/image_caption_extraction_schema.py b/src/nv_ingest/schemas/image_caption_extraction_schema.py
index 10584eed..4f6cdb3f 100644
--- a/src/nv_ingest/schemas/image_caption_extraction_schema.py
+++ b/src/nv_ingest/schemas/image_caption_extraction_schema.py
@@ -7,9 +7,9 @@
class ImageCaptionExtractionSchema(BaseModel):
- batch_size: int = 8
- caption_classifier_model_name: str = "deberta_large"
- endpoint_url: str = "triton:8001"
+ api_key: str
+ endpoint_url: str = "https://ai.api.nvidia.com/v1/gr/meta/llama-3.2-90b-vision-instruct/chat/completions"
+ prompt: str = "Caption the content of this image:"
raise_on_failure: bool = False
class Config:
diff --git a/src/nv_ingest/schemas/image_extractor_schema.py b/src/nv_ingest/schemas/image_extractor_schema.py
new file mode 100644
index 00000000..21e6b4c6
--- /dev/null
+++ b/src/nv_ingest/schemas/image_extractor_schema.py
@@ -0,0 +1,127 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES.
+# All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+
+
+import logging
+from typing import Optional
+from typing import Tuple
+
+from pydantic import BaseModel
+from pydantic import root_validator
+
+logger = logging.getLogger(__name__)
+
+
+class ImageConfigSchema(BaseModel):
+ """
+ Configuration schema for image extraction endpoints and options.
+
+ Parameters
+ ----------
+ auth_token : Optional[str], default=None
+ Authentication token required for secure services.
+
+ yolox_endpoints : Tuple[str, str]
+ A tuple containing the gRPC and HTTP services for the yolox endpoint.
+ Either the gRPC or HTTP service can be empty, but not both.
+
+ Methods
+ -------
+ validate_endpoints(values)
+ Validates that at least one of the gRPC or HTTP services is provided for each endpoint.
+
+ Raises
+ ------
+ ValueError
+ If both gRPC and HTTP services are empty for any endpoint.
+
+ Config
+ ------
+ extra : str
+ Pydantic config option to forbid extra fields.
+ """
+
+ auth_token: Optional[str] = None
+
+ yolox_endpoints: Tuple[Optional[str], Optional[str]] = (None, None)
+ yolox_infer_protocol: str = ""
+
+ @root_validator(pre=True)
+ def validate_endpoints(cls, values):
+ """
+ Validates the gRPC and HTTP services for all endpoints.
+
+ Parameters
+ ----------
+ values : dict
+ Dictionary containing the values of the attributes for the class.
+
+ Returns
+ -------
+ dict
+ The validated dictionary of values.
+
+ Raises
+ ------
+ ValueError
+ If both gRPC and HTTP services are empty for any endpoint.
+ """
+
+ def clean_service(service):
+ """Set service to None if it's an empty string or contains only spaces or quotes."""
+ if service is None or not service.strip() or service.strip(" \"'") == "":
+ return None
+ return service
+
+ for model_name in ["yolox"]:
+ endpoint_name = f"{model_name}_endpoints"
+ grpc_service, http_service = values.get(endpoint_name)
+ grpc_service = clean_service(grpc_service)
+ http_service = clean_service(http_service)
+
+ if not grpc_service and not http_service:
+ raise ValueError(f"Both gRPC and HTTP services cannot be empty for {endpoint_name}.")
+
+ values[endpoint_name] = (grpc_service, http_service)
+
+ protocol_name = f"{model_name}_infer_protocol"
+ protocol_value = values.get(protocol_name)
+ if not protocol_value:
+ protocol_value = "http" if http_service else "grpc" if grpc_service else ""
+ protocol_value = protocol_value.lower()
+ values[protocol_name] = protocol_value
+
+ return values
+
+ class Config:
+ extra = "forbid"
+
+
+class ImageExtractorSchema(BaseModel):
+ """
+ Configuration schema for the PDF extractor settings.
+
+ Parameters
+ ----------
+ max_queue_size : int, default=1
+ The maximum number of items allowed in the processing queue.
+
+ n_workers : int, default=16
+ The number of worker threads to use for processing.
+
+ raise_on_failure : bool, default=False
+ A flag indicating whether to raise an exception on processing failure.
+
+ image_extraction_config: Optional[ImageConfigSchema], default=None
+ Configuration schema for the image extraction stage.
+ """
+
+ max_queue_size: int = 1
+ n_workers: int = 16
+ raise_on_failure: bool = False
+
+ image_extraction_config: Optional[ImageConfigSchema] = None
+
+ class Config:
+ extra = "forbid"
diff --git a/src/nv_ingest/schemas/ingest_job_schema.py b/src/nv_ingest/schemas/ingest_job_schema.py
index 8b2dc7ef..9d001202 100644
--- a/src/nv_ingest/schemas/ingest_job_schema.py
+++ b/src/nv_ingest/schemas/ingest_job_schema.py
@@ -24,15 +24,16 @@
# Enums
class DocumentTypeEnum(str, Enum):
- pdf = "pdf"
- txt = "text"
+ bmp = "bmp"
docx = "docx"
- pptx = "pptx"
+ html = "html"
jpeg = "jpeg"
- bmp = "bmp"
+ pdf = "pdf"
png = "png"
+ pptx = "pptx"
svg = "svg"
- html = "html"
+ tiff = "tiff"
+ txt = "text"
class TaskTypeEnum(str, Enum):
@@ -94,9 +95,11 @@ class IngestTaskStoreSchema(BaseModelNoExt):
params: dict
+# All optional, the captioning stage requires default parameters, each of these are just overrides.
class IngestTaskCaptionSchema(BaseModelNoExt):
- content_type: str = "image"
- n_neighbors: int = 5
+ api_key: Optional[str]
+ endpoint_url: Optional[str]
+ prompt: Optional[str]
class IngestTaskFilterParamsSchema(BaseModelNoExt):
@@ -133,9 +136,11 @@ class IngestTaskVdbUploadSchema(BaseModelNoExt):
class IngestTaskTableExtraction(BaseModelNoExt):
params: Dict = {}
+
class IngestChartTableExtraction(BaseModelNoExt):
params: Dict = {}
+
class IngestTaskSchema(BaseModelNoExt):
type: TaskTypeEnum
task_properties: Union[
diff --git a/src/nv_ingest/schemas/metadata_schema.py b/src/nv_ingest/schemas/metadata_schema.py
index 6a51c83d..97a64ae6 100644
--- a/src/nv_ingest/schemas/metadata_schema.py
+++ b/src/nv_ingest/schemas/metadata_schema.py
@@ -40,32 +40,33 @@ class ContentTypeEnum(str, Enum):
TEXT = "text"
IMAGE = "image"
STRUCTURED = "structured"
+ UNSTRUCTURED = "unstructured"
INFO_MSG = "info_message"
class StdContentDescEnum(str, Enum):
- PDF_TEXT = "Unstructured text from PDF document."
- PDF_IMAGE = "Image extracted from PDF document."
- PDF_TABLE = "Structured table extracted from PDF document."
- PDF_CHART = "Structured chart extracted from PDF document."
- DOCX_TEXT = "Unstructured text from DOCX document."
DOCX_IMAGE = "Image extracted from DOCX document."
DOCX_TABLE = "Structured table extracted from DOCX document."
- PPTX_TEXT = "Unstructured text from PPTX presentation."
+ DOCX_TEXT = "Unstructured text from DOCX document."
+ PDF_CHART = "Structured chart extracted from PDF document."
+ PDF_IMAGE = "Image extracted from PDF document."
+ PDF_TABLE = "Structured table extracted from PDF document."
+ PDF_TEXT = "Unstructured text from PDF document."
PPTX_IMAGE = "Image extracted from PPTX presentation."
PPTX_TABLE = "Structured table extracted from PPTX presentation."
+ PPTX_TEXT = "Unstructured text from PPTX presentation."
class TextTypeEnum(str, Enum):
- HEADER = "header"
- BODY = "body"
- SPAN = "span"
- LINE = "line"
BLOCK = "block"
- PAGE = "page"
+ BODY = "body"
DOCUMENT = "document"
+ HEADER = "header"
+ LINE = "line"
NEARBY_BLOCK = "nearby_block"
OTHER = "other"
+ PAGE = "page"
+ SPAN = "span"
class LanguageEnum(str, Enum):
@@ -132,10 +133,10 @@ def has_value(cls, value):
class ImageTypeEnum(str, Enum):
- JPEG = "jpeg"
- PNG = "png"
BMP = "bmp"
GIF = "gif"
+ JPEG = "jpeg"
+ PNG = "png"
TIFF = "tiff"
image_type_1 = "image_type_1" # until classifier developed
@@ -148,9 +149,9 @@ def has_value(cls, value):
class TableFormatEnum(str, Enum):
HTML = "html"
- MARKDOWN = "markdown"
- LATEX = "latex"
IMAGE = "image"
+ LATEX = "latex"
+ MARKDOWN = "markdown"
class TaskTypeEnum(str, Enum):
@@ -249,13 +250,6 @@ class TextMetadataSchema(BaseModelNoExt):
text_location: tuple = (0, 0, 0, 0)
-import logging
-from pydantic import validator
-
-# Set up logging
-logger = logging.getLogger(__name__)
-
-
class ImageMetadataSchema(BaseModelNoExt):
image_type: Union[ImageTypeEnum, str]
structured_image_type: ImageTypeEnum = ImageTypeEnum.image_type_1
@@ -317,6 +311,7 @@ class InfoMessageMetadataSchema(BaseModelNoExt):
# Main metadata schema
class MetadataSchema(BaseModelNoExt):
content: str = ""
+ content_url: str = ""
embedding: Optional[List[float]] = None
source_metadata: Optional[SourceMetadataSchema] = None
content_metadata: Optional[ContentMetadataSchema] = None
diff --git a/tests/stages/__init__.py b/src/nv_ingest/stages/extractors/__init__.py
similarity index 100%
rename from tests/stages/__init__.py
rename to src/nv_ingest/stages/extractors/__init__.py
diff --git a/src/nv_ingest/stages/extractors/image_extractor_stage.py b/src/nv_ingest/stages/extractors/image_extractor_stage.py
new file mode 100644
index 00000000..85afa4c2
--- /dev/null
+++ b/src/nv_ingest/stages/extractors/image_extractor_stage.py
@@ -0,0 +1,210 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES.
+# All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+
+
+import base64
+import functools
+import io
+import logging
+import traceback
+from typing import Any
+from typing import Dict
+from typing import List
+from typing import Optional
+
+import pandas as pd
+import nv_ingest.extraction_workflows.image as image_helpers
+from morpheus.config import Config
+from nv_ingest.schemas.image_extractor_schema import ImageExtractorSchema
+from nv_ingest.stages.multiprocessing_stage import MultiProcessingBaseStage
+
+logger = logging.getLogger(f"morpheus.{__name__}")
+
+
+def decode_and_extract(
+ base64_row: pd.Series,
+ task_props: Dict[str, Any],
+ validated_config: Any,
+ default: str = "image",
+ trace_info: Optional[List] = None,
+) -> Any:
+ """
+ Decodes base64 content from a row and extracts data from it using the specified extraction method.
+
+ Parameters
+ ----------
+ base64_row : dict
+ A dictionary containing the base64-encoded content and other relevant data.
+ The key "content" should contain the base64 string, and the key "source_id" is optional.
+ task_props : dict
+ A dictionary containing task properties. It should have the keys:
+ - "method" (str): The extraction method to use (e.g., "image").
+ - "params" (dict): Parameters to pass to the extraction function.
+ validated_config : Any
+ Configuration object that contains `image_config`. Used if the `image` method is selected.
+ default : str, optional
+ The default extraction method to use if the specified method in `task_props` is not available (default is "image").
+
+ Returns
+ -------
+ Any
+ The extracted data from the decoded content. The exact return type depends on the extraction method used.
+
+ Raises
+ ------
+ KeyError
+ If the "content" key is missing from `base64_row`.
+ Exception
+ For any other unhandled exceptions during extraction, an error is logged, and the exception is re-raised.
+ """
+
+ document_type = base64_row["document_type"]
+ source_id = None
+ try:
+ base64_content = base64_row["content"]
+ except KeyError:
+ log_error_message = f"Unhandled error processing row, no content was found:\n{base64_row}"
+ logger.error(log_error_message)
+ raise
+
+ try:
+ # Row data to include in extraction
+ bool_index = base64_row.index.isin(("content",))
+ row_data = base64_row[~bool_index]
+ task_props["params"]["row_data"] = row_data
+
+ # Get source_id
+ source_id = base64_row["source_id"] if "source_id" in base64_row.index else None
+ # Decode the base64 content
+ image_bytes = base64.b64decode(base64_content)
+
+ # Load the PDF
+ image_stream = io.BytesIO(image_bytes)
+
+ # Type of extraction method to use
+ extract_method = task_props.get("method", "image")
+ extract_params = task_props.get("params", {})
+
+ logger.debug(
+ f">>> Extracting image content, image_extraction_config: {validated_config.image_extraction_config}")
+ if (validated_config.image_extraction_config is not None):
+ extract_params["image_extraction_config"] = validated_config.image_extraction_config
+
+ if (trace_info is not None):
+ extract_params["trace_info"] = trace_info
+
+ if (not hasattr(image_helpers, extract_method)):
+ extract_method = default
+
+ func = getattr(image_helpers, extract_method, default)
+ logger.debug("Running extraction method: %s", extract_method)
+ extracted_data = func(image_stream, document_type, **extract_params)
+
+ return extracted_data
+
+ except Exception as e:
+ traceback.print_exc()
+ err_msg = f"Unhandled exception in decode_and_extract for '{source_id}':\n{e}"
+ logger.error(err_msg)
+
+ raise
+
+ # Propagate error back and tag message as failed.
+ # exception_tag = create_exception_tag(error_message=log_error_message, source_id=source_id)
+
+
+def process_image(
+ df: pd.DataFrame,
+ task_props: Dict[str, Any],
+ validated_config: Any,
+ trace_info: Optional[Dict[str, Any]] = None
+) -> pd.DataFrame:
+ """
+ Processes a pandas DataFrame containing image files in base64 encoding.
+ Each image's content is replaced with its extracted components.
+
+ Parameters
+ ----------
+ df : pd.DataFrame
+ The input DataFrame with columns 'source_id' and 'content' (base64-encoded image data).
+ task_props : dict
+ Dictionary containing instructions and parameters for the image processing task.
+ validated_config : Any
+ Configuration object validated for processing images.
+ trace_info : dict, optional
+ Dictionary for tracing and logging additional information during processing (default is None).
+
+ Returns
+ -------
+ Tuple[pd.DataFrame, Dict[str, Any]]
+ A tuple containing:
+ - A pandas DataFrame with the processed image content, including columns 'document_type', 'metadata', and 'uuid'.
+ - A dictionary with trace information collected during processing.
+
+ Raises
+ ------
+ Exception
+ If an error occurs during the image processing stage.
+ """
+ logger.debug("Processing image content")
+ if trace_info is None:
+ trace_info = {}
+
+ try:
+ # Apply the helper function to each row in the 'content' column
+ _decode_and_extract = functools.partial(
+ decode_and_extract, task_props=task_props, validated_config=validated_config, trace_info=trace_info
+ )
+ logger.debug(f"Processing method: {task_props.get('method', None)}")
+ sr_extraction = df.apply(_decode_and_extract, axis=1)
+ sr_extraction = sr_extraction.explode().dropna()
+
+ if not sr_extraction.empty:
+ extracted_df = pd.DataFrame(sr_extraction.to_list(), columns=["document_type", "metadata", "uuid"])
+ else:
+ extracted_df = pd.DataFrame({"document_type": [], "metadata": [], "uuid": []})
+
+ return extracted_df, {"trace_info": trace_info}
+
+ except Exception as e:
+ err_msg = f"Unhandled exception in image extractor stage's process_image: {e}"
+ logger.error(err_msg)
+ raise
+
+
+def generate_image_extractor_stage(
+ c: Config,
+ extractor_config: Dict[str, Any],
+ task: str = "extract",
+ task_desc: str = "image_content_extractor",
+ pe_count: int = 24,
+):
+ """
+ Helper function to generate a multiprocessing stage to perform image content extraction.
+
+ Parameters
+ ----------
+ c : Config
+ Morpheus global configuration object
+ extractor_config : dict
+ Configuration parameters for pdf content extractor.
+ task : str
+ The task name to match for the stage worker function.
+ task_desc : str
+ A descriptor to be used in latency tracing.
+ pe_count : int
+ Integer for how many process engines to use for pdf content extraction.
+
+ Returns
+ -------
+ MultiProcessingBaseStage
+ A Morpheus stage with applied worker function.
+ """
+ validated_config = ImageExtractorSchema(**extractor_config)
+ _wrapped_process_fn = functools.partial(process_image, validated_config=validated_config)
+
+ return MultiProcessingBaseStage(
+ c=c, pe_count=pe_count, task=task, task_desc=task_desc, process_fn=_wrapped_process_fn,
+ document_type="regex:^(png|svg|jpeg|jpg|tiff)$"
+ )
diff --git a/src/nv_ingest/stages/multiprocessing_stage.py b/src/nv_ingest/stages/multiprocessing_stage.py
index b3af02d6..c4c5d0bd 100644
--- a/src/nv_ingest/stages/multiprocessing_stage.py
+++ b/src/nv_ingest/stages/multiprocessing_stage.py
@@ -155,7 +155,7 @@ def __init__(
task_desc: str,
pe_count: int,
process_fn: typing.Callable[[pd.DataFrame, dict], pd.DataFrame],
- document_type: str = None,
+ document_type: typing.Union[typing.List[str],str] = None,
filter_properties: dict = None,
):
super().__init__(c)
diff --git a/src/nv_ingest/stages/transforms/image_caption_extraction.py b/src/nv_ingest/stages/transforms/image_caption_extraction.py
index d1b702bb..40b798fe 100644
--- a/src/nv_ingest/stages/transforms/image_caption_extraction.py
+++ b/src/nv_ingest/stages/transforms/image_caption_extraction.py
@@ -2,452 +2,161 @@
# All rights reserved.
# SPDX-License-Identifier: Apache-2.0
-
+import base64
+import io
import logging
-import traceback
from functools import partial
-from typing import Any
+from typing import Any, Optional
from typing import Dict
-from typing import List
from typing import Tuple
-import numpy as np
import pandas as pd
-import tritonclient.grpc as grpcclient
+import requests
+from PIL import Image
from morpheus.config import Config
-from morpheus.utils.module_utils import ModuleLoaderFactory
-from sklearn.neighbors import NearestNeighbors
-from transformers import AutoTokenizer
from nv_ingest.schemas.image_caption_extraction_schema import ImageCaptionExtractionSchema
from nv_ingest.schemas.metadata_schema import ContentTypeEnum
from nv_ingest.stages.multiprocessing_stage import MultiProcessingBaseStage
+from nv_ingest.util.image_processing.transforms import scale_image_to_encoding_size
+from nv_ingest.util.tracing.tagging import traceable_func
logger = logging.getLogger(__name__)
MODULE_NAME = "image_caption_extraction"
MODULE_NAMESPACE = "nv_ingest"
-ImageCaptionExtractionLoaderFactory = ModuleLoaderFactory(MODULE_NAME, MODULE_NAMESPACE)
-
-
-def _extract_bboxes_and_content(data: Dict[str, Any]) -> Tuple[List[Tuple[int, int, int, int]], List[str]]:
- """
- Extract bounding boxes and associated content from a deeply nested data structure.
-
- Parameters
- ----------
- data : Dict[str, Any]
- A dictionary containing nested data from which bounding boxes and content are extracted.
-
- Returns
- -------
- Tuple[List[Tuple[int, int, int, int]], List[str]]
- A tuple containing two lists:
- - First list of tuples representing bounding boxes (x1, y1, x2, y2).
- - Second list of strings representing associated content.
- """
- nearby_objects = data["content_metadata"]["hierarchy"]["nearby_objects"]["text"]
- bboxes = nearby_objects["bbox"]
- content = nearby_objects["content"]
- return bboxes, content
-
-
-def _calculate_centroids(bboxes: List[Tuple[int, int, int, int]]) -> List[Tuple[float, float]]:
- """
- Calculate centroids from bounding boxes.
-
- Parameters
- ----------
- bboxes : List[Tuple[int, int, int, int]]
- A list of tuples each representing a bounding box as (x1, y1, x2, y2).
-
- Returns
- -------
- List[Tuple[float, float]]
- A list of tuples each representing the centroid (x, y) of the corresponding bounding box.
- """
- return [(bbox[0] + (bbox[2] - bbox[0]) / 2, bbox[1] + (bbox[3] - bbox[1]) / 2) for bbox in bboxes]
-
-
-def _fit_nearest_neighbors(centroids: List[Tuple[float, float]], n_neighbors: int = 5) -> Tuple[NearestNeighbors, int]:
- """
- Fit the NearestNeighbors model to the centroids, ensuring the number of neighbors does not exceed available
- centroids.
-
- Parameters
- ----------
- centroids : List[Tuple[float, float]]
- A list of tuples each representing the centroid coordinates (x, y) of bounding boxes.
- n_neighbors : int, optional
- The number of neighbors to use by default for kneighbors queries.
-
- Returns
- -------
- Tuple[NearestNeighbors, int]
- A tuple containing:
- - NearestNeighbors model fitted to the centroids.
- - The adjusted number of neighbors, which is the minimum of `n_neighbors` and the number of centroids.
- """
- centroids_array = np.array(centroids)
- adjusted_n_neighbors = min(n_neighbors, len(centroids_array))
- nbrs = NearestNeighbors(n_neighbors=adjusted_n_neighbors, algorithm="auto", metric="euclidean")
- nbrs.fit(centroids_array)
-
- return nbrs, adjusted_n_neighbors
-
-
-def _find_nearest_neighbors(
- nbrs: NearestNeighbors, new_bbox: Tuple[int, int, int, int], content: List[str], n_neighbors: int
-) -> Tuple[np.ndarray, np.ndarray, List[str]]:
- """
- Find the nearest neighbors for a new bounding box and return associated content.
-
- Parameters
- ----------
- nbrs : NearestNeighbors
- The trained NearestNeighbors model.
- new_bbox : Tuple[int, int, int, int]
- The bounding box for which to find the nearest neighbors, specified as (x1, y1, x2, y2).
- content : List[str]
- A list of content strings associated with each bounding box.
- n_neighbors : int
- The number of neighbors to retrieve.
-
- Returns
- -------
- Tuple[np.ndarray, np.ndarray, List[str]]
- A tuple containing:
- - distances: An array of distances to the nearest neighbors.
- - indices: An array of indices for the nearest neighbors.
- - A list of content strings corresponding to the nearest neighbors.
- """
- new_centroid = np.array(
- [(new_bbox[0] + (new_bbox[2] - new_bbox[0]) / 2, new_bbox[1] + (new_bbox[3] - new_bbox[1]) / 2)]
- )
- new_centroid_reshaped = new_centroid.reshape(1, -1) # Reshape to ensure 2D
- distances, indices = nbrs.kneighbors(new_centroid_reshaped, n_neighbors=n_neighbors)
- return distances, indices, [content[i] for i in indices.flatten()]
-
-
-def _sanitize_inputs(inputs: List[List[str]]) -> List[List[str]]:
- """
- Replace non-ASCII characters with '?' in inputs.
-
- Parameters
- ----------
- inputs : List[List[str]]
- A list of lists where each sub-list contains strings.
-
- Returns
- -------
- List[List[str]]
- A list of lists where each string has been sanitized to contain only ASCII characters,
- with non-ASCII characters replaced by '?'.
- """
- cleaned_inputs = [
- [candidate.encode("ascii", "replace").decode("ascii") for candidate in candidates] for candidates in inputs
- ]
- return cleaned_inputs
-
-
-TOKENIZER_NAME = "microsoft/deberta-large"
-tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME)
-
-
-def _predict_caption(triton_url: str, caption_model: str, inputs: List[List[str]], n_candidates: int = 5) -> List[str]:
- """
- Sends a request to a Triton inference server to generate captions based on provided inputs.
-
- Parameters
- ----------
- triton_url : str
- The URL of the Triton inference server.
- headers : Dict[str, str]
- HTTP headers to send with the request.
- inputs : List[List[str]]
- The input data for which captions are generated.
- n_candidates : int, optional
- The number of candidates per input data. Default is 5.
-
- Returns
- -------
- List[str]
- A list of generated captions, one for each input.
- """
-
- sanitized_inputs = _sanitize_inputs(inputs)
- captions = [""] * len(sanitized_inputs)
- max_batch_size = 128
-
- try:
- client = grpcclient.InferenceServerClient(url=triton_url)
-
- # Process inputs in batches
- for batch_start in range(0, len(sanitized_inputs), max_batch_size // n_candidates):
- batch_end = min(batch_start + (max_batch_size // n_candidates), len(sanitized_inputs))
- input_batch = sanitized_inputs[batch_start:batch_end]
- flattened_sentences = [sentence for batch in input_batch for sentence in batch]
- encoded_inputs = tokenizer(
- flattened_sentences, max_length=128, padding="max_length", truncation=True, return_tensors="np"
- )
-
- input_ids = encoded_inputs["input_ids"].astype(np.int64)
- attention_mask = encoded_inputs["attention_mask"].astype(np.float32)
- infer_inputs = [
- grpcclient.InferInput("input_ids", input_ids.shape, "INT64"),
- grpcclient.InferInput("input_mask", attention_mask.shape, "FP32"),
- ]
-
- # Set the data for the input tensors
- infer_inputs[0].set_data_from_numpy(input_ids)
- infer_inputs[1].set_data_from_numpy(attention_mask)
-
- outputs = [grpcclient.InferRequestedOutput("output")]
-
- # Perform inference
- response = client.infer(model_name=caption_model, inputs=infer_inputs, outputs=outputs)
-
- output_data = response.as_numpy("output")
-
- # Process the output to find the best sentence in each batch
- batch_size = n_candidates
- for i, batch in enumerate(input_batch):
- start_idx = i * batch_size
- end_idx = (i + 1) * batch_size
- batch_output = output_data[start_idx:end_idx]
- max_index = np.argmax(batch_output)
- best_sentence = batch[max_index]
- best_probability = batch_output[max_index]
- if best_probability > 0.5:
- captions[batch_start + i] = best_sentence
-
- except Exception as e:
- logging.error(f"An error occurred: {e}")
-
- return captions
-
def _prepare_dataframes_mod(df) -> Tuple[pd.DataFrame, pd.DataFrame, pd.Series]:
if df.empty or "document_type" not in df.columns:
return df, pd.DataFrame(), pd.Series(dtype=bool)
bool_index = df["document_type"] == ContentTypeEnum.IMAGE
- df_filtered = df.loc[bool_index]
-
- return df, df_filtered, bool_index
-
-
-def _process_documents(df_filtered: pd.DataFrame) -> Tuple[List[Any], List[List[str]]]:
- """
- Processes documents to extract content and bounding boxes, then finds nearest neighbors.
-
- Parameters
- ----------
- df_filtered : pd.DataFrame
- The dataframe filtered to contain only relevant documents.
-
- Returns
- -------
- Tuple[List[Any], List[List[str]]]
- A tuple containing metadata for each document and a list of neighbor content.
- """
- neighbor_content = []
- metadata_list = []
- for _, row in df_filtered.iterrows():
- metadata = row.metadata
- metadata_list.append(metadata)
- bboxes, content = _extract_bboxes_and_content(metadata)
- _process_content(bboxes, content, metadata, neighbor_content)
-
- return metadata_list, neighbor_content
-
-
-def _process_content(
- bboxes: List[Tuple[int, int, int, int]],
- content: List[str],
- metadata: Dict[str, Any],
- neighbor_content: List[List[str]],
- n_neighbors: int = 5,
-) -> None:
- """
- Process content by finding nearest neighbors and appending the results.
-
- Parameters
- ----------
- bboxes : List[Tuple[int, int, int, int]]
- A list of bounding boxes.
- content : List[str]
- Content associated with each bounding box.
- metadata : Dict[str, Any]
- Metadata associated with each content piece, containing image metadata.
- neighbor_content : List[List[str]]
- A list that will be appended with the nearest neighbor content.
- n_neighbors : int, optional
- The number of nearest neighbors to find (default is 5).
-
- Returns
- -------
- None
- """
- if bboxes and content:
- centroids = _calculate_centroids(bboxes)
- nn_mod, adj_neighbors = _fit_nearest_neighbors(centroids)
- image_bbox = metadata["image_metadata"]["image_location"]
- distances, indices, nearest_content = _find_nearest_neighbors(nn_mod, image_bbox, content, adj_neighbors)
- else:
- nearest_content = []
-
- if len(nearest_content) < n_neighbors:
- nearest_content.extend([""] * (n_neighbors - len(nearest_content)))
+ df_matched = df.loc[bool_index]
- neighbor_content.append(nearest_content)
+ return df, df_matched, bool_index
-def _generate_captions(neighbor_content: List[List[str]], config: Any) -> List[str]:
+def _generate_captions(base64_image: str, prompt: str, api_key: str, endpoint_url: str) -> str:
"""
- Generate captions for provided content using a Triton inference server.
+ Sends a base64-encoded PNG image to the NVIDIA LLaMA model API and retrieves the generated caption.
Parameters
----------
- neighbor_content : List[List[str]]
- A list of content batches for which to generate captions.
- config : Any
- Configuration object containing endpoint URL, headers, and batch size.
+ base64_image : str
+ Base64-encoded PNG image string.
+ api_key : str
+ API key for authentication with the NVIDIA model endpoint.
Returns
-------
- List[str]
- A list of generated captions.
+ str
+ Generated caption for the image or an error message.
"""
- captions = []
- for i in range(0, len(neighbor_content), config.batch_size):
- batch = neighbor_content[i : i + config.batch_size] # noqa: E203
- batch_captions = _predict_caption(config.endpoint_url, config.caption_classifier_model_name, batch)
- captions.extend(batch_captions)
-
- return captions
-
+ stream = False # Set to False for non-streaming response
-def _update_metadata_with_captions(
- metadata_list: List[Dict[str, Any]], captions: List[str], df_filtered: pd.DataFrame
-) -> List[Dict[str, Any]]:
- """
- Update metadata with captions and compile into a list of image document dictionaries.
+ # Ensure the base64 image size is within acceptable limits
+ base64_image = scale_image_to_encoding_size(base64_image)
- Parameters
- ----------
- metadata_list : List[Dict[str, Any]]
- A list of metadata dictionaries.
- captions : List[str]
- A list of captions corresponding to the metadata.
- df_filtered : pd.DataFrame
- The filtered DataFrame containing document UUIDs.
+ headers = {
+ "Authorization": f"Bearer {api_key}",
+ "Accept": "application/json"
+ }
- Returns
- -------
- List[Dict[str, Any]]
- A list of dictionaries each containing updated document metadata and type.
- """
- image_docs = []
- for metadata, caption, (_, row) in zip(metadata_list, captions, df_filtered.iterrows()):
- metadata["image_metadata"]["caption"] = caption
- image_docs.append(
+ # Payload for the request
+ payload = {
+ "model": 'meta/llama-3.2-90b-vision-instruct',
+ "messages": [
{
- "document_type": ContentTypeEnum.IMAGE.value,
- "metadata": metadata,
- "uuid": row.get("uuid"),
+ "role": "user",
+ "content": f'{prompt}
'
}
- )
-
- return image_docs
-
-
-def _prepare_final_dataframe_mod(df: pd.DataFrame, image_docs: List[Dict[str, Any]], filter_index: pd.Series) -> None:
- """
- Prepares the final dataframe by combining original dataframe with new image document data, converting to GPU
- dataframe, and updating the message with the new dataframe.
-
- Parameters
- ----------
- df : pd.DataFrame
- The original dataframe.
- image_docs : List[Dict[str, Any]]
- A list of dictionaries containing image document data.
- filter_index : pd.Series
- A boolean series that filters the dataframe.
-
- Returns
- -------
- None
- """
-
- image_docs_df = pd.DataFrame(image_docs)
- docs_df = pd.concat([df[~filter_index], image_docs_df], axis=0).reset_index(drop=True)
-
- return docs_df
+ ],
+ "max_tokens": 512,
+ "temperature": 1.00,
+ "top_p": 1.00,
+ "stream": stream
+ }
-
-def caption_extract_stage(df, task_props, validated_config) -> pd.DataFrame:
- """
- Extracts captions from images within the provided dataframe.
+ try:
+ response = requests.post(endpoint_url, headers=headers, json=payload)
+ response.raise_for_status() # Raise an exception for HTTP errors
+
+ if stream:
+ result = []
+ for line in response.iter_lines():
+ if line:
+ result.append(line.decode("utf-8"))
+ return "\n".join(result)
+ else:
+ response_data = response.json()
+ return response_data.get('choices', [{}])[0].get('message', {}).get('content', 'No caption returned')
+ except requests.exceptions.RequestException as e:
+ logger.error(f"Error generating caption: {e}")
+ raise
+
+
+def caption_extract_stage(df: pd.DataFrame,
+ task_props: Dict[str, Any],
+ validated_config: Any,
+ trace_info: Optional[Dict[str, Any]] = None
+ ) -> pd.DataFrame:
+ """
+ Extracts captions for image content in the DataFrame using an external NVIDIA API.
+ Updates the 'metadata' column by adding the generated captions under 'image_metadata.caption'.
Parameters
----------
df : pd.DataFrame
- The dataframe containing image data.
- task_props : dict
- Task properties required for processing.
- validated_config : ImageCaptionExtractionSchema
- Validated configuration for caption extraction.
+ The input DataFrame containing image data in 'metadata.content'.
+ validated_config : Any
+ A configuration schema object containing settings for caption extraction.
Returns
-------
pd.DataFrame
- The dataframe with updated image captions.
+ The updated DataFrame with generated captions in the 'metadata' column's 'image_metadata.caption' field.
Raises
------
- ValueError
- If an error occurs during caption extraction.
+ Exception
+ If there is an error during the caption extraction process.
"""
- try:
- logger.debug("Performing caption extraction")
-
- # Data preparation and filtering
- df, df_filtered, filter_index = _prepare_dataframes_mod(df)
- if df_filtered.empty:
- return df
+ logger.debug("Attempting to caption image content")
- # Process each image document
- metadata_list, neighbor_content = _process_documents(df_filtered)
+ # Ensure the validated configuration is available for future use
+ _ = trace_info
- # Generate captions
- captions = _generate_captions(neighbor_content, validated_config)
+ api_key = task_props.get("api_key", validated_config.api_key)
+ prompt = task_props.get("prompt", validated_config.prompt)
+ endpoint_url = task_props.get("endpoint_url", validated_config.endpoint_url)
- # Update metadata with captions
- image_docs = _update_metadata_with_captions(metadata_list, captions, df_filtered)
+ # Create a mask for rows where the document type is IMAGE
+ df_mask = df['metadata'].apply(lambda meta: meta.get('content_metadata', {}).get('type') == "image")
- logger.debug(f"Extracted captions from {len(image_docs)} images")
+ if not df_mask.any():
+ return df
- # Final dataframe merge
- df_final = _prepare_final_dataframe_mod(df, image_docs, filter_index)
+ df.loc[df_mask, 'metadata'] = df.loc[df_mask, 'metadata'].apply(
+ lambda meta: {
+ **meta,
+ 'image_metadata': {
+ **meta.get('image_metadata', {}),
+ 'caption': _generate_captions(meta['content'], prompt, api_key, endpoint_url)
+ }
+ }
+ )
- if (df_final is None) or df_final.empty:
- logger.warning("NO IMAGE DOCUMENTS FOUND IN THE DATAFRAME")
- return df
- return df_final
- except Exception as e:
- traceback.print_exc()
- raise ValueError(f"Failed to do caption extraction: {e}")
+ logger.debug("Image content captioning complete")
+ return df
def generate_caption_extraction_stage(
- c: Config,
- caption_config: Dict[str, Any],
- task: str = "caption",
- task_desc: str = "caption_extraction",
- pe_count: int = 8,
+ c: Config,
+ caption_config: Dict[str, Any],
+ task: str = "caption",
+ task_desc: str = "caption_extraction",
+ pe_count: int = 8,
):
"""
Generates a caption extraction stage with the specified configuration.
@@ -483,10 +192,5 @@ def generate_caption_extraction_stage(
f"Generating caption extraction stage with {pe_count} processing elements. task: {task}, document_type: *"
)
return MultiProcessingBaseStage(
- c=c,
- pe_count=pe_count,
- task=task,
- task_desc=task_desc,
- process_fn=_wrapped_caption_extract,
- filter_properties={"content_type": ContentTypeEnum.IMAGE.value},
+ c=c, pe_count=pe_count, task=task, task_desc=task_desc, process_fn=_wrapped_caption_extract
)
diff --git a/src/nv_ingest/util/converters/type_mappings.py b/src/nv_ingest/util/converters/type_mappings.py
index b3834d50..4fbfb0a9 100644
--- a/src/nv_ingest/util/converters/type_mappings.py
+++ b/src/nv_ingest/util/converters/type_mappings.py
@@ -15,6 +15,7 @@
DocumentTypeEnum.png: ContentTypeEnum.IMAGE,
DocumentTypeEnum.pptx: ContentTypeEnum.STRUCTURED,
DocumentTypeEnum.svg: ContentTypeEnum.IMAGE,
+ DocumentTypeEnum.tiff: ContentTypeEnum.IMAGE,
DocumentTypeEnum.txt: ContentTypeEnum.TEXT,
}
diff --git a/src/nv_ingest/util/flow_control/filter_by_task.py b/src/nv_ingest/util/flow_control/filter_by_task.py
index 2e973372..11a73a43 100644
--- a/src/nv_ingest/util/flow_control/filter_by_task.py
+++ b/src/nv_ingest/util/flow_control/filter_by_task.py
@@ -5,6 +5,7 @@
import logging
import typing
+import re
from functools import wraps
from morpheus.messages import ControlMessage
@@ -50,10 +51,12 @@ def wrapper(*args, **kwargs):
continue
task_props_list = tasks.get(required_task_name, [])
+ logger.debug(f"Checking task properties for: {required_task_name}")
+ logger.debug(f"Required task properties: {required_task_props_list}")
for task_props in task_props_list:
if all(
- _is_subset(task_props, required_task_props)
- for required_task_props in required_task_props_list
+ _is_subset(task_props, required_task_props)
+ for required_task_props in required_task_props_list
):
return func(*args, **kwargs)
@@ -75,22 +78,38 @@ def _is_subset(superset, subset):
if subset == "*":
return True
if isinstance(superset, dict) and isinstance(subset, dict):
- return all(key in superset and _is_subset(superset[key], val) for key, val in subset.items())
+ return all(
+ key in superset and _is_subset(superset[key], val)
+ for key, val in subset.items()
+ )
+ if isinstance(subset, str) and subset.startswith('regex:'):
+ # The subset is a regex pattern
+ pattern = subset[len('regex:'):]
+ if isinstance(superset, list):
+ return any(re.match(pattern, str(sup_item)) for sup_item in superset)
+ else:
+ return re.match(pattern, str(superset)) is not None
+ if isinstance(superset, list) and not isinstance(subset, list):
+ # Check if the subset value matches any item in the superset
+ return any(_is_subset(sup_item, subset) for sup_item in superset)
if isinstance(superset, list) or isinstance(superset, set):
- return all(any(_is_subset(sup_item, sub_item) for sup_item in superset) for sub_item in subset)
+ return all(
+ any(_is_subset(sup_item, sub_item) for sup_item in superset)
+ for sub_item in subset
+ )
return superset == subset
def remove_task_subset(ctrl_msg: ControlMessage, task_type: typing.List, subset: typing.Dict):
"""
- A helper function to extract a task based on subset matching when the task might be out of order wrt the
+ A helper function to extract a task based on subset matching when the task might be out of order with respect to the
Morpheus pipeline. For example, if a deduplication filter occurs before scale filtering in the pipeline, but
the task list includes scale filtering before deduplication.
Parameters
----------
ctrl_msg : ControlMessage
- A list of task keys to check for in the ControlMessage.
+ The ControlMessage object containing tasks.
task_type : list
The name of the ControlMessage task to operate on.
subset : dict
diff --git a/src/nv_ingest/util/image_processing/transforms.py b/src/nv_ingest/util/image_processing/transforms.py
index d441db72..15f25757 100644
--- a/src/nv_ingest/util/image_processing/transforms.py
+++ b/src/nv_ingest/util/image_processing/transforms.py
@@ -3,6 +3,8 @@
# SPDX-License-Identifier: Apache-2.0
import base64
+import io
+import logging
from io import BytesIO
from math import ceil
from math import floor
@@ -18,13 +20,112 @@
DEFAULT_MAX_WIDTH = 1024
DEFAULT_MAX_HEIGHT = 1280
+logger = logging.getLogger(__name__)
+
+
+def scale_image_to_encoding_size(base64_image: str, max_base64_size: int = 180_000,
+ initial_reduction: float = 0.9) -> str:
+ """
+ Decodes a base64-encoded image, resizes it if needed, and re-encodes it as base64.
+ Ensures the final image size is within the specified limit.
+
+ Parameters
+ ----------
+ base64_image : str
+ Base64-encoded image string.
+ max_base64_size : int, optional
+ Maximum allowable size for the base64-encoded image, by default 180,000 characters.
+ initial_reduction : float, optional
+ Initial reduction step for resizing, by default 0.9.
+
+ Returns
+ -------
+ str
+ Base64-encoded PNG image string, resized if necessary.
+
+ Raises
+ ------
+ Exception
+ If the image cannot be resized below the specified max_base64_size.
+ """
+ try:
+ # Decode the base64 image and open it as a PIL image
+ image_data = base64.b64decode(base64_image)
+ img = Image.open(io.BytesIO(image_data)).convert("RGB")
+
+ # Check initial size
+ if len(base64_image) <= max_base64_size:
+ logger.debug("Initial image is within the size limit.")
+ return base64_image
+
+ # Initial reduction step
+ reduction_step = initial_reduction
+ while len(base64_image) > max_base64_size:
+ width, height = img.size
+ new_size = (int(width * reduction_step), int(height * reduction_step))
+ logger.debug(f"Resizing image to {new_size}")
+
+ img_resized = img.resize(new_size, Image.LANCZOS)
+ buffered = io.BytesIO()
+ img_resized.save(buffered, format="PNG")
+ base64_image = base64.b64encode(buffered.getvalue()).decode("utf-8")
+
+ logger.debug(f"Resized base64 image size: {len(base64_image)} characters.")
+
+ # Adjust the reduction step if necessary
+ if len(base64_image) > max_base64_size:
+ reduction_step *= 0.95 # Reduce size further if needed
+ logger.debug(f"Reducing step size for further resizing: {reduction_step:.3f}")
+
+ # Safety check
+ if new_size[0] < 1 or new_size[1] < 1:
+ raise Exception("Image cannot be resized further without becoming too small.")
+
+ return base64_image
+
+ except Exception as e:
+ logger.error(f"Error resizing the image: {e}")
+ raise
+
+
+def ensure_base64_is_png(base64_image: str) -> str:
+ """
+ Ensures the given base64-encoded image is in PNG format. Converts to PNG if necessary.
+
+ Parameters
+ ----------
+ base64_image : str
+ Base64-encoded image string.
+
+ Returns
+ -------
+ str
+ Base64-encoded PNG image string.
+ """
+ try:
+ # Decode the base64 string and load the image
+ image_data = base64.b64decode(base64_image)
+ image = Image.open(io.BytesIO(image_data))
+
+ # Check if the image is already in PNG format
+ if image.format != 'PNG':
+ # Convert the image to PNG
+ buffered = io.BytesIO()
+ image.convert("RGB").save(buffered, format="PNG")
+ base64_image = base64.b64encode(buffered.getvalue()).decode("utf-8")
+
+ return base64_image
+ except Exception as e:
+ logger.error(f"Error ensuring PNG format: {e}")
+ return None
+
def pad_image(
- array: np.ndarray,
- target_width: int = DEFAULT_MAX_WIDTH,
- target_height: int = DEFAULT_MAX_HEIGHT,
- background_color: int = 255,
- dtype=np.uint8,
+ array: np.ndarray,
+ target_width: int = DEFAULT_MAX_WIDTH,
+ target_height: int = DEFAULT_MAX_HEIGHT,
+ background_color: int = 255,
+ dtype=np.uint8,
) -> Tuple[np.ndarray, Tuple[int, int]]:
"""
Pads a NumPy array representing an image to the specified target dimensions.
@@ -75,10 +176,11 @@ def pad_image(
# Create the canvas and place the original image on it
canvas = background_color * np.ones((final_height, final_width, array.shape[2]), dtype=dtype)
- canvas[pad_height : pad_height + height, pad_width : pad_width + width] = array # noqa: E203
+ canvas[pad_height: pad_height + height, pad_width: pad_width + width] = array # noqa: E203
return canvas, (pad_width, pad_height)
+
def check_numpy_image_size(image: np.ndarray, min_height: int, min_width: int) -> bool:
"""
Checks if the height and width of the image are larger than the specified minimum values.
@@ -98,8 +200,9 @@ def check_numpy_image_size(image: np.ndarray, min_height: int, min_width: int) -
height, width = image.shape[:2]
return height >= min_height and width >= min_width
+
def crop_image(
- array: np.array, bbox: Tuple[int, int, int, int], min_width: int = 1, min_height: int = 1
+ array: np.array, bbox: Tuple[int, int, int, int], min_width: int = 1, min_height: int = 1
) -> Optional[np.ndarray]:
"""
Crops a NumPy array representing an image according to the specified bounding box.
@@ -138,13 +241,13 @@ def crop_image(
def normalize_image(
- array: np.ndarray,
- r_mean: float = 0.485,
- g_mean: float = 0.456,
- b_mean: float = 0.406,
- r_std: float = 0.229,
- g_std: float = 0.224,
- b_std: float = 0.225,
+ array: np.ndarray,
+ r_mean: float = 0.485,
+ g_mean: float = 0.456,
+ b_mean: float = 0.406,
+ r_std: float = 0.229,
+ g_std: float = 0.224,
+ b_std: float = 0.225,
) -> np.ndarray:
"""
Normalizes an RGB image by applying a mean and standard deviation to each channel.
diff --git a/src/nv_ingest/extraction_workflows/pdf/yolox_utils.py b/src/nv_ingest/util/nim/yolox.py
similarity index 93%
rename from src/nv_ingest/extraction_workflows/pdf/yolox_utils.py
rename to src/nv_ingest/util/nim/yolox.py
index e8458ea0..f77460e8 100644
--- a/src/nv_ingest/extraction_workflows/pdf/yolox_utils.py
+++ b/src/nv_ingest/util/nim/yolox.py
@@ -2,6 +2,8 @@
# All rights reserved.
# SPDX-License-Identifier: Apache-2.0
+from typing import List
+import numpy as np
import warnings
import cv2
@@ -26,7 +28,7 @@ def postprocess_model_prediction(prediction, num_classes, conf_thre=0.7, nms_thr
continue
# Get score and class with highest confidence
- class_conf, class_pred = torch.max(image_pred[:, 5 : 5 + num_classes], 1, keepdim=True) # noqa: E203
+ class_conf, class_pred = torch.max(image_pred[:, 5: 5 + num_classes], 1, keepdim=True) # noqa: E203
conf_mask = (image_pred[:, 4] * class_conf.squeeze() >= conf_thre).squeeze()
# Detections ordered as (x1, y1, x2, y2, obj_conf, class_conf, class_pred)
@@ -201,14 +203,14 @@ def expand_chart_bboxes(annotation_dict, labels=None):
def weighted_boxes_fusion(
- boxes_list,
- scores_list,
- labels_list,
- iou_thr=0.5,
- skip_box_thr=0.0,
- conf_type="avg",
- merge_type="weighted",
- class_agnostic=False,
+ boxes_list,
+ scores_list,
+ labels_list,
+ iou_thr=0.5,
+ skip_box_thr=0.0,
+ conf_type="avg",
+ merge_type="weighted",
+ class_agnostic=False,
):
"""
Custom wbf implementation that supports a class_agnostic mode and a biggest box fusion.
@@ -581,3 +583,35 @@ def get_weighted_box(boxes, conf_type="avg"):
box[3] = -1 # model index field is retained for consistency but is not used.
box[4:] /= conf
return box
+
+
+def prepare_images_for_inference(images: List[np.ndarray]) -> np.ndarray:
+ """
+ Prepare a list of images for model inference by resizing and reordering axes.
+
+ Parameters
+ ----------
+ images : List[np.ndarray]
+ A list of image arrays to be prepared for inference.
+
+ Returns
+ -------
+ np.ndarray
+ A numpy array suitable for model input, with the shape reordered to match the expected input format.
+
+ Notes
+ -----
+ The images are resized to 1024x1024 pixels and the axes are reordered to match the expected input shape for
+ the model (batch, channels, height, width).
+
+ Examples
+ --------
+ >>> images = [np.random.rand(1536, 1536, 3) for _ in range(2)]
+ >>> input_array = prepare_images_for_inference(images)
+ >>> input_array.shape
+ (2, 3, 1024, 1024)
+ """
+
+ resized_images = [resize_image(image, (1024, 1024)) for image in images]
+
+ return np.einsum("bijk->bkij", resized_images).astype(np.float32)
diff --git a/src/nv_ingest/util/pdf/metadata_aggregators.py b/src/nv_ingest/util/pdf/metadata_aggregators.py
index 851a931b..64b040af 100644
--- a/src/nv_ingest/util/pdf/metadata_aggregators.py
+++ b/src/nv_ingest/util/pdf/metadata_aggregators.py
@@ -3,16 +3,20 @@
# SPDX-License-Identifier: Apache-2.0
-import uuid
from dataclasses import dataclass
from datetime import datetime
+from PIL import Image
from typing import Any
from typing import Dict
from typing import List
from typing import Tuple
+import base64
+import io
+import uuid
import pandas as pd
import pypdfium2 as pdfium
+from pypdfium2 import PdfImage
from nv_ingest.schemas.metadata_schema import ContentSubtypeEnum
from nv_ingest.schemas.metadata_schema import ContentTypeEnum
@@ -186,8 +190,95 @@ def construct_text_metadata(
return [ContentTypeEnum.TEXT, validated_unified_metadata.dict(), str(uuid.uuid4())]
-def construct_image_metadata(
- image_base64: Base64Image,
+def construct_image_metadata_from_base64(
+ base64_image: str,
+ page_idx: int,
+ page_count: int,
+ source_metadata: Dict[str, Any],
+ base_unified_metadata: Dict[str, Any],
+) -> List[Any]:
+ """
+ Extracts image data from a base64-encoded image string, decodes the image to get
+ its dimensions and bounding box, and constructs metadata for the image.
+
+ Parameters
+ ----------
+ base64_image : str
+ A base64-encoded string representing the image.
+ page_idx : int
+ The index of the current page being processed.
+ page_count : int
+ The total number of pages in the PDF document.
+ source_metadata : Dict[str, Any]
+ Metadata related to the source of the PDF document.
+ base_unified_metadata : Dict[str, Any]
+ The base unified metadata structure to be updated with the extracted image information.
+
+ Returns
+ -------
+ List[Any]
+ A list containing the content type, validated metadata dictionary, and a UUID string.
+
+ Raises
+ ------
+ ValueError
+ If the image cannot be decoded from the base64 string.
+ """
+ # Decode the base64 image
+ try:
+ image_data = base64.b64decode(base64_image)
+ image = Image.open(io.BytesIO(image_data))
+ except Exception as e:
+ raise ValueError(f"Failed to decode image from base64: {e}")
+
+ # Extract image dimensions and bounding box
+ width, height = image.size
+ bbox = (0, 0, width, height) # Assuming the full image as the bounding box
+
+ # Construct content metadata
+ content_metadata: Dict[str, Any] = {
+ "type": ContentTypeEnum.IMAGE,
+ "description": StdContentDescEnum.PDF_IMAGE,
+ "page_number": page_idx,
+ "hierarchy": {
+ "page_count": page_count,
+ "page": page_idx,
+ "block": -1,
+ "line": -1,
+ "span": -1,
+ "nearby_objects": [],
+ },
+ }
+
+ # Construct image metadata
+ image_metadata: Dict[str, Any] = {
+ "image_type": "PNG", # This can be dynamic if needed
+ "structured_image_type": ImageTypeEnum.image_type_1,
+ "caption": "",
+ "text": "",
+ "image_location": bbox,
+ "image_location_max_dimensions": (width, height),
+ "height": height,
+ }
+
+ # Update the unified metadata with the extracted image information
+ unified_metadata: Dict[str, Any] = base_unified_metadata.copy()
+ unified_metadata.update(
+ {
+ "content": base64_image,
+ "source_metadata": source_metadata,
+ "content_metadata": content_metadata,
+ "image_metadata": image_metadata,
+ }
+ )
+
+ # Validate and return the unified metadata
+ validated_unified_metadata = validate_metadata(unified_metadata)
+ return [ContentTypeEnum.IMAGE, validated_unified_metadata.dict(), str(uuid.uuid4())]
+
+
+def construct_image_metadata_from_pdf_image(
+ pdf_image: PdfImage,
page_idx: int,
page_count: int,
source_metadata: Dict[str, Any],
@@ -219,7 +310,7 @@ def construct_image_metadata(
------
PdfiumError
If the image cannot be extracted due to an issue with the PdfImage object.
- :param image_base64:
+ :param pdf_image:
"""
# Define the assumed image type (e.g., PNG)
image_type: str = "PNG"
@@ -245,16 +336,16 @@ def construct_image_metadata(
"structured_image_type": ImageTypeEnum.image_type_1,
"caption": "",
"text": "",
- "image_location": image_base64.bbox,
- "image_location_max_dimensions": (max(image_base64.max_width,0), max(image_base64.max_height,0)),
- "height": image_base64.height,
+ "image_location": pdf_image.bbox,
+ "image_location_max_dimensions": (max(pdf_image.max_width, 0), max(pdf_image.max_height, 0)),
+ "height": pdf_image.height,
}
# Update the unified metadata with the extracted image information
unified_metadata: Dict[str, Any] = base_unified_metadata.copy()
unified_metadata.update(
{
- "content": image_base64.image,
+ "content": pdf_image.image,
"source_metadata": source_metadata,
"content_metadata": content_metadata,
"image_metadata": image_metadata,
diff --git a/src/nv_ingest/util/pipeline/stage_builders.py b/src/nv_ingest/util/pipeline/stage_builders.py
index ee8ae3fe..a55f9e95 100644
--- a/src/nv_ingest/util/pipeline/stage_builders.py
+++ b/src/nv_ingest/util/pipeline/stage_builders.py
@@ -23,6 +23,7 @@
from nv_ingest.modules.transforms.embed_extractions import EmbedExtractionsLoaderFactory
from nv_ingest.modules.transforms.nemo_doc_splitter import NemoDocSplitterLoaderFactory
from nv_ingest.stages.docx_extractor_stage import generate_docx_extractor_stage
+from nv_ingest.stages.extractors.image_extractor_stage import generate_image_extractor_stage
from nv_ingest.stages.filters import generate_dedup_stage
from nv_ingest.stages.filters import generate_image_filter_stage
from nv_ingest.stages.nim.chart_extraction import generate_chart_extractor_stage
@@ -247,6 +248,29 @@ def add_chart_extractor_stage(pipe, morpheus_pipeline_config, ingest_config, def
return table_extractor_stage
+def add_image_extractor_stage(pipe, morpheus_pipeline_config, ingest_config, default_cpu_count):
+ yolox_grpc, yolox_http, yolox_auth, yolox_protocol = get_table_detection_service("yolox")
+ image_extractor_config = ingest_config.get("image_extraction_module",
+ {
+ "image_extraction_config": {
+ "yolox_endpoints": (yolox_grpc, yolox_http),
+ "yolox_infer_protocol": yolox_protocol,
+ "auth_token": yolox_auth,
+ # All auth tokens are the same for the moment
+ }
+ })
+ image_extractor_stage = pipe.add_stage(
+ generate_image_extractor_stage(
+ morpheus_pipeline_config,
+ extractor_config=image_extractor_config,
+ pe_count=8,
+ task="extract",
+ task_desc="docx_content_extractor",
+ )
+ )
+ return image_extractor_stage
+
+
def add_docx_extractor_stage(pipe, morpheus_pipeline_config, default_cpu_count):
docx_extractor_stage = pipe.add_stage(
generate_docx_extractor_stage(
@@ -319,14 +343,25 @@ def add_nemo_splitter_stage(pipe, morpheus_pipeline_config, ingest_config):
def add_image_caption_stage(pipe, morpheus_pipeline_config, ingest_config, default_cpu_count):
- endpoint_url, model_name = get_caption_classifier_service()
+ auth_token = os.environ.get(
+ "NVIDIA_BUILD_API_KEY",
+ "",
+ ) or os.environ.get(
+ "NGC_API_KEY",
+ "",
+ )
+
+ endpoint_url = os.environ.get("VLM_CAPTION_ENDPOINT")
+
image_caption_config = ingest_config.get(
"image_caption_extraction_module",
{
- "caption_classifier_model_name": model_name,
+ "api_key": auth_token,
"endpoint_url": endpoint_url,
+ "prompt": "Caption the content of this image:",
},
)
+
image_caption_stage = pipe.add_stage(
generate_caption_extraction_stage(
morpheus_pipeline_config,
diff --git a/src/pipeline.py b/src/pipeline.py
index 4d8289d7..6397765b 100644
--- a/src/pipeline.py
+++ b/src/pipeline.py
@@ -49,6 +49,7 @@ def setup_ingestion_pipeline(
## Primitive extraction
########################################################################################################
pdf_extractor_stage = add_pdf_extractor_stage(pipe, morpheus_pipeline_config, ingest_config, default_cpu_count)
+ image_extractor_stage = add_image_extractor_stage(pipe, morpheus_pipeline_config, ingest_config, default_cpu_count)
docx_extractor_stage = add_docx_extractor_stage(pipe, morpheus_pipeline_config, default_cpu_count)
pptx_extractor_stage = add_pptx_extractor_stage(pipe, morpheus_pipeline_config, default_cpu_count)
########################################################################################################
@@ -60,6 +61,7 @@ def setup_ingestion_pipeline(
image_filter_stage = add_image_filter_stage(pipe, morpheus_pipeline_config, ingest_config, default_cpu_count)
table_extraction_stage = add_table_extractor_stage(pipe, morpheus_pipeline_config, ingest_config, default_cpu_count)
chart_extraction_stage = add_chart_extractor_stage(pipe, morpheus_pipeline_config, ingest_config, default_cpu_count)
+ image_caption_stage = add_image_caption_stage(pipe, morpheus_pipeline_config, ingest_config, default_cpu_count)
########################################################################################################
########################################################################################################
@@ -73,10 +75,10 @@ def setup_ingestion_pipeline(
## Storage and output
########################################################################################################
image_storage_stage = add_image_storage_stage(pipe, morpheus_pipeline_config)
+ vdb_task_sink_stage = add_vdb_task_sink_stage(pipe, morpheus_pipeline_config, ingest_config)
sink_stage = add_sink_stage(
pipe, morpheus_pipeline_config, ingest_config, message_provider_host, message_provider_port
)
- vdb_task_sink_stage = add_vdb_task_sink_stage(pipe, morpheus_pipeline_config, ingest_config)
########################################################################################################
#######################################################################################################
@@ -93,14 +95,16 @@ def setup_ingestion_pipeline(
pipe.add_edge(source_stage, submitted_job_counter_stage)
pipe.add_edge(submitted_job_counter_stage, metadata_injector_stage)
pipe.add_edge(metadata_injector_stage, pdf_extractor_stage)
- pipe.add_edge(pdf_extractor_stage, docx_extractor_stage)
+ pipe.add_edge(pdf_extractor_stage, image_extractor_stage)
+ pipe.add_edge(image_extractor_stage, docx_extractor_stage)
pipe.add_edge(docx_extractor_stage, pptx_extractor_stage)
pipe.add_edge(pptx_extractor_stage, image_dedup_stage)
pipe.add_edge(image_dedup_stage, image_filter_stage)
pipe.add_edge(image_filter_stage, table_extraction_stage)
pipe.add_edge(table_extraction_stage, chart_extraction_stage)
pipe.add_edge(chart_extraction_stage, nemo_splitter_stage)
- pipe.add_edge(nemo_splitter_stage, embed_extractions_stage)
+ pipe.add_edge(nemo_splitter_stage, image_caption_stage)
+ pipe.add_edge(image_caption_stage, embed_extractions_stage)
pipe.add_edge(embed_extractions_stage, image_storage_stage)
pipe.add_edge(image_storage_stage, vdb_task_sink_stage)
pipe.add_edge(vdb_task_sink_stage, sink_stage)
diff --git a/tests/stages/nims/__init__.py b/tests/nv_ingest/extraction_workflows/image/__init__.py
similarity index 100%
rename from tests/stages/nims/__init__.py
rename to tests/nv_ingest/extraction_workflows/image/__init__.py
diff --git a/tests/nv_ingest/extraction_workflows/image/test_image_handlers.py b/tests/nv_ingest/extraction_workflows/image/test_image_handlers.py
new file mode 100644
index 00000000..b58c235f
--- /dev/null
+++ b/tests/nv_ingest/extraction_workflows/image/test_image_handlers.py
@@ -0,0 +1,354 @@
+import urllib
+from pyexpat import ExpatError
+from xml.etree.ElementTree import ParseError
+
+from wand.exceptions import WandException
+
+from nv_ingest.extraction_workflows.image.image_handlers import load_and_preprocess_image, convert_svg_to_bitmap, \
+ extract_table_and_chart_images
+from PIL import Image
+import io
+import numpy as np
+from typing import List, Tuple
+
+from nv_ingest.extraction_workflows.image.image_handlers import process_inference_results
+from nv_ingest.util.pdf.metadata_aggregators import CroppedImageWithContent
+
+
+def test_load_and_preprocess_image_jpeg():
+ """Test loading and preprocessing a JPEG image."""
+ # Create a small sample image and save it to a BytesIO stream as JPEG
+ image = Image.new("RGB", (10, 10), color="red")
+ image_stream = io.BytesIO()
+ image.save(image_stream, format="JPEG")
+ image_stream.seek(0)
+
+ # Load and preprocess the image
+ result = load_and_preprocess_image(image_stream)
+
+ # Check the output type and shape
+ assert isinstance(result, np.ndarray)
+ assert result.shape == (10, 10, 3)
+ assert result.dtype == np.float32
+ assert np.all(result[:, :, 0] == 254) # All red pixels
+ assert np.all(result[:, :, 1] == 0) # No green
+ assert np.all(result[:, :, 2] == 0) # No blue
+
+
+def test_load_and_preprocess_image_png():
+ """Test loading and preprocessing a PNG image."""
+ # Create a small sample image and save it to a BytesIO stream as PNG
+ image = Image.new("RGB", (5, 5), color="blue")
+ image_stream = io.BytesIO()
+ image.save(image_stream, format="PNG")
+ image_stream.seek(0)
+
+ # Load and preprocess the image
+ result = load_and_preprocess_image(image_stream)
+
+ # Check the output type and shape
+ assert isinstance(result, np.ndarray)
+ assert result.shape == (5, 5, 3)
+ assert result.dtype == np.float32
+ assert np.all(result[:, :, 0] == 0) # No red
+ assert np.all(result[:, :, 1] == 0) # No green
+ assert np.all(result[:, :, 2] == 255) # All blue pixels
+
+
+def test_load_and_preprocess_image_invalid_format():
+ """Test that an invalid image format raises an error."""
+ # Create a BytesIO stream with non-image content
+ invalid_stream = io.BytesIO(b"This is not an image file")
+
+ # Expect an OSError when trying to open a non-image stream
+ try:
+ load_and_preprocess_image(invalid_stream)
+ except OSError as e:
+ assert "cannot identify image file" in str(e)
+
+
+def test_load_and_preprocess_image_corrupt_image():
+ """Test that a corrupt image raises an error."""
+ # Create a valid JPEG header but corrupt the rest
+ corrupt_stream = io.BytesIO(b"\xFF\xD8\xFF\xE0" + b"\x00" * 10)
+
+ # Expect an OSError when trying to open a corrupt image stream
+ try:
+ load_and_preprocess_image(corrupt_stream)
+ except OSError as e:
+ assert "cannot identify image file" in str(e)
+
+
+def test_convert_svg_to_bitmap_basic_svg():
+ """Test converting a simple SVG to a bitmap image."""
+ # Sample SVG image data (a small red square)
+ svg_data = b"""
+
+ """
+ image_stream = io.BytesIO(svg_data)
+
+ # Convert SVG to bitmap
+ result = convert_svg_to_bitmap(image_stream)
+
+ # Check the output type, shape, and color values
+ assert isinstance(result, np.ndarray)
+ assert result.shape == (10, 10, 3)
+ assert result.dtype == np.float32
+ assert np.all(result[:, :, 0] == 255) # Red channel fully on
+ assert np.all(result[:, :, 1] == 0) # Green channel off
+ assert np.all(result[:, :, 2] == 0) # Blue channel off
+
+
+def test_convert_svg_to_bitmap_large_svg():
+ """Test converting a larger SVG to ensure scalability."""
+ # Large SVG image data (blue rectangle 100x100)
+ svg_data = b"""
+
+ """
+ image_stream = io.BytesIO(svg_data)
+
+ # Convert SVG to bitmap
+ result = convert_svg_to_bitmap(image_stream)
+
+ # Check the output type and shape
+ assert isinstance(result, np.ndarray)
+ assert result.shape == (100, 100, 3)
+ assert result.dtype == np.float32
+ assert np.all(result[:, :, 0] == 0) # Red channel off
+ assert np.all(result[:, :, 1] == 0) # Green channel off
+ assert np.all(result[:, :, 2] == 255) # Blue channel fully on
+
+
+def test_process_inference_results_basic_case():
+ """Test process_inference_results with a typical valid input."""
+
+ # Simulated model output array for a single image with several detections.
+ # Array format is (batch_size, num_detections, 85) - 80 classes + 5 box coordinates
+ # For simplicity, use random values for the boxes and class predictions.
+ output_array = np.zeros((1, 3, 85), dtype=np.float32)
+
+ # Mock bounding box coordinates
+ output_array[0, 0, :4] = [0.5, 0.5, 0.2, 0.2] # x_center, y_center, width, height
+ output_array[0, 1, :4] = [0.6, 0.6, 0.2, 0.2]
+ output_array[0, 2, :4] = [0.7, 0.7, 0.2, 0.2]
+
+ # Mock object confidence scores
+ output_array[0, :, 4] = [0.8, 0.9, 0.85]
+
+ # Mock class scores (set class 1 with highest confidence for simplicity)
+ output_array[0, 0, 5 + 1] = 0.7
+ output_array[0, 1, 5 + 1] = 0.75
+ output_array[0, 2, 5 + 1] = 0.72
+
+ original_image_shapes = [(640, 640)] # Original shape of the image before resizing
+
+ # Process inference results with thresholds that should retain all mock detections
+ results = process_inference_results(
+ output_array,
+ original_image_shapes,
+ num_classes=80,
+ conf_thresh=0.5,
+ iou_thresh=0.5,
+ min_score=0.1,
+ final_thresh=0.3,
+ )
+
+ # Check output structure
+ assert isinstance(results, list)
+ assert len(results) == 1
+ assert isinstance(results[0], dict)
+
+ # Validate bounding box scaling and structure
+ assert "chart" in results[0] or "table" in results[0]
+ if "chart" in results[0]:
+ assert isinstance(results[0]["chart"], list)
+ assert len(results[0]["chart"]) > 0
+ # Check bounding box format for each detected "chart" item (5 values per box)
+ for bbox in results[0]["chart"]:
+ assert len(bbox) == 5 # [x1, y1, x2, y2, score]
+ assert bbox[4] >= 0.3 # score meets final threshold
+
+ print("Processed inference results:", results)
+
+
+def test_process_inference_results_multiple_images():
+ """Test with multiple images to verify batch processing."""
+ # Simulate model output with 2 images and 3 detections each
+ output_array = np.zeros((2, 3, 85), dtype=np.float32)
+ # Set bounding boxes and confidence for the mock detections
+ output_array[0, 0, :5] = [0.5, 0.5, 0.2, 0.2, 0.8]
+ output_array[0, 1, :5] = [0.6, 0.6, 0.2, 0.2, 0.7]
+ output_array[1, 0, :5] = [0.4, 0.4, 0.1, 0.1, 0.9]
+ # Assign class confidences for classes 0 and 1
+ output_array[0, 0, 5 + 1] = 0.75
+ output_array[0, 1, 5 + 1] = 0.65
+ output_array[1, 0, 5 + 0] = 0.8
+
+ original_image_shapes = [(640, 640), (800, 800)]
+
+ results = process_inference_results(
+ output_array,
+ original_image_shapes,
+ num_classes=80,
+ conf_thresh=0.5,
+ iou_thresh=0.5,
+ min_score=0.1,
+ final_thresh=0.3,
+ )
+
+ assert isinstance(results, list)
+ assert len(results) == 2
+ for result in results:
+ assert isinstance(result, dict)
+ if "chart" in result:
+ assert all(len(bbox) == 5 and bbox[4] >= 0.3 for bbox in result["chart"])
+
+
+def test_process_inference_results_high_confidence_threshold():
+ """Test with a high confidence threshold to verify filtering."""
+ output_array = np.zeros((1, 5, 85), dtype=np.float32)
+ # Set low confidence scores below the threshold
+ output_array[0, :, 4] = [0.2, 0.3, 0.4, 0.4, 0.2]
+ output_array[0, :, 5] = [0.5] * 5 # Class confidence
+
+ original_image_shapes = [(640, 640)]
+
+ results = process_inference_results(
+ output_array,
+ original_image_shapes,
+ num_classes=80,
+ conf_thresh=0.9, # High confidence threshold
+ iou_thresh=0.5,
+ min_score=0.1,
+ final_thresh=0.3,
+ )
+
+ assert isinstance(results, list)
+ assert len(results) == 1
+ assert results[0] == {} # No detections should pass the high confidence threshold
+
+
+def test_process_inference_results_varied_num_classes():
+ """Test compatibility with different model class counts."""
+ output_array = np.zeros((1, 3, 25), dtype=np.float32) # 20 classes + 5 box coords
+ # Assign box, object confidence, and class scores
+ output_array[0, 0, :5] = [0.5, 0.5, 0.2, 0.2, 0.8]
+ output_array[0, 1, :5] = [0.6, 0.6, 0.3, 0.3, 0.7]
+ output_array[0, 0, 5 + 1] = 0.9 # Assign highest confidence to class 1
+
+ original_image_shapes = [(640, 640)]
+
+ results = process_inference_results(
+ output_array,
+ original_image_shapes,
+ num_classes=20, # Different class count
+ conf_thresh=0.5,
+ iou_thresh=0.5,
+ min_score=0.1,
+ final_thresh=0.3,
+ )
+
+ assert isinstance(results, list)
+ assert len(results) == 1
+ assert isinstance(results[0], dict)
+ assert "chart" in results[0]
+ assert len(results[0]["chart"]) > 0 # Verify detections processed correctly with 20 classes
+
+
+def crop_image(image: np.ndarray, bbox: Tuple[int, int, int, int]) -> np.ndarray:
+ """Mock function to simulate cropping an image."""
+ h1, w1, h2, w2 = bbox
+ return image[int(h1):int(h2), int(w1):int(w2)]
+
+
+def test_extract_table_and_chart_images_empty_annotations():
+ """Test when annotation_dict has no objects to extract."""
+ annotation_dict = {"table": [], "chart": []}
+ original_image = np.random.rand(640, 640, 3)
+ tables_and_charts = []
+
+ extract_table_and_chart_images(annotation_dict, original_image, 0, tables_and_charts)
+
+ # Expect no entries added to tables_and_charts since there are no objects
+ assert tables_and_charts == []
+
+
+def test_extract_table_and_chart_images_single_table():
+ """Test extraction with a single table bounding box."""
+ annotation_dict = {"table": [[0.1, 0.1, 0.3, 0.3, 0.8]], "chart": []}
+ original_image = np.random.rand(640, 640, 3)
+ tables_and_charts = []
+
+ extract_table_and_chart_images(annotation_dict, original_image, 0, tables_and_charts)
+
+ # Expect one entry in tables_and_charts for the table
+ assert len(tables_and_charts) == 1
+ page_idx, cropped_image_data = tables_and_charts[0]
+ assert page_idx == 0
+ assert isinstance(cropped_image_data, CroppedImageWithContent)
+
+ # Verify attribute values
+ assert cropped_image_data.content == ""
+ assert cropped_image_data.type_string == "table"
+ assert cropped_image_data.bbox == (64, 64, 192, 192) # Scaled bounding box from (0.1, 0.1, 0.3, 0.3)
+ assert cropped_image_data.max_width == 640
+ assert cropped_image_data.max_height == 640
+ assert isinstance(cropped_image_data.image, str) # Assuming the image is base64-encoded
+
+
+def test_extract_table_and_chart_images_single_chart():
+ """Test extraction with a single chart bounding box."""
+ annotation_dict = {"table": [], "chart": [[0.4, 0.4, 0.6, 0.6, 0.9]]}
+ original_image = np.random.rand(640, 640, 3)
+ tables_and_charts = []
+
+ extract_table_and_chart_images(annotation_dict, original_image, 1, tables_and_charts)
+
+ # Expect one entry in tables_and_charts for the chart
+ assert len(tables_and_charts) == 1
+ page_idx, cropped_image_data = tables_and_charts[0]
+ assert page_idx == 1
+ assert isinstance(cropped_image_data, CroppedImageWithContent)
+ assert cropped_image_data.type_string == "chart"
+ assert cropped_image_data.bbox == (256, 256, 384, 384) # Scaled bounding box
+
+
+def test_extract_table_and_chart_images_multiple_objects():
+ """Test extraction with multiple table and chart objects."""
+ annotation_dict = {
+ "table": [[0.1, 0.1, 0.3, 0.3, 0.8], [0.5, 0.5, 0.7, 0.7, 0.85]],
+ "chart": [[0.2, 0.2, 0.4, 0.4, 0.9]]
+ }
+ original_image = np.random.rand(640, 640, 3)
+ tables_and_charts = []
+
+ extract_table_and_chart_images(annotation_dict, original_image, 2, tables_and_charts)
+
+ # Expect three entries in tables_and_charts: two tables and one chart
+ assert len(tables_and_charts) == 3
+ for page_idx, cropped_image_data in tables_and_charts:
+ assert page_idx == 2
+ assert isinstance(cropped_image_data, CroppedImageWithContent)
+ assert cropped_image_data.type_string in ["table", "chart"]
+ assert cropped_image_data.bbox is not None # Bounding box should be defined
+
+
+def test_extract_table_and_chart_images_invalid_bounding_box():
+ """Test with an invalid bounding box to check handling of incorrect coordinates."""
+ annotation_dict = {"table": [[1.1, 1.1, 1.5, 1.5, 0.9]], "chart": []} # Out of bounds
+ original_image = np.random.rand(640, 640, 3)
+ tables_and_charts = []
+
+ extract_table_and_chart_images(annotation_dict, original_image, 3, tables_and_charts)
+
+ # Verify that the function processes the bounding box as is
+ assert len(tables_and_charts) == 1
+ page_idx, cropped_image_data = tables_and_charts[0]
+ assert page_idx == 3
+ assert isinstance(cropped_image_data, CroppedImageWithContent)
+ assert cropped_image_data.type_string == "table"
+ assert cropped_image_data.bbox == (704, 704, 960, 960) # Scaled bounding box with out-of-bounds values
diff --git a/tests/nv_ingest/extraction_workflows/pdf/test_pdfium_helper.py b/tests/nv_ingest/extraction_workflows/pdf/test_pdfium_helper.py
index 3455271d..b9999603 100644
--- a/tests/nv_ingest/extraction_workflows/pdf/test_pdfium_helper.py
+++ b/tests/nv_ingest/extraction_workflows/pdf/test_pdfium_helper.py
@@ -9,7 +9,7 @@
import pandas as pd
import pytest
-from nv_ingest.extraction_workflows.pdf.pdfium_helper import pdfium
+from nv_ingest.extraction_workflows.pdf.pdfium_helper import pdfium_extractor
from nv_ingest.schemas.metadata_schema import TextTypeEnum
@@ -38,8 +38,8 @@ def pdf_stream_embedded_tables_pdf():
@pytest.mark.xfail(reason="PDFium conversion required")
-def test_pdfium_basic(pdf_stream_test_pdf, document_df):
- extracted_data = pdfium(
+def test_pdfium_extractor_basic(pdf_stream_test_pdf, document_df):
+ extracted_data = pdfium_extractor(
pdf_stream_test_pdf,
extract_text=True,
extract_images=False,
@@ -63,8 +63,8 @@ def test_pdfium_basic(pdf_stream_test_pdf, document_df):
"text_depth",
["span", TextTypeEnum.SPAN, "line", TextTypeEnum.LINE, "block", TextTypeEnum.BLOCK],
)
-def test_pdfium_text_depth_line(pdf_stream_test_pdf, document_df, text_depth):
- extracted_data = pdfium(
+def test_pdfium_extractor_text_depth_line(pdf_stream_test_pdf, document_df, text_depth):
+ extracted_data = pdfium_extractor(
pdf_stream_test_pdf,
extract_text=True,
extract_images=False,
@@ -89,8 +89,8 @@ def test_pdfium_text_depth_line(pdf_stream_test_pdf, document_df, text_depth):
"text_depth",
["page", TextTypeEnum.PAGE, "document", TextTypeEnum.DOCUMENT],
)
-def test_pdfium_text_depth_page(pdf_stream_test_pdf, document_df, text_depth):
- extracted_data = pdfium(
+def test_pdfium_extractor_text_depth_page(pdf_stream_test_pdf, document_df, text_depth):
+ extracted_data = pdfium_extractor(
pdf_stream_test_pdf,
extract_text=True,
extract_images=False,
@@ -111,8 +111,8 @@ def test_pdfium_text_depth_page(pdf_stream_test_pdf, document_df, text_depth):
@pytest.mark.xfail(reason="PDFium conversion required")
-def test_pdfium_extract_image(pdf_stream_test_pdf, document_df):
- extracted_data = pdfium(
+def test_pdfium_extractor_extract_image(pdf_stream_test_pdf, document_df):
+ extracted_data = pdfium_extractor(
pdf_stream_test_pdf,
extract_text=True,
extract_images=True,
@@ -147,8 +147,8 @@ def read_markdown_table(table_str: str) -> pd.DataFrame:
@pytest.mark.xfail(reason="PDFium conversion required")
-def test_pdfium_table_extraction_on_pdf_with_no_tables(pdf_stream_test_pdf, document_df):
- extracted_data = pdfium(
+def test_pdfium_extractor_table_extraction_on_pdf_with_no_tables(pdf_stream_test_pdf, document_df):
+ extracted_data = pdfium_extractor(
pdf_stream_test_pdf,
extract_text=False,
extract_images=False,
@@ -162,11 +162,11 @@ def test_pdfium_table_extraction_on_pdf_with_no_tables(pdf_stream_test_pdf, docu
@pytest.mark.xfail(reason="PDFium conversion required")
-def test_pdfium_table_extraction_on_pdf_with_tables(pdf_stream_embedded_tables_pdf, document_df):
+def test_pdfium_extractor_table_extraction_on_pdf_with_tables(pdf_stream_embedded_tables_pdf, document_df):
"""
Test to ensure pdfium's table extraction is able to extract easy-to-read tables from a PDF.
"""
- extracted_data = pdfium(
+ extracted_data = pdfium_extractor(
pdf_stream_embedded_tables_pdf,
extract_text=False,
extract_images=False,
diff --git a/tests/nv_ingest/schemas/test_image_caption_extraction_schema.py b/tests/nv_ingest/schemas/test_image_caption_extraction_schema.py
new file mode 100644
index 00000000..c059a253
--- /dev/null
+++ b/tests/nv_ingest/schemas/test_image_caption_extraction_schema.py
@@ -0,0 +1,70 @@
+import pytest
+from pydantic import ValidationError
+
+from nv_ingest.schemas import ImageCaptionExtractionSchema
+
+
+def test_valid_schema():
+ # Test with all required fields and optional defaults
+ valid_data = {
+ "api_key": "your-api-key-here",
+ }
+ schema = ImageCaptionExtractionSchema(**valid_data)
+ assert schema.api_key == "your-api-key-here"
+ assert schema.endpoint_url == "https://ai.api.nvidia.com/v1/gr/meta/llama-3.2-90b-vision-instruct/chat/completions"
+ assert schema.prompt == "Caption the content of this image:"
+ assert schema.raise_on_failure is False
+
+
+def test_valid_schema_with_custom_values():
+ # Test with all fields including custom values for optional fields
+ valid_data = {
+ "api_key": "your-api-key-here",
+ "endpoint_url": "https://custom.api.endpoint",
+ "prompt": "Describe the image:",
+ "raise_on_failure": True,
+ }
+ schema = ImageCaptionExtractionSchema(**valid_data)
+ assert schema.api_key == "your-api-key-here"
+ assert schema.endpoint_url == "https://custom.api.endpoint"
+ assert schema.prompt == "Describe the image:"
+ assert schema.raise_on_failure is True
+
+
+def test_missing_api_key():
+ # Test with missing required field `api_key`
+ with pytest.raises(ValidationError) as exc_info:
+ ImageCaptionExtractionSchema()
+ assert "field required" in str(exc_info.value)
+
+
+def test_invalid_extra_field():
+ # Test with an additional field that should be forbidden
+ data_with_extra_field = {
+ "api_key": "your-api-key-here",
+ "extra_field": "should_not_be_allowed"
+ }
+ with pytest.raises(ValidationError) as exc_info:
+ ImageCaptionExtractionSchema(**data_with_extra_field)
+ assert "extra fields not permitted" in str(exc_info.value)
+
+
+def test_invalid_field_types():
+ # Test with wrong types for optional fields
+ invalid_data = {
+ "api_key": "your-api-key-here",
+ "endpoint_url": 12345, # invalid type
+ "prompt": 123, # invalid type
+ "raise_on_failure": "not_boolean" # invalid type
+ }
+ with pytest.raises(ValidationError) as exc_info:
+ ImageCaptionExtractionSchema(**invalid_data)
+
+
+def test_default_values():
+ # Test that default values are correctly assigned when not provided
+ data = {"api_key": "your-api-key-here"}
+ schema = ImageCaptionExtractionSchema(**data)
+ assert schema.endpoint_url == "https://ai.api.nvidia.com/v1/gr/meta/llama-3.2-90b-vision-instruct/chat/completions"
+ assert schema.prompt == "Caption the content of this image:"
+ assert schema.raise_on_failure is False
diff --git a/tests/nv_ingest/schemas/test_image_extrator_schema.py b/tests/nv_ingest/schemas/test_image_extrator_schema.py
new file mode 100644
index 00000000..bf1ea2d6
--- /dev/null
+++ b/tests/nv_ingest/schemas/test_image_extrator_schema.py
@@ -0,0 +1,119 @@
+import pytest
+from pydantic import ValidationError
+
+from nv_ingest.schemas.image_extractor_schema import ImageConfigSchema, ImageExtractorSchema
+
+
+def test_image_config_schema_valid():
+ # Test valid data with both gRPC and HTTP endpoints
+ config = ImageConfigSchema(
+ auth_token="token123",
+ yolox_endpoints=("grpc_service_url", "http_service_url"),
+ yolox_infer_protocol="http"
+ )
+ assert config.auth_token == "token123"
+ assert config.yolox_endpoints == ("grpc_service_url", "http_service_url")
+ assert config.yolox_infer_protocol == "http"
+
+
+def test_image_config_schema_valid_single_service():
+ # Test valid data with only gRPC service
+ config = ImageConfigSchema(
+ yolox_endpoints=("grpc_service_url", None),
+ )
+ assert config.yolox_endpoints == ("grpc_service_url", None)
+ assert config.yolox_infer_protocol == "grpc"
+
+ # Test valid data with only HTTP service
+ config = ImageConfigSchema(
+ yolox_endpoints=(None, "http_service_url"),
+ )
+ assert config.yolox_endpoints == (None, "http_service_url")
+ assert config.yolox_infer_protocol == "http"
+
+
+def test_image_config_schema_invalid_both_services_empty():
+ # Test invalid data with both gRPC and HTTP services empty
+ with pytest.raises(ValidationError) as exc_info:
+ ImageConfigSchema(yolox_endpoints=(None, None))
+ errors = exc_info.value.errors()
+ assert any("Both gRPC and HTTP services cannot be empty" in error['msg'] for error in errors)
+
+
+def test_image_config_schema_empty_service_strings():
+ # Test services that are empty strings or whitespace
+ config = ImageConfigSchema(
+ yolox_endpoints=(" ", "http_service_url")
+ )
+ assert config.yolox_endpoints == (None, "http_service_url") # Cleaned empty strings are None
+
+
+def test_image_config_schema_missing_infer_protocol():
+ # Test infer_protocol default setting based on available service
+ config = ImageConfigSchema(
+ yolox_endpoints=("grpc_service_url", None)
+ )
+ assert config.yolox_infer_protocol == "grpc"
+
+
+def test_image_config_schema_extra_field():
+ # Test extra fields raise a validation error
+ with pytest.raises(ValidationError):
+ ImageConfigSchema(
+ auth_token="token123",
+ yolox_endpoints=("grpc_service_url", "http_service_url"),
+ extra_field="should_not_be_allowed"
+ )
+
+
+def test_image_extractor_schema_valid():
+ # Test valid data for ImageExtractorSchema with nested ImageConfigSchema
+ config = ImageExtractorSchema(
+ max_queue_size=10,
+ n_workers=4,
+ raise_on_failure=True,
+ image_extraction_config=ImageConfigSchema(
+ auth_token="token123",
+ yolox_endpoints=("grpc_service_url", "http_service_url"),
+ yolox_infer_protocol="http"
+ )
+ )
+ assert config.max_queue_size == 10
+ assert config.n_workers == 4
+ assert config.raise_on_failure is True
+ assert config.image_extraction_config.auth_token == "token123"
+
+
+def test_image_extractor_schema_defaults():
+ # Test default values for optional fields
+ config = ImageExtractorSchema()
+ assert config.max_queue_size == 1
+ assert config.n_workers == 16
+ assert config.raise_on_failure is False
+ assert config.image_extraction_config is None
+
+
+def test_image_extractor_schema_invalid_max_queue_size():
+ # Test invalid type for max_queue_size
+ with pytest.raises(ValidationError) as exc_info:
+ ImageExtractorSchema(max_queue_size="invalid_type")
+ errors = exc_info.value.errors()
+ assert any(error['loc'] == ('max_queue_size',) and error['type'] == 'type_error.integer' for error in errors)
+
+
+def test_image_extractor_schema_invalid_n_workers():
+ # Test invalid type for n_workers
+ with pytest.raises(ValidationError) as exc_info:
+ ImageExtractorSchema(n_workers="invalid_type")
+ errors = exc_info.value.errors()
+ assert any(error['loc'] == ('n_workers',) and error['type'] == 'type_error.integer' for error in errors)
+
+
+def test_image_extractor_schema_invalid_nested_config():
+ # Test invalid nested image_extraction_config
+ with pytest.raises(ValidationError) as exc_info:
+ ImageExtractorSchema(
+ image_extraction_config={"auth_token": "token123", "yolox_endpoints": (None, None)}
+ )
+ errors = exc_info.value.errors()
+ assert any("Both gRPC and HTTP services cannot be empty" in error['msg'] for error in errors)
diff --git a/tests/nv_ingest/stages/__init__.py b/tests/nv_ingest/stages/__init__.py
new file mode 100644
index 00000000..6a35633c
--- /dev/null
+++ b/tests/nv_ingest/stages/__init__.py
@@ -0,0 +1,4 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES.
+# All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+
diff --git a/tests/nv_ingest/stages/nims/__init__.py b/tests/nv_ingest/stages/nims/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/tests/stages/nims/test_chart_extraction.py b/tests/nv_ingest/stages/nims/test_chart_extraction.py
similarity index 100%
rename from tests/stages/nims/test_chart_extraction.py
rename to tests/nv_ingest/stages/nims/test_chart_extraction.py
diff --git a/tests/stages/nims/test_table_extraction.py b/tests/nv_ingest/stages/nims/test_table_extraction.py
similarity index 100%
rename from tests/stages/nims/test_table_extraction.py
rename to tests/nv_ingest/stages/nims/test_table_extraction.py
diff --git a/tests/nv_ingest/stages/test_image_extractor_stage.py b/tests/nv_ingest/stages/test_image_extractor_stage.py
new file mode 100644
index 00000000..6965ec48
--- /dev/null
+++ b/tests/nv_ingest/stages/test_image_extractor_stage.py
@@ -0,0 +1,229 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES.
+# All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+import base64
+
+import pytest
+from unittest.mock import patch, MagicMock
+import pandas as pd
+
+from nv_ingest.stages.extractors.image_extractor_stage import process_image
+from nv_ingest.stages.extractors.image_extractor_stage import decode_and_extract
+
+MODULE_UNDER_TEST = 'nv_ingest.stages.extractors.image_extractor_stage'
+
+
+# Define the test function using pytest
+@patch(f'{MODULE_UNDER_TEST}.decode_and_extract')
+def test_process_image_single_row(mock_decode_and_extract):
+ mock_decode_and_extract.return_value = [
+ {"document_type": "type1", "metadata": {"key": "value"}, "uuid": "1234"}
+ ]
+
+ input_df = pd.DataFrame({
+ "source_id": [1],
+ "content": ["base64encodedstring"]
+ })
+
+ task_props = {"method": "some_method"}
+ validated_config = MagicMock()
+ trace_info = {}
+
+ processed_df, trace_info_output = process_image(input_df, task_props, validated_config, trace_info)
+
+ assert len(processed_df) == 1
+ assert "document_type" in processed_df.columns
+ assert "metadata" in processed_df.columns
+ assert "uuid" in processed_df.columns
+ assert processed_df.iloc[0]["document_type"] == "type1"
+ assert processed_df.iloc[0]["metadata"] == {"key": "value"}
+ assert processed_df.iloc[0]["uuid"] == "1234"
+ assert trace_info_output["trace_info"] == trace_info
+
+
+@patch(f'{MODULE_UNDER_TEST}.decode_and_extract')
+def test_process_image_empty_dataframe(mock_decode_and_extract):
+ mock_decode_and_extract.return_value = []
+
+ input_df = pd.DataFrame(columns=["source_id", "content"])
+ task_props = {"method": "some_method"}
+ validated_config = MagicMock()
+ trace_info = {}
+
+ processed_df, trace_info_output = process_image(input_df, task_props, validated_config, trace_info)
+
+ assert processed_df.empty
+ assert "document_type" in processed_df.columns
+ assert "metadata" in processed_df.columns
+ assert "uuid" in processed_df.columns
+ assert trace_info_output["trace_info"] == trace_info
+
+
+@patch(f'{MODULE_UNDER_TEST}.decode_and_extract')
+def test_process_image_multiple_rows(mock_decode_and_extract):
+ mock_decode_and_extract.side_effect = [
+ [{"document_type": "type1", "metadata": {"key": "value1"}, "uuid": "1234"}],
+ [{"document_type": "type2", "metadata": {"key": "value2"}, "uuid": "5678"}]
+ ]
+
+ input_df = pd.DataFrame({
+ "source_id": [1, 2],
+ "content": ["base64encodedstring1", "base64encodedstring2"]
+ })
+
+ task_props = {"method": "some_method"}
+ validated_config = MagicMock()
+ trace_info = {}
+
+ processed_df, trace_info_output = process_image(input_df, task_props, validated_config, trace_info)
+
+ assert len(processed_df) == 2
+ assert processed_df.iloc[0]["document_type"] == "type1"
+ assert processed_df.iloc[0]["metadata"] == {"key": "value1"}
+ assert processed_df.iloc[0]["uuid"] == "1234"
+ assert processed_df.iloc[1]["document_type"] == "type2"
+ assert processed_df.iloc[1]["metadata"] == {"key": "value2"}
+ assert processed_df.iloc[1]["uuid"] == "5678"
+ assert trace_info_output["trace_info"] == trace_info
+
+
+@patch(f'{MODULE_UNDER_TEST}.decode_and_extract')
+def test_process_image_with_exception(mock_decode_and_extract):
+ mock_decode_and_extract.side_effect = Exception("Decoding error")
+
+ input_df = pd.DataFrame({
+ "source_id": [1],
+ "content": ["base64encodedstring"]
+ })
+
+ task_props = {"method": "some_method"}
+ validated_config = MagicMock()
+ trace_info = {}
+
+ with pytest.raises(Exception) as excinfo:
+ process_image(input_df, task_props, validated_config, trace_info)
+
+ assert "Decoding error" in str(excinfo.value)
+
+
+@patch(f'{MODULE_UNDER_TEST}.image_helpers')
+def test_decode_and_extract_valid_method(mock_image_helpers):
+ # Mock the extraction function inside image_helpers
+ mock_func = MagicMock(return_value="extracted_data")
+ mock_image_helpers.image = mock_func # Default extraction method
+
+ # Sample inputs as a pandas Series (row)
+ base64_content = base64.b64encode(b"dummy_image_data").decode('utf-8')
+ base64_row = pd.Series({
+ "content": base64_content,
+ "document_type": "image",
+ "source_id": 1
+ })
+ task_props = {"method": "image", "params": {}}
+ validated_config = MagicMock()
+ trace_info = []
+
+ # Call the function
+ result = decode_and_extract(base64_row, task_props, validated_config, default="image", trace_info=trace_info)
+
+ # Assert that the mocked function was called correctly
+ assert result == "extracted_data"
+ mock_func.assert_called_once()
+
+
+@patch(f'{MODULE_UNDER_TEST}.image_helpers')
+def test_decode_and_extract_missing_content_key(mock_image_helpers):
+ # Sample inputs with missing 'content' key as a pandas Series (row)
+ base64_row = pd.Series({
+ "document_type": "image",
+ "source_id": 1
+ })
+ task_props = {"method": "image", "params": {}}
+ validated_config = MagicMock()
+ trace_info = []
+
+ # Expecting a KeyError
+ with pytest.raises(KeyError):
+ decode_and_extract(base64_row, task_props, validated_config, trace_info=trace_info)
+
+
+@patch(f'{MODULE_UNDER_TEST}.image_helpers')
+def test_decode_and_extract_fallback_to_default_method(mock_image_helpers):
+ # Mock only the default method; other methods will appear as non-existent
+ mock_default_func = MagicMock(return_value="default_extracted_data")
+ setattr(mock_image_helpers, 'default', mock_default_func)
+
+ # Ensure that non_existing_method does not exist on mock_image_helpers
+ if hasattr(mock_image_helpers, "non_existing_method"):
+ delattr(mock_image_helpers, "non_existing_method")
+
+ # Input with a non-existing extraction method as a pandas Series (row)
+ base64_content = base64.b64encode(b"dummy_image_data").decode('utf-8')
+ base64_row = pd.Series({
+ "content": base64_content,
+ "document_type": "image",
+ "source_id": 1
+ })
+ task_props = {"method": "non_existing_method", "params": {}}
+ validated_config = MagicMock()
+ trace_info = []
+
+ # Call the function
+ result = decode_and_extract(base64_row, task_props, validated_config, default="default", trace_info=trace_info)
+
+ # Assert that the default function was called instead of the missing one
+ assert result == "default_extracted_data"
+ mock_default_func.assert_called_once()
+
+
+@patch(f'{MODULE_UNDER_TEST}.image_helpers')
+def test_decode_and_extract_with_trace_info(mock_image_helpers):
+ # Mock the extraction function with trace_info usage
+ mock_func = MagicMock(return_value="extracted_data_with_trace")
+ mock_image_helpers.image = mock_func # Default extraction method
+
+ # Sample inputs with trace_info as a pandas Series (row)
+ base64_content = base64.b64encode(b"dummy_image_data").decode('utf-8')
+ base64_row = pd.Series({
+ "content": base64_content,
+ "document_type": "image",
+ "source_id": 1
+ })
+ task_props = {"method": "image", "params": {}}
+ validated_config = MagicMock()
+ trace_info = [{"some": "trace_info"}]
+
+ # Call the function
+ result = decode_and_extract(base64_row, task_props, validated_config, trace_info=trace_info)
+
+ # Assert that the mocked function was called with trace_info in params
+ assert result == "extracted_data_with_trace"
+ mock_func.assert_called_once()
+ _, _, kwargs = mock_func.mock_calls[0]
+ assert "trace_info" in kwargs
+ assert kwargs["trace_info"] == trace_info
+
+
+@patch(f'{MODULE_UNDER_TEST}.image_helpers')
+def test_decode_and_extract_handles_exception_in_extraction(mock_image_helpers):
+ # Mock the extraction function (using a valid method) to raise an exception
+ mock_func = MagicMock(side_effect=Exception("Extraction error"))
+ mock_image_helpers.image = mock_func # Use the default method or a valid method
+
+ # Sample inputs as a pandas Series (row)
+ base64_content = base64.b64encode(b"dummy_image_data").decode('utf-8')
+ base64_row = pd.Series({
+ "content": base64_content,
+ "document_type": "image",
+ "source_id": 1
+ })
+ task_props = {"method": "image", "params": {}} # Use a valid method name
+ validated_config = MagicMock()
+ trace_info = []
+
+ # Expecting an exception during extraction
+ with pytest.raises(Exception) as excinfo:
+ decode_and_extract(base64_row, task_props, validated_config, trace_info=trace_info)
+
+ # Verify the exception message
+ assert str(excinfo.value) == "Extraction error"
diff --git a/tests/nv_ingest/stages/transforms/__init__.py b/tests/nv_ingest/stages/transforms/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/tests/nv_ingest/stages/transforms/test_image_caption_extraction.py b/tests/nv_ingest/stages/transforms/test_image_caption_extraction.py
new file mode 100644
index 00000000..5b946566
--- /dev/null
+++ b/tests/nv_ingest/stages/transforms/test_image_caption_extraction.py
@@ -0,0 +1,312 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES.
+# All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+import base64
+import io
+
+import requests
+from PIL import Image
+from unittest.mock import MagicMock, patch
+
+import pytest
+
+MODULE_UNDER_TEST = 'nv_ingest.stages.transforms.image_caption_extraction'
+
+import pandas as pd
+
+from nv_ingest.schemas.metadata_schema import ContentTypeEnum
+from nv_ingest.stages.transforms.image_caption_extraction import _prepare_dataframes_mod
+from nv_ingest.stages.transforms.image_caption_extraction import _generate_captions
+from nv_ingest.stages.transforms.image_caption_extraction import caption_extract_stage
+
+
+def generate_base64_png_image() -> str:
+ """Helper function to generate a base64-encoded PNG image."""
+ img = Image.new("RGB", (10, 10), color="blue") # Create a simple blue image
+ buffered = io.BytesIO()
+ img.save(buffered, format="PNG")
+ return base64.b64encode(buffered.getvalue()).decode("utf-8")
+
+
+def test_prepare_dataframes_empty_dataframe():
+ # Test with an empty DataFrame
+ df = pd.DataFrame()
+
+ df_out, df_matched, bool_index = _prepare_dataframes_mod(df)
+
+ assert df_out.equals(df)
+ assert df_matched.empty
+ assert bool_index.empty
+ assert bool_index.dtype == bool
+
+
+def test_prepare_dataframes_missing_document_type_column():
+ # Test with a DataFrame missing the 'document_type' column
+ df = pd.DataFrame({
+ "other_column": [1, 2, 3]
+ })
+
+ df_out, df_matched, bool_index = _prepare_dataframes_mod(df)
+
+ assert df_out.equals(df)
+ assert df_matched.empty
+ assert bool_index.empty
+ assert bool_index.dtype == bool
+
+
+def test_prepare_dataframes_no_matches():
+ # Test with a DataFrame where no 'document_type' matches ContentTypeEnum.IMAGE
+ df = pd.DataFrame({
+ "document_type": [ContentTypeEnum.TEXT, ContentTypeEnum.STRUCTURED, ContentTypeEnum.UNSTRUCTURED]
+ })
+
+ df_out, df_matched, bool_index = _prepare_dataframes_mod(df)
+
+ assert df_out.equals(df)
+ assert df_matched.empty
+ assert bool_index.equals(pd.Series([False, False, False]))
+ assert bool_index.dtype == bool
+
+
+def test_prepare_dataframes_partial_matches():
+ # Test with a DataFrame where some rows match ContentTypeEnum.IMAGE
+ df = pd.DataFrame({
+ "document_type": [ContentTypeEnum.IMAGE, ContentTypeEnum.TEXT, ContentTypeEnum.IMAGE]
+ })
+
+ df_out, df_matched, bool_index = _prepare_dataframes_mod(df)
+
+ assert df_out.equals(df)
+ assert not df_matched.empty
+ assert df_matched.equals(df[df["document_type"] == ContentTypeEnum.IMAGE])
+ assert bool_index.equals(pd.Series([True, False, True]))
+ assert bool_index.dtype == bool
+
+
+def test_prepare_dataframes_all_matches():
+ # Test with a DataFrame where all rows match ContentTypeEnum.IMAGE
+ df = pd.DataFrame({
+ "document_type": [ContentTypeEnum.IMAGE, ContentTypeEnum.IMAGE, ContentTypeEnum.IMAGE]
+ })
+
+ df_out, df_matched, bool_index = _prepare_dataframes_mod(df)
+
+ assert df_out.equals(df)
+ assert df_matched.equals(df)
+ assert bool_index.equals(pd.Series([True, True, True]))
+ assert bool_index.dtype == bool
+
+
+@patch(f'{MODULE_UNDER_TEST}._generate_captions')
+def test_caption_extract_no_image_content(mock_generate_captions):
+ # DataFrame with no image content
+ df = pd.DataFrame({
+ "metadata": [{"content_metadata": {"type": "text"}}, {"content_metadata": {"type": "pdf"}}]
+ })
+ task_props = {"api_key": "test_api_key", "prompt": "Describe the image", "endpoint_url": "https://api.example.com"}
+ validated_config = MagicMock()
+ trace_info = {}
+
+ # Call the function
+ result_df = caption_extract_stage(df, task_props, validated_config, trace_info)
+
+ # Check that _generate_captions was not called and df is unchanged
+ mock_generate_captions.assert_not_called()
+ assert result_df.equals(df)
+
+
+@patch(f'{MODULE_UNDER_TEST}._generate_captions')
+def test_caption_extract_with_image_content(mock_generate_captions):
+ # Mock caption generation
+ mock_generate_captions.return_value = "A description of the image."
+
+ # DataFrame with image content
+ df = pd.DataFrame({
+ "metadata": [{"content_metadata": {"type": "image"}, "content": "base64_encoded_image_data"}]
+ })
+ task_props = {"api_key": "test_api_key", "prompt": "Describe the image", "endpoint_url": "https://api.example.com"}
+ validated_config = MagicMock()
+ trace_info = {}
+
+ # Call the function
+ result_df = caption_extract_stage(df, task_props, validated_config, trace_info)
+
+ # Check that _generate_captions was called once
+ mock_generate_captions.assert_called_once_with("base64_encoded_image_data", "Describe the image", "test_api_key",
+ "https://api.example.com")
+
+ # Verify that the caption was added to image_metadata
+ assert result_df.loc[0, "metadata"]["image_metadata"]["caption"] == "A description of the image."
+
+
+@patch(f'{MODULE_UNDER_TEST}._generate_captions')
+def test_caption_extract_mixed_content(mock_generate_captions):
+ # Mock caption generation
+ mock_generate_captions.return_value = "A description of the image."
+
+ # DataFrame with mixed content types
+ df = pd.DataFrame({
+ "metadata": [
+ {"content_metadata": {"type": "image"}, "content": "image_data_1"},
+ {"content_metadata": {"type": "text"}, "content": "text_data"},
+ {"content_metadata": {"type": "image"}, "content": "image_data_2"}
+ ]
+ })
+ task_props = {"api_key": "test_api_key", "prompt": "Describe the image", "endpoint_url": "https://api.example.com"}
+ validated_config = MagicMock()
+ trace_info = {}
+
+ # Call the function
+ result_df = caption_extract_stage(df, task_props, validated_config, trace_info)
+
+ # Check that _generate_captions was called twice for images only
+ assert mock_generate_captions.call_count == 2
+ mock_generate_captions.assert_any_call("image_data_1", "Describe the image", "test_api_key",
+ "https://api.example.com")
+ mock_generate_captions.assert_any_call("image_data_2", "Describe the image", "test_api_key",
+ "https://api.example.com")
+
+ # Verify that captions were added only for image rows
+ assert result_df.loc[0, "metadata"]["image_metadata"]["caption"] == "A description of the image."
+ assert "caption" not in result_df.loc[1, "metadata"].get("image_metadata", {})
+ assert result_df.loc[2, "metadata"]["image_metadata"]["caption"] == "A description of the image."
+
+
+@patch(f'{MODULE_UNDER_TEST}._generate_captions')
+def test_caption_extract_empty_dataframe(mock_generate_captions):
+ # Empty DataFrame
+ df = pd.DataFrame(columns=["metadata"])
+ task_props = {"api_key": "test_api_key", "prompt": "Describe the image", "endpoint_url": "https://api.example.com"}
+ validated_config = MagicMock()
+ trace_info = {}
+
+ # Call the function
+ result_df = caption_extract_stage(df, task_props, validated_config, trace_info)
+
+ # Check that _generate_captions was not called and df is still empty
+ mock_generate_captions.assert_not_called()
+ assert result_df.empty
+
+
+@patch(f'{MODULE_UNDER_TEST}._generate_captions')
+def test_caption_extract_malformed_metadata(mock_generate_captions):
+ # Mock caption generation
+ mock_generate_captions.return_value = "A description of the image."
+
+ # DataFrame with malformed metadata (missing 'content' key in one row)
+ df = pd.DataFrame({
+ "metadata": [{"unexpected_key": "value"}, {"content_metadata": {"type": "image"}}]
+ })
+ task_props = {"api_key": "test_api_key", "prompt": "Describe the image", "endpoint_url": "https://api.example.com"}
+ validated_config = MagicMock()
+ trace_info = {}
+
+ # Expecting KeyError for missing 'content' in the second row
+ with pytest.raises(KeyError, match="'content'"):
+ caption_extract_stage(df, task_props, validated_config, trace_info)
+
+
+@patch(f"{MODULE_UNDER_TEST}.requests.post")
+def test_generate_captions_successful(mock_post):
+ # Mock the successful API response
+ mock_response = MagicMock()
+ mock_response.raise_for_status = MagicMock()
+ mock_response.json.return_value = {
+ "choices": [
+ {"message": {"content": "A beautiful sunset over the mountains."}}
+ ]
+ }
+ mock_post.return_value = mock_response
+
+ # Parameters
+ base64_image = generate_base64_png_image()
+ prompt = "Describe the image"
+ api_key = "test_api_key"
+ endpoint_url = "https://api.example.com"
+
+ # Call the function
+ result = _generate_captions(base64_image, prompt, api_key, endpoint_url)
+
+ # Verify that the correct caption was returned
+ assert result == "A beautiful sunset over the mountains."
+ mock_post.assert_called_once_with(
+ endpoint_url,
+ headers={"Authorization": f"Bearer {api_key}", "Accept": "application/json"},
+ json={
+ "model": 'meta/llama-3.2-90b-vision-instruct',
+ "messages": [
+ {
+ "role": "user",
+ "content": f'{prompt}
'
+ }
+ ],
+ "max_tokens": 512,
+ "temperature": 1.00,
+ "top_p": 1.00,
+ "stream": False
+ }
+ )
+
+
+@patch(f"{MODULE_UNDER_TEST}.requests.post")
+def test_generate_captions_api_error(mock_post):
+ # Mock a 500 Internal Server Error response
+ mock_response = MagicMock()
+ mock_response.raise_for_status.side_effect = requests.exceptions.HTTPError("500 Server Error")
+ mock_post.return_value = mock_response
+
+ # Parameters
+ base64_image = generate_base64_png_image()
+ prompt = "Describe the image"
+ api_key = "test_api_key"
+ endpoint_url = "https://api.example.com"
+
+ # Expect an exception due to the server error
+ with pytest.raises(requests.exceptions.RequestException, match="500 Server Error"):
+ _generate_captions(base64_image, prompt, api_key, endpoint_url)
+
+
+@patch(f"{MODULE_UNDER_TEST}.requests.post")
+def test_generate_captions_malformed_json(mock_post):
+ # Mock a response with an unexpected JSON structure
+ mock_response = MagicMock()
+ mock_response.raise_for_status = MagicMock()
+ mock_response.json.return_value = {"unexpected_key": "unexpected_value"}
+ mock_post.return_value = mock_response
+
+ # Parameters
+ base64_image = generate_base64_png_image()
+ prompt = "Describe the image"
+ api_key = "test_api_key"
+ endpoint_url = "https://api.example.com"
+
+ # Call the function
+ result = _generate_captions(base64_image, prompt, api_key, endpoint_url)
+
+ # Verify fallback response when JSON is malformed
+ assert result == "No caption returned"
+
+
+@patch(f"{MODULE_UNDER_TEST}.requests.post")
+def test_generate_captions_empty_caption_content(mock_post):
+ # Mock a response with empty caption content
+ mock_response = MagicMock()
+ mock_response.raise_for_status = MagicMock()
+ mock_response.json.return_value = {
+ "choices": [
+ {"message": {"content": ""}}
+ ]
+ }
+ mock_post.return_value = mock_response
+
+ # Parameters
+ base64_image = generate_base64_png_image()
+ prompt = "Describe the image"
+ api_key = "test_api_key"
+ endpoint_url = "https://api.example.com"
+
+ # Call the function
+ result = _generate_captions(base64_image, prompt, api_key, endpoint_url)
+
+ # Verify that the fallback response is returned
+ assert result == ""
diff --git a/tests/nv_ingest/util/image_processing/test_transforms.py b/tests/nv_ingest/util/image_processing/test_transforms.py
index ad86c191..55b51585 100644
--- a/tests/nv_ingest/util/image_processing/test_transforms.py
+++ b/tests/nv_ingest/util/image_processing/test_transforms.py
@@ -2,14 +2,18 @@
# All rights reserved.
# SPDX-License-Identifier: Apache-2.0
-import pytest
-import numpy as np
-from PIL import Image
import base64
+import io
+import numpy as np
+import pytest
+
from io import BytesIO
+from PIL import Image
+from typing import Tuple
from unittest import mock
-from nv_ingest.util.image_processing.transforms import numpy_to_base64, base64_to_numpy, check_numpy_image_size
+from nv_ingest.util.image_processing.transforms import numpy_to_base64, base64_to_numpy, check_numpy_image_size, \
+ scale_image_to_encoding_size, ensure_base64_is_png
# Helper function to create a base64-encoded string from an image
@@ -107,3 +111,125 @@ def test_check_numpy_image_size_invalid_dimensions():
img = np.zeros((100,), dtype=np.uint8) # 1D array
with pytest.raises(ValueError, match="The input array does not have sufficient dimensions for an image."):
check_numpy_image_size(img, 50, 50)
+
+
+def generate_base64_image(size: Tuple[int, int]) -> str:
+ """Helper function to generate a base64-encoded PNG image of a specific size."""
+ img = Image.new("RGB", size, color="blue") # Create a simple blue image
+ buffered = io.BytesIO()
+ img.save(buffered, format="PNG")
+ return base64.b64encode(buffered.getvalue()).decode("utf-8")
+
+
+def generate_base64_image_with_format(format: str = 'PNG', size: Tuple[int, int] = (100, 100)) -> str:
+ """Helper function to generate a base64-encoded image of a specified format and size."""
+ img = Image.new("RGB", size, color="blue") # Simple blue image
+ buffered = io.BytesIO()
+ img.save(buffered, format=format)
+ return base64.b64encode(buffered.getvalue()).decode("utf-8")
+
+
+def test_resize_image_within_size_limit():
+ # Generate a base64 image within the size limit
+ base64_image = generate_base64_image((100, 100)) # Small image
+ max_base64_size = len(base64_image) + 10 # Set limit slightly above image size
+
+ result = scale_image_to_encoding_size(base64_image, max_base64_size)
+ assert result == base64_image # Should return unchanged
+
+
+def test_resize_image_one_resize_needed():
+ # Generate a large base64 image that requires resizing
+ base64_image = generate_base64_image((500, 500))
+ max_base64_size = len(base64_image) - 1000 # Set limit slightly below current size
+
+ result = scale_image_to_encoding_size(base64_image, max_base64_size)
+ assert len(result) <= max_base64_size # Should be resized within limit
+
+
+def test_resize_image_multiple_resizes_needed():
+ # Generate a very large base64 image that will require multiple reductions
+ base64_image = generate_base64_image((1000, 1000))
+ max_base64_size = len(base64_image) // 2 # Set limit well below current size
+
+ result = scale_image_to_encoding_size(base64_image, max_base64_size)
+ assert len(result) <= max_base64_size # Final size should be within limit
+
+
+def test_resize_image_cannot_be_resized_below_limit():
+ # Generate a small base64 image where further resizing would be impractical
+ base64_image = generate_base64_image((10, 10))
+ max_base64_size = 1 # Unreachable size limit
+
+ with pytest.raises(ValueError, match="height and width must be > 0"):
+ scale_image_to_encoding_size(base64_image, max_base64_size)
+
+
+def test_resize_image_edge_case_minimal_reduction():
+ # Generate an image just above the size limit
+ base64_image = generate_base64_image((500, 500))
+ max_base64_size = len(base64_image) - 50 # Just a slight reduction needed
+
+ result = scale_image_to_encoding_size(base64_image, max_base64_size)
+ assert len(result) <= max_base64_size # Should achieve minimal reduction within limit
+
+
+def test_resize_image_with_invalid_input():
+ # Provide non-image data as input
+ non_image_base64 = base64.b64encode(b"This is not an image").decode("utf-8")
+
+ with pytest.raises(Exception):
+ scale_image_to_encoding_size(non_image_base64)
+
+
+def test_ensure_base64_is_png_already_png():
+ # Generate a base64-encoded PNG image
+ base64_image = generate_base64_image_with_format("PNG")
+
+ result = ensure_base64_is_png(base64_image)
+ assert result == base64_image # Should be unchanged
+
+
+def test_ensure_base64_is_png_jpeg_to_png_conversion():
+ # Generate a base64-encoded JPEG image
+ base64_image = generate_base64_image_with_format("JPEG")
+
+ result = ensure_base64_is_png(base64_image)
+
+ # Decode the result and check format
+ image_data = base64.b64decode(result)
+ image = Image.open(io.BytesIO(image_data))
+ assert image.format == "PNG" # Should be converted to PNG
+
+
+def test_ensure_base64_is_png_invalid_base64():
+ # Provide invalid base64 input
+ invalid_base64 = "This is not base64 encoded data"
+
+ result = ensure_base64_is_png(invalid_base64)
+ assert result is None # Should return None for invalid input
+
+
+def test_ensure_base64_is_png_non_image_base64_data():
+ # Provide valid base64 data that isn’t an image
+ non_image_base64 = base64.b64encode(b"This is not an image").decode("utf-8")
+
+ result = ensure_base64_is_png(non_image_base64)
+ assert result is None # Should return None for non-image data
+
+
+def test_ensure_base64_is_png_unsupported_format():
+ # Generate an image in a rare format and base64 encode it
+ img = Image.new("RGB", (100, 100), color="blue")
+ buffered = io.BytesIO()
+ img.save(buffered, format="BMP") # Use an uncommon format like BMP
+ base64_image = base64.b64encode(buffered.getvalue()).decode("utf-8")
+
+ result = ensure_base64_is_png(base64_image)
+ # Decode result to verify conversion
+ if result:
+ image_data = base64.b64decode(result)
+ image = Image.open(io.BytesIO(image_data))
+ assert image.format == "PNG" # Should be converted to PNG if supported
+ else:
+ assert result is None # If unsupported, result should be None
diff --git a/tests/nv_ingest_client/primitives/tasks/test_caption.py b/tests/nv_ingest_client/primitives/tasks/test_caption.py
new file mode 100644
index 00000000..f9af7469
--- /dev/null
+++ b/tests/nv_ingest_client/primitives/tasks/test_caption.py
@@ -0,0 +1,139 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES.
+# All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+
+from pydantic import ValidationError
+from nv_ingest_client.primitives.tasks.caption import CaptionTaskSchema, CaptionTask
+
+
+# Testing CaptionTaskSchema
+
+def test_valid_schema_initialization():
+ """Test valid initialization of CaptionTaskSchema with all fields."""
+ schema = CaptionTaskSchema(api_key="test_key", endpoint_url="http://example.com", prompt="Generate a caption")
+ assert schema.api_key == "test_key"
+ assert schema.endpoint_url == "http://example.com"
+ assert schema.prompt == "Generate a caption"
+
+
+def test_partial_schema_initialization():
+ """Test valid initialization of CaptionTaskSchema with some fields omitted."""
+ schema = CaptionTaskSchema(api_key="test_key")
+ assert schema.api_key == "test_key"
+ assert schema.endpoint_url is None
+ assert schema.prompt is None
+
+
+def test_empty_schema_initialization():
+ """Test valid initialization of CaptionTaskSchema with no fields."""
+ schema = CaptionTaskSchema()
+ assert schema.api_key is None
+ assert schema.endpoint_url is None
+ assert schema.prompt is None
+
+
+def test_schema_invalid_extra_field():
+ """Test that CaptionTaskSchema raises an error with extra fields."""
+ try:
+ CaptionTaskSchema(api_key="test_key", extra_field="invalid")
+ except ValidationError as e:
+ assert "extra_field" in str(e)
+
+
+# Testing CaptionTask
+
+def test_caption_task_initialization():
+ """Test initializing CaptionTask with all fields."""
+ task = CaptionTask(api_key="test_key", endpoint_url="http://example.com", prompt="Generate a caption")
+ assert task._api_key == "test_key"
+ assert task._endpoint_url == "http://example.com"
+ assert task._prompt == "Generate a caption"
+
+
+def test_caption_task_partial_initialization():
+ """Test initializing CaptionTask with some fields omitted."""
+ task = CaptionTask(api_key="test_key")
+ assert task._api_key == "test_key"
+ assert task._endpoint_url is None
+ assert task._prompt is None
+
+
+def test_caption_task_empty_initialization():
+ """Test initializing CaptionTask with no fields."""
+ task = CaptionTask()
+ assert task._api_key is None
+ assert task._endpoint_url is None
+ assert task._prompt is None
+
+
+def test_caption_task_str_representation_all_fields():
+ """Test string representation of CaptionTask with all fields."""
+ task = CaptionTask(api_key="test_key", endpoint_url="http://example.com", prompt="Generate a caption")
+ task_str = str(task)
+ assert "Image Caption Task:" in task_str
+ assert "api_key: [redacted]" in task_str
+ assert "endpoint_url: http://example.com" in task_str
+ assert "prompt: Generate a caption" in task_str
+
+
+def test_caption_task_str_representation_partial_fields():
+ """Test string representation of CaptionTask with partial fields."""
+ task = CaptionTask(api_key="test_key")
+ task_str = str(task)
+ assert "Image Caption Task:" in task_str
+ assert "api_key: [redacted]" in task_str
+ assert "endpoint_url" not in task_str
+ assert "prompt" not in task_str
+
+
+def test_caption_task_to_dict_all_fields():
+ """Test to_dict method of CaptionTask with all fields."""
+ task = CaptionTask(api_key="test_key", endpoint_url="http://example.com", prompt="Generate a caption")
+ task_dict = task.to_dict()
+ assert task_dict == {
+ "type": "caption",
+ "task_properties": {
+ "api_key": "test_key",
+ "endpoint_url": "http://example.com",
+ "prompt": "Generate a caption"
+ }
+ }
+
+
+def test_caption_task_to_dict_partial_fields():
+ """Test to_dict method of CaptionTask with partial fields."""
+ task = CaptionTask(api_key="test_key")
+ task_dict = task.to_dict()
+ assert task_dict == {
+ "type": "caption",
+ "task_properties": {
+ "api_key": "test_key"
+ }
+ }
+
+
+def test_caption_task_to_dict_empty_fields():
+ """Test to_dict method of CaptionTask with no fields."""
+ task = CaptionTask()
+ task_dict = task.to_dict()
+ assert task_dict == {
+ "type": "caption",
+ "task_properties": {}
+ }
+
+
+# Execute tests
+if __name__ == "__main__":
+ test_valid_schema_initialization()
+ test_partial_schema_initialization()
+ test_empty_schema_initialization()
+ test_schema_invalid_extra_field()
+ test_caption_task_initialization()
+ test_caption_task_partial_initialization()
+ test_caption_task_empty_initialization()
+ test_caption_task_str_representation_all_fields()
+ test_caption_task_str_representation_partial_fields()
+ test_caption_task_to_dict_all_fields()
+ test_caption_task_to_dict_partial_fields()
+ test_caption_task_to_dict_empty_fields()
+ print("All tests passed.")