diff --git a/src/nv_ingest/util/nim/__init__.py b/src/nv_ingest/util/nim/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/nv_ingest/util/nim/cached.py b/src/nv_ingest/util/nim/cached.py new file mode 100644 index 00000000..299f249c --- /dev/null +++ b/src/nv_ingest/util/nim/cached.py @@ -0,0 +1,274 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + + +import base64 +import io +import logging +import PIL.Image as Image +from typing import Any, Dict, Optional, List + +import numpy as np + +from nv_ingest.util.image_processing.transforms import base64_to_numpy +from nv_ingest.util.nim.helpers import ModelInterface + +logger = logging.getLogger(__name__) + + +class CachedModelInterface(ModelInterface): + """ + An interface for handling inference with a Cached model, supporting both gRPC and HTTP + protocols, including batched input. + """ + + def name(self) -> str: + """ + Get the name of the model interface. + + Returns + ------- + str + The name of the model interface ("Cached"). + """ + return "Cached" + + def prepare_data_for_inference(self, data: Dict[str, Any]) -> Dict[str, Any]: + """ + Decode base64-encoded images into NumPy arrays, storing them in `data["image_arrays"]`. + + Parameters + ---------- + data : dict of str -> Any + The input data containing either: + - "base64_image": a single base64-encoded image, or + - "base64_images": a list of base64-encoded images. + + Returns + ------- + dict of str -> Any + The updated data dictionary with decoded image arrays stored in + "image_arrays", where each array has shape (H, W, C). + + Raises + ------ + KeyError + If neither 'base64_image' nor 'base64_images' is provided. + ValueError + If 'base64_images' is provided but is not a list. + """ + if "base64_images" in data: + base64_list = data["base64_images"] + if not isinstance(base64_list, list): + raise ValueError("The 'base64_images' key must contain a list of base64-encoded strings.") + data["image_arrays"] = [base64_to_numpy(img) for img in base64_list] + + elif "base64_image" in data: + # Fallback to single image case; wrap it in a list to keep the interface consistent + data["image_arrays"] = [base64_to_numpy(data["base64_image"])] + + else: + raise KeyError("Input data must include 'base64_image' or 'base64_images' with base64-encoded images.") + + return data + + def format_input(self, data: Dict[str, Any], protocol: str, max_batch_size: int, **kwargs) -> Any: + """ + Format input data for the specified protocol ("grpc" or "http"), handling batched images. + Additionally, returns batched data that coalesces the original image arrays and their dimensions + in the same order as provided. + + Parameters + ---------- + data : dict of str -> Any + The input data dictionary, expected to contain "image_arrays" (a list of np.ndarray). + protocol : str + The protocol to use, "grpc" or "http". + max_batch_size : int + The maximum number of images per batch. + + Returns + ------- + tuple + A tuple (formatted_batches, formatted_batch_data) where: + - For gRPC: formatted_batches is a list of NumPy arrays, each of shape (B, H, W, C) + with B <= max_batch_size. + - For HTTP: formatted_batches is a list of JSON-serializable dict payloads. + - In both cases, formatted_batch_data is a list of dicts with the keys: + "image_arrays": the list of original np.ndarray images for that batch, and + "image_dims": a list of (height, width) tuples for each image in the batch. + + Raises + ------ + KeyError + If "image_arrays" is missing in the data dictionary. + ValueError + If the protocol is invalid, or if no valid images are found. + """ + if "image_arrays" not in data: + raise KeyError("Expected 'image_arrays' in data. Make sure prepare_data_for_inference was called.") + + image_arrays = data["image_arrays"] + # Compute dimensions for each image. + image_dims = [(img.shape[0], img.shape[1]) for img in image_arrays] + + # Helper: chunk a list into sublists of length up to chunk_size. + def chunk_list(lst: list, chunk_size: int) -> List[list]: + return [lst[i : i + chunk_size] for i in range(0, len(lst), chunk_size)] + + if protocol == "grpc": + logger.debug("Formatting input for gRPC Cached model (batched).") + batched_images = [] + for arr in image_arrays: + # Expand from (H, W, C) to (1, H, W, C) if needed + if arr.ndim == 3: + arr = np.expand_dims(arr, axis=0) + batched_images.append(arr.astype(np.float32)) + + if not batched_images: + raise ValueError("No valid images found for gRPC formatting.") + + # Chunk the processed images, original arrays, and dimensions. + batched_image_chunks = chunk_list(batched_images, max_batch_size) + orig_chunks = chunk_list(image_arrays, max_batch_size) + dims_chunks = chunk_list(image_dims, max_batch_size) + + batched_inputs = [] + formatted_batch_data = [] + for proc_chunk, orig_chunk, dims_chunk in zip(batched_image_chunks, orig_chunks, dims_chunks): + # Concatenate along the batch dimension => shape (B, H, W, C) + batched_input = np.concatenate(proc_chunk, axis=0) + batched_inputs.append(batched_input) + formatted_batch_data.append({"image_arrays": orig_chunk, "image_dims": dims_chunk}) + return batched_inputs, formatted_batch_data + + elif protocol == "http": + logger.debug("Formatting input for HTTP Cached model (batched).") + content_list: List[Dict[str, Any]] = [] + for arr in image_arrays: + # Convert to uint8 if needed, then to PIL Image and base64-encode it. + if arr.dtype != np.uint8: + arr = (arr * 255).astype(np.uint8) + image_pil = Image.fromarray(arr) + buffered = io.BytesIO() + image_pil.save(buffered, format="PNG") + base64_img = base64.b64encode(buffered.getvalue()).decode("utf-8") + image_item = {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{base64_img}"}} + content_list.append(image_item) + + # Chunk the content list, original arrays, and dimensions. + content_chunks = chunk_list(content_list, max_batch_size) + orig_chunks = chunk_list(image_arrays, max_batch_size) + dims_chunks = chunk_list(image_dims, max_batch_size) + + payload_batches = [] + formatted_batch_data = [] + for chunk, orig_chunk, dims_chunk in zip(content_chunks, orig_chunks, dims_chunks): + message = {"content": chunk} + payload = {"messages": [message]} + payload_batches.append(payload) + formatted_batch_data.append({"image_arrays": orig_chunk, "image_dims": dims_chunk}) + return payload_batches, formatted_batch_data + + else: + raise ValueError("Invalid protocol specified. Must be 'grpc' or 'http'.") + + def parse_output(self, response: Any, protocol: str, data: Optional[Dict[str, Any]] = None, **kwargs: Any) -> Any: + """ + Parse the output from the Cached model's inference response. + + Parameters + ---------- + response : Any + The raw response from the model inference. + protocol : str + The protocol used ("grpc" or "http"). + data : dict of str -> Any, optional + Additional input data (unused here, but available for consistency). + **kwargs : Any + Additional keyword arguments for future compatibility. + + Returns + ------- + Any + The parsed output data (e.g., list of strings), depending on the protocol. + + Raises + ------ + ValueError + If the protocol is invalid. + RuntimeError + If the HTTP response is not as expected (missing 'data' key). + """ + if protocol == "grpc": + logger.debug("Parsing output from gRPC Cached model (batched).") + parsed: List[str] = [] + # Assume `response` is iterable, each element a list/array of byte strings + for single_output in response: + joined_str = " ".join(o.decode("utf-8") for o in single_output) + parsed.append(joined_str) + return parsed + + elif protocol == "http": + logger.debug("Parsing output from HTTP Cached model (batched).") + if not isinstance(response, dict): + raise RuntimeError("Expected JSON/dict response for HTTP, got something else.") + if "data" not in response or not response["data"]: + raise RuntimeError("Unexpected response format: 'data' key missing or empty.") + + contents: List[str] = [] + for item in response["data"]: + # Each "item" might have a "content" key + content = item.get("content", "") + contents.append(content) + + return contents + + else: + raise ValueError("Invalid protocol specified. Must be 'grpc' or 'http'.") + + def process_inference_results(self, output: Any, protocol: str, **kwargs: Any) -> Any: + """ + Process inference results for the Cached model. + + Parameters + ---------- + output : Any + The raw output from the model. + protocol : str + The inference protocol used ("grpc" or "http"). + **kwargs : Any + Additional parameters for post-processing (not used here). + + Returns + ------- + Any + The processed inference results, which here is simply returned as-is. + """ + # For Cached model, we simply return what we parsed (e.g., a list of strings or a single string) + return output + + def _extract_content_from_nim_response(self, json_response: Dict[str, Any]) -> Any: + """ + Extract content from the JSON response of a NIM (HTTP) API request. + + Parameters + ---------- + json_response : dict of str -> Any + The JSON response from the NIM API. + + Returns + ------- + Any + The extracted content from the response. + + Raises + ------ + RuntimeError + If the response format is unexpected (missing 'data' or empty). + """ + if "data" not in json_response or not json_response["data"]: + raise RuntimeError("Unexpected response format: 'data' key is missing or empty.") + + return json_response["data"][0]["content"] diff --git a/src/nv_ingest/util/nim/decorators.py b/src/nv_ingest/util/nim/decorators.py new file mode 100644 index 00000000..869c9e54 --- /dev/null +++ b/src/nv_ingest/util/nim/decorators.py @@ -0,0 +1,52 @@ +import logging +from functools import wraps +from multiprocessing import Lock +from multiprocessing import Manager + +logger = logging.getLogger(__name__) + +# Create a shared manager and lock for thread-safe access +manager = Manager() +global_cache = manager.dict() +lock = Lock() + + +def multiprocessing_cache(max_calls): + """ + A decorator that creates a global cache shared between multiple processes. + The cache is invalidated after `max_calls` number of accesses. + + Args: + max_calls (int): The number of calls after which the cache is cleared. + + Returns: + function: The decorated function with global cache and invalidation logic. + """ + + def decorator(func): + call_count = manager.Value("i", 0) # Shared integer for call counting + + @wraps(func) + def wrapper(*args, **kwargs): + key = (func.__name__, args, frozenset(kwargs.items())) + + with lock: + call_count.value += 1 + + if call_count.value > max_calls: + global_cache.clear() + call_count.value = 0 + + if key in global_cache: + return global_cache[key] + + result = func(*args, **kwargs) + + with lock: + global_cache[key] = result + + return result + + return wrapper + + return decorator diff --git a/src/nv_ingest/util/nim/deplot.py b/src/nv_ingest/util/nim/deplot.py new file mode 100644 index 00000000..3a8414d7 --- /dev/null +++ b/src/nv_ingest/util/nim/deplot.py @@ -0,0 +1,270 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from typing import Dict, Any, Optional, List + +import numpy as np +import logging + +from nv_ingest.util.image_processing.transforms import base64_to_numpy +from nv_ingest.util.nim.helpers import ModelInterface + +logger = logging.getLogger(__name__) + + +class DeplotModelInterface(ModelInterface): + """ + An interface for handling inference with a Deplot model, supporting both gRPC and HTTP protocols, + now updated to handle multiple base64 images ('base64_images'). + """ + + def name(self) -> str: + """ + Get the name of the model interface. + + Returns + ------- + str + The name of the model interface ("Deplot"). + """ + return "Deplot" + + def prepare_data_for_inference(self, data: Dict[str, Any]) -> Dict[str, Any]: + """ + Prepare input data by decoding one or more base64-encoded images into NumPy arrays. + + Parameters + ---------- + data : dict + The input data containing either 'base64_image' (single image) + or 'base64_images' (multiple images). + + Returns + ------- + dict + The updated data dictionary with 'image_arrays': a list of decoded NumPy arrays. + """ + + # Handle a single base64_image or multiple base64_images + if "base64_images" in data: + base64_list = data["base64_images"] + if not isinstance(base64_list, list): + raise ValueError("The 'base64_images' key must contain a list of base64-encoded strings.") + image_arrays = [base64_to_numpy(b64) for b64 in base64_list] + + elif "base64_image" in data: + # Fallback for single image + image_arrays = [base64_to_numpy(data["base64_image"])] + else: + raise KeyError("Input data must include 'base64_image' or 'base64_images'.") + + data["image_arrays"] = image_arrays + + return data + + def format_input(self, data: Dict[str, Any], protocol: str, max_batch_size: int, **kwargs) -> Any: + """ + Format input data for the specified protocol (gRPC or HTTP) for Deplot. + For HTTP, we now construct multiple messages—one per image batch—along with + corresponding batch data carrying the original image arrays and their dimensions. + + Parameters + ---------- + data : dict of str -> Any + The input data dictionary, expected to contain "image_arrays" (a list of np.ndarray). + protocol : str + The protocol to use, "grpc" or "http". + max_batch_size : int + The maximum number of images per batch. + kwargs : dict + Additional parameters to pass to the payload preparation (for HTTP). + + Returns + ------- + tuple + (formatted_batches, formatted_batch_data) where: + - For gRPC: formatted_batches is a list of NumPy arrays, each of shape (B, H, W, C) + with B <= max_batch_size. + - For HTTP: formatted_batches is a list of JSON-serializable payload dicts. + - In both cases, formatted_batch_data is a list of dicts containing: + "image_arrays": the list of original np.ndarray images for that batch, and + "image_dims": a list of (height, width) tuples for each image in the batch. + + Raises + ------ + KeyError + If "image_arrays" is missing in the data dictionary. + ValueError + If the protocol is invalid, or if no valid images are found. + """ + if "image_arrays" not in data: + raise KeyError("Expected 'image_arrays' in data. Call prepare_data_for_inference first.") + + image_arrays = data["image_arrays"] + # Compute image dimensions from each image array. + image_dims = [(img.shape[0], img.shape[1]) for img in image_arrays] + + # Helper function: chunk a list into sublists of length <= chunk_size. + def chunk_list(lst: list, chunk_size: int) -> List[list]: + return [lst[i : i + chunk_size] for i in range(0, len(lst), chunk_size)] + + if protocol == "grpc": + logger.debug("Formatting input for gRPC Deplot model (potentially batched).") + processed = [] + for arr in image_arrays: + # Ensure each image has shape (1, H, W, C) + if arr.ndim == 3: + arr = np.expand_dims(arr, axis=0) + arr = arr.astype(np.float32) + arr /= 255.0 # Normalize to [0,1] + processed.append(arr) + + if not processed: + raise ValueError("No valid images found for gRPC formatting.") + + formatted_batches = [] + formatted_batch_data = [] + proc_chunks = chunk_list(processed, max_batch_size) + orig_chunks = chunk_list(image_arrays, max_batch_size) + dims_chunks = chunk_list(image_dims, max_batch_size) + + for proc_chunk, orig_chunk, dims_chunk in zip(proc_chunks, orig_chunks, dims_chunks): + # Concatenate along the batch dimension to form a single input. + batched_input = np.concatenate(proc_chunk, axis=0) + formatted_batches.append(batched_input) + formatted_batch_data.append({"image_arrays": orig_chunk, "image_dims": dims_chunk}) + return formatted_batches, formatted_batch_data + + elif protocol == "http": + logger.debug("Formatting input for HTTP Deplot model (multiple messages).") + if "base64_images" in data: + base64_list = data["base64_images"] + else: + base64_list = [data["base64_image"]] + + formatted_batches = [] + formatted_batch_data = [] + b64_chunks = chunk_list(base64_list, max_batch_size) + orig_chunks = chunk_list(image_arrays, max_batch_size) + dims_chunks = chunk_list(image_dims, max_batch_size) + + for b64_chunk, orig_chunk, dims_chunk in zip(b64_chunks, orig_chunks, dims_chunks): + payload = self._prepare_deplot_payload( + base64_list=b64_chunk, + max_tokens=kwargs.get("max_tokens", 500), + temperature=kwargs.get("temperature", 0.5), + top_p=kwargs.get("top_p", 0.9), + ) + formatted_batches.append(payload) + formatted_batch_data.append({"image_arrays": orig_chunk, "image_dims": dims_chunk}) + return formatted_batches, formatted_batch_data + + else: + raise ValueError("Invalid protocol specified. Must be 'grpc' or 'http'.") + + def parse_output(self, response: Any, protocol: str, data: Optional[Dict[str, Any]] = None, **kwargs) -> Any: + """ + Parse the model's inference response. + """ + if protocol == "grpc": + logger.debug("Parsing output from gRPC Deplot model (batched).") + # Each batch element might be returned as a list of bytes. Combine or keep separate as needed. + results = [] + for item in response: + # If item is [b'...'], decode and join + if isinstance(item, list): + joined_str = " ".join(o.decode("utf-8") for o in item) + results.append(joined_str) + else: + # single bytes or str + val = item.decode("utf-8") if isinstance(item, bytes) else str(item) + results.append(val) + return results # Return a list of strings, one per image. + + elif protocol == "http": + logger.debug("Parsing output from HTTP Deplot model.") + return self._extract_content_from_deplot_response(response) + else: + raise ValueError("Invalid protocol specified. Must be 'grpc' or 'http'.") + + def process_inference_results(self, output: Any, protocol: str, **kwargs) -> Any: + """ + Process inference results for the Deplot model. + + Parameters + ---------- + output : Any + The raw output from the model. + protocol : str + The protocol used for inference (gRPC or HTTP). + + Returns + ------- + Any + The processed inference results. + """ + + # For Deplot, the output is the chart content as a string + return output + + @staticmethod + def _prepare_deplot_payload( + base64_list: list, + max_tokens: int = 500, + temperature: float = 0.5, + top_p: float = 0.9, + ) -> Dict[str, Any]: + """ + Prepare an HTTP payload for Deplot that includes one message per image, + matching the original single-image style: + + messages = [ + { + "role": "user", + "content": "Generate ... " + }, + { + "role": "user", + "content": "Generate ... " + }, + ... + ] + + If your backend expects multiple messages in a single request, this keeps + the same structure as the single-image code repeated N times. + """ + messages = [] + # Note: deplot NIM currently only supports a single message per request + for b64_img in base64_list: + messages.append( + { + "role": "user", + "content": ( + "Generate the underlying data table of the figure below: " + f'' + ), + } + ) + + payload = { + "model": "google/deplot", + "messages": messages, # multiple user messages now + "max_tokens": max_tokens, + "stream": False, + "temperature": temperature, + "top_p": top_p, + } + return payload + + @staticmethod + def _extract_content_from_deplot_response(json_response: Dict[str, Any]) -> Any: + """ + Extract content from the JSON response of a Deplot HTTP API request. + The original code expected a single choice with a single textual content. + """ + if "choices" not in json_response or not json_response["choices"]: + raise RuntimeError("Unexpected response format: 'choices' key is missing or empty.") + + # If the service only returns one textual result, we return that one. + return json_response["choices"][0]["message"]["content"] diff --git a/src/nv_ingest/util/nim/doughnut.py b/src/nv_ingest/util/nim/doughnut.py new file mode 100644 index 00000000..84cafe3b --- /dev/null +++ b/src/nv_ingest/util/nim/doughnut.py @@ -0,0 +1,165 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import logging +import re +from typing import List +from typing import Tuple + +ACCEPTED_TEXT_CLASSES = set( + [ + "Text", + "Title", + "Section-header", + "List-item", + "TOC", + "Bibliography", + "Formula", + "Page-header", + "Page-footer", + "Caption", + "Footnote", + "Floating-text", + ] +) +ACCEPTED_TABLE_CLASSES = set( + [ + "Table", + ] +) +ACCEPTED_IMAGE_CLASSES = set( + [ + "Picture", + ] +) +ACCEPTED_CLASSES = ACCEPTED_TEXT_CLASSES | ACCEPTED_TABLE_CLASSES | ACCEPTED_IMAGE_CLASSES + +_re_extract_class_bbox = re.compile( + r"((?:|.(?:(?", # noqa: E501 + re.MULTILINE | re.DOTALL, +) + +logger = logging.getLogger(__name__) + + +def extract_classes_bboxes(text: str) -> Tuple[List[str], List[Tuple[int, int, int, int]], List[str]]: + classes: List[str] = [] + bboxes: List[Tuple[int, int, int, int]] = [] + texts: List[str] = [] + + last_end = 0 + + for m in _re_extract_class_bbox.finditer(text): + start, end = m.span() + + # [Bad box] Add the non-match chunk (text between the last match and the current match) + if start > last_end: + bad_text = text[last_end:start].strip() + classes.append("Bad-box") + bboxes.append((0, 0, 0, 0)) + texts.append(bad_text) + + last_end = end + + x1, y1, text, x2, y2, cls = m.groups() + + bbox = tuple(map(int, (x1, y1, x2, y2))) + + # [Bad box] check if the class is a valid class. + if cls not in ACCEPTED_CLASSES: + logger.debug(f"Dropped a bad box: invalid class {cls} at {bbox}.") + classes.append("Bad-box") + bboxes.append(bbox) + texts.append(text) + continue + + # Drop bad box: drop if the box is invalid. + if (bbox[0] >= bbox[2]) or (bbox[1] >= bbox[3]): + logger.debug(f"Dropped a bad box: invalid box {cls} at {bbox}.") + classes.append("Bad-box") + bboxes.append(bbox) + texts.append(text) + continue + + classes.append(cls) + bboxes.append(bbox) + texts.append(text) + + if last_end < len(text): + bad_text = text[last_end:].strip() + if len(bad_text) > 0: + classes.append("Bad-box") + bboxes.append((0, 0, 0, 0)) + texts.append(bad_text) + + return classes, bboxes, texts + + +def _fix_dots(m): + # Remove spaces between dots. + s = m.group(0) + return s.startswith(" ") * " " + min(5, s.count(".")) * "." + s.endswith(" ") * " " + + +def strip_markdown_formatting(text): + # Remove headers (e.g., # Header, ## Header, ### Header) + text = re.sub(r"^(#+)\s*(.*)", r"\2", text, flags=re.MULTILINE) + + # Remove bold formatting (e.g., **bold text** or __bold text__) + text = re.sub(r"\*\*(.*?)\*\*", r"\1", text) + text = re.sub(r"__(.*?)__", r"\1", text) + + # Remove italic formatting (e.g., *italic text* or _italic text_) + text = re.sub(r"\*(.*?)\*", r"\1", text) + text = re.sub(r"_(.*?)_", r"\1", text) + + # Remove strikethrough formatting (e.g., ~~strikethrough~~) + text = re.sub(r"~~(.*?)~~", r"\1", text) + + # Remove list items (e.g., - item, * item, 1. item) + text = re.sub(r"^\s*([-*+]|[0-9]+\.)\s+", "", text, flags=re.MULTILINE) + + # Remove hyperlinks (e.g., [link text](http://example.com)) + text = re.sub(r"\[(.*?)\]\(.*?\)", r"\1", text) + + # Remove inline code (e.g., `code`) + text = re.sub(r"`(.*?)`", r"\1", text) + + # Remove blockquotes (e.g., > quote) + text = re.sub(r"^\s*>\s*(.*)", r"\1", text, flags=re.MULTILINE) + + # Remove multiple newlines + text = re.sub(r"\n{3,}", "\n\n", text) + + # Limit dots sequences to max 5 dots + text = re.sub(r"(?:\s*\.\s*){3,}", _fix_dots, text, flags=re.DOTALL) + + return text + + +def reverse_transform_bbox( + bbox: Tuple[int, int, int, int], + bbox_offset: Tuple[int, int], + original_width: int, + original_height: int, +) -> Tuple[int, int, int, int]: + width_ratio = (original_width - 2 * bbox_offset[0]) / original_width + height_ratio = (original_height - 2 * bbox_offset[1]) / original_height + w1, h1, w2, h2 = bbox + w1 = int((w1 - bbox_offset[0]) / width_ratio) + h1 = int((h1 - bbox_offset[1]) / height_ratio) + w2 = int((w2 - bbox_offset[0]) / width_ratio) + h2 = int((h2 - bbox_offset[1]) / height_ratio) + + return (w1, h1, w2, h2) + + +def postprocess_text(txt: str, cls: str): + if cls in ACCEPTED_CLASSES: + txt = txt.replace("", "").strip() # remove tokens (continued paragraphs) + txt = strip_markdown_formatting(txt) + else: + txt = "" + + return txt diff --git a/src/nv_ingest/util/nim/helpers.py b/src/nv_ingest/util/nim/helpers.py new file mode 100644 index 00000000..22505b08 --- /dev/null +++ b/src/nv_ingest/util/nim/helpers.py @@ -0,0 +1,708 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import logging +import re +import threading +import time +from concurrent.futures import ThreadPoolExecutor +from typing import Any +from typing import Optional +from typing import Tuple + +import backoff +import cv2 +import numpy as np +import requests +import tritonclient.grpc as grpcclient + +from nv_ingest.util.image_processing.transforms import normalize_image +from nv_ingest.util.image_processing.transforms import pad_image +from nv_ingest.util.nim.decorators import multiprocessing_cache +from nv_ingest.util.tracing.tagging import traceable_func + +logger = logging.getLogger(__name__) + +DEPLOT_MAX_TOKENS = 128 +DEPLOT_TEMPERATURE = 1.0 +DEPLOT_TOP_P = 1.0 + + +class ModelInterface: + """ + Base class for defining a model interface that supports preparing input data, formatting it for + inference, parsing output, and processing inference results. + """ + + def format_input(self, data: dict, protocol: str, max_batch_size: int): + """ + Format the input data for the specified protocol. + + Parameters + ---------- + data : dict + The input data to format. + protocol : str + The protocol to format the data for. + """ + + raise NotImplementedError("Subclasses should implement this method") + + def parse_output(self, response, protocol: str, data: Optional[dict] = None, **kwargs): + """ + Parse the output data from the model's inference response. + + Parameters + ---------- + response : Any + The response from the model inference. + protocol : str + The protocol used ("grpc" or "http"). + data : dict, optional + Additional input data passed to the function. + """ + + raise NotImplementedError("Subclasses should implement this method") + + def prepare_data_for_inference(self, data: dict): + """ + Prepare input data for inference by processing or transforming it as required. + + Parameters + ---------- + data : dict + The input data to prepare. + """ + raise NotImplementedError("Subclasses should implement this method") + + def process_inference_results(self, output_array, protocol: str, **kwargs): + """ + Process the inference results from the model. + + Parameters + ---------- + output_array : Any + The raw output from the model. + kwargs : dict + Additional parameters for processing. + """ + raise NotImplementedError("Subclasses should implement this method") + + def name(self) -> str: + """ + Get the name of the model interface. + + Returns + ------- + str + The name of the model interface. + """ + raise NotImplementedError("Subclasses should implement this method") + + +class NimClient: + """ + A client for interfacing with a model inference server using gRPC or HTTP protocols. + """ + + def __init__( + self, + model_interface, + protocol: str, + endpoints: Tuple[str, str], + auth_token: Optional[str] = None, + timeout: float = 120.0, + max_retries: int = 5, + ): + """ + Initialize the NimClient with the specified model interface, protocol, and server endpoints. + + Parameters + ---------- + model_interface : ModelInterface + The model interface implementation to use. + protocol : str + The protocol to use ("grpc" or "http"). + endpoints : tuple + A tuple containing the gRPC and HTTP endpoints. + auth_token : str, optional + Authorization token for HTTP requests (default: None). + timeout : float, optional + Timeout for HTTP requests in seconds (default: 30.0). + + Raises + ------ + ValueError + If an invalid protocol is specified or if required endpoints are missing. + """ + + self.client = None + self.model_interface = model_interface + self.protocol = protocol.lower() + self.auth_token = auth_token + self.timeout = timeout # Timeout for HTTP requests + self.max_retries = max_retries + self._grpc_endpoint, self._http_endpoint = endpoints + self._max_batch_sizes = {} + self._lock = threading.Lock() + + if self.protocol == "grpc": + if not self._grpc_endpoint: + raise ValueError("gRPC endpoint must be provided for gRPC protocol") + logger.debug(f"Creating gRPC client with {self._grpc_endpoint}") + self.client = grpcclient.InferenceServerClient(url=self._grpc_endpoint) + elif self.protocol == "http": + if not self._http_endpoint: + raise ValueError("HTTP endpoint must be provided for HTTP protocol") + logger.debug(f"Creating HTTP client with {self._http_endpoint}") + self.endpoint_url = generate_url(self._http_endpoint) + self.headers = {"accept": "application/json", "content-type": "application/json"} + if self.auth_token: + self.headers["Authorization"] = f"Bearer {self.auth_token}" + else: + raise ValueError("Invalid protocol specified. Must be 'grpc' or 'http'.") + + def _fetch_max_batch_size(self, model_name, model_version: str = "") -> int: + """Fetch the maximum batch size from the Triton model configuration in a thread-safe manner.""" + if model_name in self._max_batch_sizes: + return self._max_batch_sizes[model_name] + + with self._lock: + # Double check, just in case another thread set the value while we were waiting + if model_name in self._max_batch_sizes: + return self._max_batch_sizes[model_name] + + if not self._grpc_endpoint: + self._max_batch_sizes[model_name] = 1 + return 1 + + try: + client = self.client if self.client else grpcclient.InferenceServerClient(url=self._grpc_endpoint) + model_config = client.get_model_config(model_name=model_name, model_version=model_version) + self._max_batch_sizes[model_name] = model_config.config.max_batch_size + logger.debug(f"Max batch size for model '{model_name}': {self._max_batch_sizes[model_name]}") + except Exception as e: + self._max_batch_sizes[model_name] = 1 + logger.warning(f"Failed to retrieve max batch size: {e}, defaulting to 1") + + return self._max_batch_sizes[model_name] + + def _process_batch(self, batch_input, *, batch_data, model_name, **kwargs): + """ + Process a single batch input for inference using its corresponding batch_data. + + Parameters + ---------- + batch_input : Any + The input data for this batch. + batch_data : Any + The corresponding scratch-pad data for this batch as returned by format_input. + model_name : str + The model name for inference. + kwargs : dict + Additional parameters. + + Returns + ------- + tuple + A tuple (parsed_output, batch_data) for subsequent post-processing. + """ + if self.protocol == "grpc": + logger.debug("Performing gRPC inference for a batch...") + response = self._grpc_infer(batch_input, model_name) + logger.debug("gRPC inference received response for a batch") + elif self.protocol == "http": + logger.debug("Performing HTTP inference for a batch...") + response = self._http_infer(batch_input) + logger.debug("HTTP inference received response for a batch") + else: + raise ValueError("Invalid protocol specified. Must be 'grpc' or 'http'.") + + parsed_output = self.model_interface.parse_output(response, protocol=self.protocol, data=batch_data, **kwargs) + return parsed_output, batch_data + + def try_set_max_batch_size(self, model_name, model_version: str = ""): + """Attempt to set the max batch size for the model if it is not already set, ensuring thread safety.""" + self._fetch_max_batch_size(model_name, model_version) + + @traceable_func(trace_name="{stage_name}::{model_name}") + def infer(self, data: dict, model_name: str, **kwargs) -> Any: + """ + Perform inference using the specified model and input data. + + Parameters + ---------- + data : dict + The input data for inference. + model_name : str + The model name. + kwargs : dict + Additional parameters for inference. + + Returns + ------- + Any + The processed inference results, coalesced in the same order as the input images. + """ + try: + # 1. Retrieve or default to the model's maximum batch size. + batch_size = self._fetch_max_batch_size(model_name) + max_requested_batch_size = kwargs.get("max_batch_size", batch_size) + force_requested_batch_size = kwargs.get("force_max_batch_size", False) + max_batch_size = ( + min(batch_size, max_requested_batch_size) + if not force_requested_batch_size + else max_requested_batch_size + ) + + # 2. Prepare data for inference. + data = self.model_interface.prepare_data_for_inference(data) + + # 3. Format the input based on protocol. + formatted_batches, formatted_batch_data = self.model_interface.format_input( + data, protocol=self.protocol, max_batch_size=max_batch_size, model_name=model_name + ) + + # Check for a custom maximum pool worker count, and remove it from kwargs. + max_pool_workers = kwargs.pop("max_pool_workers", 16) + + # 4. Process each batch concurrently using a thread pool. + # We enumerate the batches so that we can later reassemble results in order. + results = [None] * len(formatted_batches) + with ThreadPoolExecutor(max_workers=max_pool_workers) as executor: + futures = [] + for idx, (batch, batch_data) in enumerate(zip(formatted_batches, formatted_batch_data)): + future = executor.submit( + self._process_batch, batch, batch_data=batch_data, model_name=model_name, **kwargs + ) + futures.append((idx, future)) + for idx, future in futures: + results[idx] = future.result() + + # 5. Process the parsed outputs for each batch using its corresponding batch_data. + # As the batches are in order, we coalesce their outputs accordingly. + all_results = [] + for parsed_output, batch_data in results: + batch_results = self.model_interface.process_inference_results( + parsed_output, + original_image_shapes=batch_data.get("original_image_shapes"), + protocol=self.protocol, + **kwargs, + ) + if isinstance(batch_results, list): + all_results.extend(batch_results) + else: + all_results.append(batch_results) + + except Exception as err: + error_str = f"Error during NimClient inference [{self.model_interface.name()}, {self.protocol}]: {err}" + logger.error(error_str) + raise RuntimeError(error_str) + + return all_results + + def _grpc_infer(self, formatted_input: np.ndarray, model_name: str) -> np.ndarray: + """ + Perform inference using the gRPC protocol. + + Parameters + ---------- + formatted_input : np.ndarray + The input data formatted as a numpy array. + model_name : str + The name of the model to use for inference. + + Returns + ------- + np.ndarray + The output of the model as a numpy array. + """ + + input_tensors = [grpcclient.InferInput("input", formatted_input.shape, datatype="FP32")] + input_tensors[0].set_data_from_numpy(formatted_input) + + outputs = [grpcclient.InferRequestedOutput("output")] + response = self.client.infer(model_name=model_name, inputs=input_tensors, outputs=outputs) + logger.debug(f"gRPC inference response: {response}") + + return response.as_numpy("output") + + def _http_infer(self, formatted_input: dict) -> dict: + """ + Perform inference using the HTTP protocol, retrying for timeouts or 5xx errors up to 5 times. + + Parameters + ---------- + formatted_input : dict + The input data formatted as a dictionary. + + Returns + ------- + dict + The output of the model as a dictionary. + + Raises + ------ + TimeoutError + If the HTTP request times out repeatedly, up to the max retries. + requests.RequestException + For other HTTP-related errors that persist after max retries. + """ + + base_delay = 2.0 + attempt = 0 + + while attempt < self.max_retries: + try: + response = requests.post( + self.endpoint_url, json=formatted_input, headers=self.headers, timeout=self.timeout + ) + status_code = response.status_code + + # Check for server-side or rate-limit type errors + # e.g. 5xx => server error, 429 => too many requests + if status_code == 429 or status_code == 503 or (500 <= status_code < 600): + logger.warning( + f"Received HTTP {status_code} ({response.reason}) from " + f"{self.model_interface.name()}. Attempt {attempt + 1} of {self.max_retries}." + ) + if attempt == self.max_retries - 1: + # No more retries left + logger.error(f"Max retries exceeded after receiving HTTP {status_code}.") + response.raise_for_status() # raise the appropriate HTTPError + else: + # Exponential backoff + backoff_time = base_delay * (2**attempt) + time.sleep(backoff_time) + attempt += 1 + continue + else: + # Not in our "retry" category => just raise_for_status or return + response.raise_for_status() + logger.debug(f"HTTP inference response: {response.json()}") + return response.json() + + except requests.Timeout: + # Treat timeouts similarly to 5xx => attempt a retry + logger.warning( + f"HTTP request timed out after {self.timeout} seconds during {self.model_interface.name()} " + f"inference. Attempt {attempt + 1} of {self.max_retries}." + ) + if attempt == self.max_retries - 1: + logger.error("Max retries exceeded after repeated timeouts.") + raise TimeoutError( + f"Repeated timeouts for {self.model_interface.name()} after {attempt + 1} attempts." + ) + # Exponential backoff + backoff_time = base_delay * (2**attempt) + time.sleep(backoff_time) + attempt += 1 + + except requests.HTTPError as http_err: + # If we ended up here, it's a non-retryable 4xx or final 5xx after final attempt + logger.error(f"HTTP request failed with status code {response.status_code}: {http_err}") + raise + + except requests.RequestException as e: + # ConnectionError or other non-HTTPError + logger.error(f"HTTP request encountered a network issue: {e}") + if attempt == self.max_retries - 1: + raise + # Else retry on next loop iteration + backoff_time = base_delay * (2**attempt) + time.sleep(backoff_time) + attempt += 1 + + # If we exit the loop without returning, we've exhausted all attempts + logger.error(f"Failed to get a successful response after {self.max_retries} retries.") + raise Exception(f"Failed to get a successful response after {self.max_retries} retries.") + + def close(self): + if self.protocol == "grpc" and hasattr(self.client, "close"): + self.client.close() + + +def create_inference_client( + endpoints: Tuple[str, str], + model_interface: ModelInterface, + auth_token: Optional[str] = None, + infer_protocol: Optional[str] = None, +) -> NimClient: + """ + Create a NimClient for interfacing with a model inference server. + + Parameters + ---------- + endpoints : tuple + A tuple containing the gRPC and HTTP endpoints. + model_interface : ModelInterface + The model interface implementation to use. + auth_token : str, optional + Authorization token for HTTP requests (default: None). + infer_protocol : str, optional + The protocol to use ("grpc" or "http"). If not specified, it is inferred from the endpoints. + + Returns + ------- + NimClient + The initialized NimClient. + + Raises + ------ + ValueError + If an invalid infer_protocol is specified. + """ + + grpc_endpoint, http_endpoint = endpoints + + if (infer_protocol is None) and (grpc_endpoint and grpc_endpoint.strip()): + infer_protocol = "grpc" + elif infer_protocol is None and http_endpoint: + infer_protocol = "http" + + if infer_protocol not in ["grpc", "http"]: + raise ValueError("Invalid infer_protocol specified. Must be 'grpc' or 'http'.") + + return NimClient(model_interface, infer_protocol, endpoints, auth_token) + + +def preprocess_image_for_paddle(array: np.ndarray, image_max_dimension: int = 960) -> np.ndarray: + """ + Preprocesses an input image to be suitable for use with PaddleOCR by resizing, normalizing, padding, + and transposing it into the required format. + + This function is intended for preprocessing images to be passed as input to PaddleOCR using GRPC. + It is not necessary when using the HTTP endpoint. + + Steps: + ----- + 1. Resizes the image while maintaining aspect ratio such that its largest dimension is scaled to 960 pixels. + 2. Normalizes the image using the `normalize_image` function. + 3. Pads the image to ensure both its height and width are multiples of 32, as required by PaddleOCR. + 4. Transposes the image from (height, width, channel) to (channel, height, width), the format expected by PaddleOCR. + + Parameters: + ---------- + array : np.ndarray + The input image array of shape (height, width, channels). It should have pixel values in the range [0, 255]. + + Returns: + ------- + np.ndarray + A preprocessed image with the shape (channels, height, width) and normalized pixel values. + The image will be padded to have dimensions that are multiples of 32, with the padding color set to 0. + + Notes: + ----- + - The image is resized so that its largest dimension becomes 960 pixels, maintaining the aspect ratio. + - After normalization, the image is padded to the nearest multiple of 32 in both dimensions, which is + a requirement for PaddleOCR. + - The normalized pixel values are scaled between 0 and 1 before padding and transposing the image. + """ + height, width = array.shape[:2] + scale_factor = image_max_dimension / max(height, width) + new_height = int(height * scale_factor) + new_width = int(width * scale_factor) + resized = cv2.resize(array, (new_width, new_height)) + + normalized = normalize_image(resized) + + # PaddleOCR NIM (GRPC) requires input shapes to be multiples of 32. + new_height = (normalized.shape[0] + 31) // 32 * 32 + new_width = (normalized.shape[1] + 31) // 32 * 32 + padded, (pad_width, pad_height) = pad_image( + normalized, target_height=new_height, target_width=new_width, background_color=0, dtype=np.float32 + ) + + # PaddleOCR NIM (GRPC) requires input to be (channel, height, width). + transposed = padded.transpose((2, 0, 1)) + + # Metadata can used for inverting transformations on the resulting bounding boxes. + metadata = { + "original_height": height, + "original_width": width, + "scale_factor": scale_factor, + "new_height": transposed.shape[1], + "new_width": transposed.shape[2], + "pad_height": pad_height, + "pad_width": pad_width, + } + + return transposed, metadata + + +def remove_url_endpoints(url) -> str: + """Some configurations provide the full endpoint in the URL. + Ex: http://deplot:8000/v1/chat/completions. For hitting the + health endpoint we need to get just the hostname:port combo + that we can append the health/ready endpoint to so we attempt + to parse that information here. + + Args: + url str: Incoming URL + + Returns: + str: URL with just the hostname:port portion remaining + """ + if "/v1" in url: + url = url.split("/v1")[0] + + return url + + +def generate_url(url) -> str: + """Examines the user defined URL for http*://. If that + pattern is detected the URL is used as provided by the user. + If that pattern does not exist then the assumption is made that + the endpoint is simply `http://` and that is prepended + to the user supplied endpoint. + + Args: + url str: Endpoint where the Rest service is running + + Returns: + str: Fully validated URL + """ + if not re.match(r"^https?://", url): + # Add the default `http://` if it's not already present in the URL + url = f"http://{url}" + + return url + + +def is_ready(http_endpoint: str, ready_endpoint: str) -> bool: + """ + Check if the server at the given endpoint is ready. + + Parameters + ---------- + http_endpoint : str + The HTTP endpoint of the server. + ready_endpoint : str + The specific ready-check endpoint. + + Returns + ------- + bool + True if the server is ready, False otherwise. + """ + + # IF the url is empty or None that means the service was not configured + # and is therefore automatically marked as "ready" + if http_endpoint is None or http_endpoint == "": + return True + + # If the url is for build.nvidia.com, it is automatically assumed "ready" + if "ai.api.nvidia.com" in http_endpoint: + return True + + url = generate_url(http_endpoint) + url = remove_url_endpoints(url) + + if not ready_endpoint.startswith("/") and not url.endswith("/"): + ready_endpoint = "/" + ready_endpoint + + url = url + ready_endpoint + + # Call the ready endpoint of the NIM + try: + # Use a short timeout to prevent long hanging calls. 5 seconds seems resonable + resp = requests.get(url, timeout=5) + if resp.status_code == 200: + # The NIM is saying it is ready to serve + return True + elif resp.status_code == 503: + # NIM is explicitly saying it is not ready. + return False + else: + # Any other code is confusing. We should log it with a warning + # as it could be something that might hold up ready state + logger.warning(f"'{url}' HTTP Status: {resp.status_code} - Response Payload: {resp.json()}") + return False + except requests.HTTPError as http_err: + logger.warning(f"'{url}' produced a HTTP error: {http_err}") + return False + except requests.Timeout: + logger.warning(f"'{url}' request timed out") + return False + except ConnectionError: + logger.warning(f"A connection error for '{url}' occurred") + return False + except requests.RequestException as err: + logger.warning(f"An error occurred: {err} for '{url}'") + return False + except Exception as ex: + # Don't let anything squeeze by + logger.warning(f"Exception: {ex}") + return False + + +@multiprocessing_cache(max_calls=100) # Cache results first to avoid redundant retries from backoff +@backoff.on_predicate(backoff.expo, max_time=30) +def get_version(http_endpoint: str, metadata_endpoint: str = "/v1/metadata", version_field: str = "version") -> str: + """ + Get the version of the server from its metadata endpoint. + + Parameters + ---------- + http_endpoint : str + The HTTP endpoint of the server. + metadata_endpoint : str, optional + The metadata endpoint to query (default: "/v1/metadata"). + version_field : str, optional + The field containing the version in the response (default: "version"). + + Returns + ------- + str + The version of the server, or an empty string if unavailable. + """ + + if (http_endpoint is None) or (http_endpoint == ""): + return "" + + # TODO: Need a way to match NIM versions to API versions. + if "ai.api.nvidia.com" in http_endpoint or "api.nvcf.nvidia.com" in http_endpoint: + return "1.0.0" + + url = generate_url(http_endpoint) + url = remove_url_endpoints(url) + + if not metadata_endpoint.startswith("/") and not url.endswith("/"): + metadata_endpoint = "/" + metadata_endpoint + + url = url + metadata_endpoint + + # Call the metadata endpoint of the NIM + try: + # Use a short timeout to prevent long hanging calls. 5 seconds seems reasonable + resp = requests.get(url, timeout=5) + if resp.status_code == 200: + version = resp.json().get(version_field, "") + if version: + return version + else: + # If version field is empty, retry + logger.warning(f"No version field in response from '{url}'. Retrying.") + return "" + else: + # Any other code is confusing. We should log it with a warning + logger.warning(f"'{url}' HTTP Status: {resp.status_code} - Response Payload: {resp.text}") + return "" + except requests.HTTPError as http_err: + logger.warning(f"'{url}' produced a HTTP error: {http_err}") + return "" + except requests.Timeout: + logger.warning(f"'{url}' request timed out") + return "" + except ConnectionError: + logger.warning(f"A connection error for '{url}' occurred") + return "" + except requests.RequestException as err: + logger.warning(f"An error occurred: {err} for '{url}'") + return "" + except Exception as ex: + # Don't let anything squeeze by + logger.warning(f"Exception: {ex}") + return "" diff --git a/src/nv_ingest/util/nim/paddle.py b/src/nv_ingest/util/nim/paddle.py new file mode 100644 index 00000000..98b99efd --- /dev/null +++ b/src/nv_ingest/util/nim/paddle.py @@ -0,0 +1,446 @@ +import json +import logging +from typing import Any, List, Tuple +from typing import Dict +from typing import Optional + +import numpy as np + +from nv_ingest.util.image_processing.transforms import base64_to_numpy +from nv_ingest.util.nim.helpers import ModelInterface +from nv_ingest.util.nim.helpers import preprocess_image_for_paddle + +logger = logging.getLogger(__name__) + + +class PaddleOCRModelInterface(ModelInterface): + """ + An interface for handling inference with a PaddleOCR model, supporting both gRPC and HTTP protocols. + """ + + def name(self) -> str: + """ + Get the name of the model interface. + + Returns + ------- + str + The name of the model interface. + """ + return "PaddleOCR" + + def prepare_data_for_inference(self, data: Dict[str, Any]) -> Dict[str, Any]: + """ + Decode one or more base64-encoded images into NumPy arrays, storing them + alongside their dimensions in `data`. + + Parameters + ---------- + data : dict of str -> Any + The input data containing either: + - 'base64_image': a single base64-encoded image, or + - 'base64_images': a list of base64-encoded images. + + Returns + ------- + dict of str -> Any + The updated data dictionary with the following keys added: + - "image_arrays": List of decoded NumPy arrays of shape (H, W, C). + - "image_dims": List of (height, width) tuples for each decoded image. + + Raises + ------ + KeyError + If neither 'base64_image' nor 'base64_images' is found in `data`. + ValueError + If 'base64_images' is present but is not a list. + """ + if "base64_images" in data: + base64_list = data["base64_images"] + if not isinstance(base64_list, list): + raise ValueError("The 'base64_images' key must contain a list of base64-encoded strings.") + + image_arrays: List[np.ndarray] = [] + for b64 in base64_list: + img = base64_to_numpy(b64) + image_arrays.append(img) + + data["image_arrays"] = image_arrays + + elif "base64_image" in data: + # Single-image fallback + img = base64_to_numpy(data["base64_image"]) + data["image_arrays"] = [img] + + else: + raise KeyError("Input data must include 'base64_image' or 'base64_images'.") + + return data + + def format_input(self, data: Dict[str, Any], protocol: str, max_batch_size: int, **kwargs) -> Any: + """ + Format input data for the specified protocol ("grpc" or "http"), supporting batched data. + + Parameters + ---------- + data : dict of str -> Any + The input data dictionary, expected to contain "image_arrays" (list of np.ndarray) + and "image_dims" (list of (height, width) tuples), as produced by prepare_data_for_inference. + protocol : str + The inference protocol, either "grpc" or "http". + max_batch_size : int + The maximum batch size for batching. + + Returns + ------- + tuple + A tuple (formatted_batches, formatted_batch_data) where: + - formatted_batches is a list of batches ready for inference. + - formatted_batch_data is a list of scratch-pad dictionaries corresponding to each batch, + containing the keys "image_arrays" and "image_dims" for later post-processing. + + Raises + ------ + KeyError + If either "image_arrays" or "image_dims" is not found in `data`. + ValueError + If an invalid protocol is specified. + """ + + images = data["image_arrays"] + + dims: List[Dict[str, Any]] = [] + data["image_dims"] = dims + + # Helper function to split a list into chunks of size up to chunk_size. + def chunk_list(lst, chunk_size): + return [lst[i : i + chunk_size] for i in range(0, len(lst), chunk_size)] + + if "image_arrays" not in data or "image_dims" not in data: + raise KeyError("Expected 'image_arrays' and 'image_dims' in data. Call prepare_data_for_inference first.") + + images = data["image_arrays"] + dims = data["image_dims"] + + if protocol == "grpc": + logger.debug("Formatting input for gRPC PaddleOCR model (batched).") + processed: List[np.ndarray] = [] + for img in images: + arr, _dims = preprocess_image_for_paddle(img) + dims.append(_dims) + arr = arr.astype(np.float32) + arr = np.expand_dims(arr, axis=0) # => shape (1, H, W, C) + processed.append(arr) + + batches = [] + batch_data_list = [] + for proc_chunk, orig_chunk, dims_chunk in zip( + chunk_list(processed, max_batch_size), + chunk_list(images, max_batch_size), + chunk_list(dims, max_batch_size), + ): + batched_input = np.concatenate(proc_chunk, axis=0) + batches.append(batched_input) + batch_data_list.append({"image_arrays": orig_chunk, "image_dims": dims_chunk}) + return batches, batch_data_list + + elif protocol == "http": + logger.debug("Formatting input for HTTP PaddleOCR model (batched).") + if "base64_images" in data: + base64_list = data["base64_images"] + else: + base64_list = [data["base64_image"]] + + input_list: List[Dict[str, Any]] = [] + for b64 in base64_list: + image_url = f"data:image/png;base64,{b64}" + image_obj = {"type": "image_url", "url": image_url} + input_list.append(image_obj) + + batches = [] + batch_data_list = [] + for input_chunk, orig_chunk, dims_chunk in zip( + chunk_list(input_list, max_batch_size), + chunk_list(images, max_batch_size), + chunk_list(dims, max_batch_size), + ): + payload = {"input": input_chunk} + batches.append(payload) + batch_data_list.append({"image_arrays": orig_chunk, "image_dims": dims_chunk}) + return batches, batch_data_list + + else: + raise ValueError("Invalid protocol specified. Must be 'grpc' or 'http'.") + + def parse_output(self, response: Any, protocol: str, data: Optional[Dict[str, Any]] = None, **kwargs: Any) -> Any: + """ + Parse the model's inference response for the given protocol. The parsing + may handle batched outputs for multiple images. + + Parameters + ---------- + response : Any + The raw response from the PaddleOCR model. + protocol : str + The protocol used for inference, "grpc" or "http". + data : dict of str -> Any, optional + Additional data dictionary that may include "image_dims" for bounding box scaling. + **kwargs : Any + Additional keyword arguments, such as custom `table_content_format`. + + Returns + ------- + Any + The parsed output, typically a list of (content, table_content_format) tuples. + + Raises + ------ + ValueError + If an invalid protocol is specified. + """ + # Retrieve image dimensions if available + dims: Optional[List[Tuple[int, int]]] = data.get("image_dims") if data else None + + if protocol == "grpc": + logger.debug("Parsing output from gRPC PaddleOCR model (batched).") + return self._extract_content_from_paddle_grpc_response(response, dims) + + elif protocol == "http": + logger.debug("Parsing output from HTTP PaddleOCR model (batched).") + return self._extract_content_from_paddle_http_response(response) + + else: + raise ValueError("Invalid protocol specified. Must be 'grpc' or 'http'.") + + def process_inference_results(self, output: Any, **kwargs: Any) -> Any: + """ + Process inference results for the PaddleOCR model. + + Parameters + ---------- + output : Any + The raw output parsed from the PaddleOCR model. + **kwargs : Any + Additional keyword arguments for customization. + + Returns + ------- + Any + The post-processed inference results. By default, this simply returns the output + as the table content (or content list). + """ + return output + + def _prepare_paddle_payload(self, base64_img: str) -> Dict[str, Any]: + """ + DEPRECATED by batch logic in format_input. Kept here if you need single-image direct calls. + + Parameters + ---------- + base64_img : str + A single base64-encoded image string. + + Returns + ------- + dict of str -> Any + The payload in either legacy or new format for PaddleOCR's HTTP endpoint. + """ + image_url = f"data:image/png;base64,{base64_img}" + + image = {"type": "image_url", "url": image_url} + payload = {"input": [image]} + + return payload + + def _extract_content_from_paddle_http_response( + self, + json_response: Dict[str, Any], + table_content_format: Optional[str], + ) -> List[Tuple[str, str]]: + """ + Extract content from the JSON response of a PaddleOCR HTTP API request. + + Parameters + ---------- + json_response : dict of str -> Any + The JSON response returned by the PaddleOCR endpoint. + table_content_format : str or None + The specified format for table content (e.g., 'simple' or 'pseudo_markdown'). + + Returns + ------- + list of (str, str) + A list of (content, table_content_format) tuples, one for each image result. + + Raises + ------ + RuntimeError + If the response format is missing or invalid. + ValueError + If the `table_content_format` is unrecognized. + """ + if "data" not in json_response or not json_response["data"]: + raise RuntimeError("Unexpected response format: 'data' key is missing or empty.") + + results: List[str] = [] + for item_idx, item in enumerate(json_response["data"]): + text_detections = item.get("text_detections", []) + text_predictions = [] + bounding_boxes = [] + for td in text_detections: + text_predictions.append(td["text_prediction"]["text"]) + bounding_boxes.append([[pt["x"], pt["y"]] for pt in td["bounding_box"]["points"]]) + + results.append([bounding_boxes, text_predictions]) + + return results + + def _extract_content_from_paddle_grpc_response( + self, + response: np.ndarray, + dimensions: List[Dict[str, Any]], + ) -> List[Tuple[str, str]]: + """ + Parse a gRPC response for one or more images. The response can have two possible shapes: + - (3,) for batch_size=1 + - (3, n) for batch_size=n + + In either case: + response[0, i]: byte string containing bounding box data + response[1, i]: byte string containing text prediction data + response[2, i]: (Optional) additional data/metadata (ignored here) + + Parameters + ---------- + response : np.ndarray + The raw NumPy array from gRPC. Expected shape: (3,) or (3, n). + table_content_format : str + The format of the output text content, e.g. 'simple' or 'pseudo_markdown'. + dims : list of dict, optional + A list of dict for each corresponding image, used for bounding box scaling. + + Returns + ------- + list of (str, str) + A list of (content, table_content_format) for each image. + + Raises + ------ + ValueError + If the response is not a NumPy array or has an unexpected shape, + or if the `table_content_format` is unrecognized. + """ + if not isinstance(response, np.ndarray): + raise ValueError("Unexpected response format: response is not a NumPy array.") + + # If we have shape (3,), convert to (3, 1) + if response.ndim == 1 and response.shape == (3,): + response = response.reshape(3, 1) + elif response.ndim != 2 or response.shape[0] != 3: + raise ValueError(f"Unexpected response shape: {response.shape}. Expecting (3,) or (3, n).") + + batch_size = response.shape[1] + results: List[Tuple[str, str]] = [] + + for i in range(batch_size): + # 1) Parse bounding boxes + bboxes_bytestr: bytes = response[0, i] + bounding_boxes = json.loads(bboxes_bytestr.decode("utf8")) + + # 2) Parse text predictions + texts_bytestr: bytes = response[1, i] + text_predictions = json.loads(texts_bytestr.decode("utf8")) + + # 3) Log the third element (extra data/metadata) if needed + extra_data_bytestr: bytes = response[2, i] + logger.debug(f"Ignoring extra_data for image {i}: {extra_data_bytestr}") + + # Some gRPC responses nest single-item lists; flatten them if needed + if isinstance(bounding_boxes, list) and len(bounding_boxes) == 1: + bounding_boxes = bounding_boxes[0] + if isinstance(text_predictions, list) and len(text_predictions) == 1: + text_predictions = text_predictions[0] + + bounding_boxes, text_predictions = self._postprocess_paddle_response( + bounding_boxes, + text_predictions, + dimensions, + img_index=i, + ) + + results.append([bounding_boxes, text_predictions]) + + return results + + @staticmethod + def _postprocess_paddle_response( + bounding_boxes: List[Any], + text_predictions: List[str], + dims: Optional[List[Dict[str, Any]]] = None, + img_index: int = 0, + ) -> Tuple[List[Any], List[str]]: + """ + Convert bounding boxes with normalized coordinates to pixel cooridnates by using + the dimensions. Also shift the coorindates if the inputs were padded. For multiple images, + the correct image dimensions (height, width) are retrieved from `dims[img_index]`. + + Parameters + ---------- + bounding_boxes : list of Any + A list (per line of text) of bounding boxes, each a list of (x, y) points. + text_predictions : list of str + A list of text predictions, one for each bounding box. + img_index : int, optional + The index of the image for which bounding boxes are being converted. Default is 0. + dims : list of dict, optional + A list of dictionaries, where each dictionary contains image-specific dimensions + and scaling information: + - "new_width" (int): The width of the image after processing. + - "new_height" (int): The height of the image after processing. + - "pad_width" (int, optional): The width of padding added to the image. + - "pad_height" (int, optional): The height of padding added to the image. + - "scale_factor" (float, optional): The scaling factor applied to the image. + + Returns + ------- + Tuple[List[Any], List[str]] + Bounding boxes scaled backed to the original dimensions and detected text lines. + + Notes + ----- + - If `dims` is None or `img_index` is out of range, bounding boxes will not be scaled properly. + """ + # Default to no scaling if dims are missing or out of range + if not dims: + raise ValueError("No image_dims provided.") + else: + if img_index >= len(dims): + logger.warning("Image index out of range for stored dimensions. Using first image dims by default.") + img_index = 0 + + max_width = dims[img_index]["new_width"] + max_height = dims[img_index]["new_height"] + pad_width = dims[img_index].get("pad_width", 0) + pad_height = dims[img_index].get("pad_height", 0) + scale_factor = dims[img_index].get("scale_factor", 1.0) + + bboxes: List[List[float]] = [] + texts: List[str] = [] + + # Convert normalized coords back to actual pixel coords + for box, txt in zip(bounding_boxes, text_predictions): + if box == "nan": + continue + points: List[List[float]] = [] + for point in box: + # Convert normalized coords back to actual pixel coords, + # and shift them back to their original positions if padded. + x_pixels = float(point[0]) * max_width - pad_width + y_pixels = float(point[1]) * max_height - pad_height + x_original = x_pixels / scale_factor + y_original = y_pixels / scale_factor + points.append([x_original, y_original]) + bboxes.append(points) + texts.append(txt) + + return bboxes, texts diff --git a/src/nv_ingest/util/nim/text_embedding.py b/src/nv_ingest/util/nim/text_embedding.py new file mode 100644 index 00000000..11a16789 --- /dev/null +++ b/src/nv_ingest/util/nim/text_embedding.py @@ -0,0 +1,128 @@ +from typing import Any, Dict, List, Optional, Tuple + +from nv_ingest.util.nim.helpers import ModelInterface + + +# Assume ModelInterface is defined elsewhere in the project. +class EmbeddingModelInterface(ModelInterface): + """ + An interface for handling inference with an embedding model endpoint. + This implementation supports HTTP inference for generating embeddings from text prompts. + """ + + def name(self) -> str: + """ + Return the name of this model interface. + """ + return "Embedding" + + def prepare_data_for_inference(self, data: Dict[str, Any]) -> Dict[str, Any]: + """ + Prepare input data for embedding inference. Ensures that a 'prompts' key is provided + and that its value is a list. + + Raises + ------ + KeyError + If the 'prompts' key is missing. + """ + if "prompts" not in data: + raise KeyError("Input data must include 'prompts'.") + # Ensure the prompts are in list format. + if not isinstance(data["prompts"], list): + data["prompts"] = [data["prompts"]] + return data + + def format_input( + self, data: Dict[str, Any], protocol: str, max_batch_size: int, **kwargs + ) -> Tuple[List[Any], List[Dict[str, Any]]]: + """ + Format the input payload for the embedding endpoint. This method constructs one payload per batch, + where each payload includes a list of text prompts. + Additionally, it returns batch data that preserves the original order of prompts. + + Parameters + ---------- + data : dict + The input data containing "prompts" (a list of text prompts). + protocol : str + Only "http" is supported. + max_batch_size : int + Maximum number of prompts per payload. + kwargs : dict + Additional parameters including model_name, encoding_format, input_type, and truncate. + + Returns + ------- + tuple + A tuple (payloads, batch_data_list) where: + - payloads is a list of JSON-serializable payload dictionaries. + - batch_data_list is a list of dictionaries containing the key "prompts" corresponding to each batch. + """ + if protocol != "http": + raise ValueError("EmbeddingModelInterface only supports HTTP protocol.") + + prompts = data.get("prompts", []) + + def chunk_list(lst, chunk_size): + return [lst[i : i + chunk_size] for i in range(0, len(lst), chunk_size)] + + batches = chunk_list(prompts, max_batch_size) + payloads = [] + batch_data_list = [] + for batch in batches: + payload = { + "model": kwargs.get("model_name"), + "input": batch, + "encoding_format": kwargs.get("encoding_format", "float"), + "extra_body": { + "input_type": kwargs.get("input_type", "query"), + "truncate": kwargs.get("truncate", "NONE"), + }, + } + payloads.append(payload) + batch_data_list.append({"prompts": batch}) + return payloads, batch_data_list + + def parse_output(self, response: Any, protocol: str, data: Optional[Dict[str, Any]] = None, **kwargs) -> Any: + """ + Parse the HTTP response from the embedding endpoint. Expects a response structure with a "data" key. + + Parameters + ---------- + response : Any + The raw HTTP response (assumed to be already decoded as JSON). + protocol : str + Only "http" is supported. + data : dict, optional + The original input data. + kwargs : dict + Additional keyword arguments. + + Returns + ------- + list + A list of generated embeddings extracted from the response. + """ + if protocol != "http": + raise ValueError("EmbeddingModelInterface only supports HTTP protocol.") + if isinstance(response, dict): + embeddings = response.get("data") + if not embeddings: + raise RuntimeError("Unexpected response format: 'data' key is missing or empty.") + # Each item in embeddings is expected to have an 'embedding' field. + return [item.get("embedding", None) for item in embeddings] + else: + return [str(response)] + + def process_inference_results(self, output: Any, protocol: str, **kwargs) -> Any: + """ + Process inference results for the embedding model. + For this implementation, the output is expected to be a list of embeddings. + + Returns + ------- + list + The processed list of embeddings. + """ + return output diff --git a/src/nv_ingest/util/nim/vlm.py b/src/nv_ingest/util/nim/vlm.py new file mode 100644 index 00000000..cffefb0f --- /dev/null +++ b/src/nv_ingest/util/nim/vlm.py @@ -0,0 +1,148 @@ +from typing import Dict, Any, Optional, Tuple, List + +import logging + +from nv_ingest.util.nim.helpers import ModelInterface + +logger = logging.getLogger(__name__) + + +class VLMModelInterface(ModelInterface): + """ + An interface for handling inference with a VLM model endpoint (e.g., NVIDIA LLaMA-based VLM). + This implementation supports HTTP inference with one or more base64-encoded images and a caption prompt. + """ + + def name(self) -> str: + """ + Return the name of this model interface. + """ + return "VLM" + + def prepare_data_for_inference(self, data: Dict[str, Any]) -> Dict[str, Any]: + """ + Prepare input data for VLM inference. Accepts either a single base64 image or a list of images. + Ensures that a 'prompt' is provided. + + Raises + ------ + KeyError + If neither "base64_image" nor "base64_images" is provided or if "prompt" is missing. + ValueError + If "base64_images" exists but is not a list. + """ + # Allow either a single image with "base64_image" or multiple images with "base64_images". + if "base64_images" in data: + if not isinstance(data["base64_images"], list): + raise ValueError("The 'base64_images' key must contain a list of base64-encoded strings.") + elif "base64_image" in data: + # Convert a single image into a list. + data["base64_images"] = [data["base64_image"]] + else: + raise KeyError("Input data must include 'base64_image' or 'base64_images'.") + + if "prompt" not in data: + raise KeyError("Input data must include 'prompt'.") + return data + + def format_input( + self, data: Dict[str, Any], protocol: str, max_batch_size: int, **kwargs + ) -> Tuple[List[Any], List[Dict[str, Any]]]: + """ + Format the input payload for the VLM endpoint. This method constructs one payload per batch, + where each payload includes one message per image in the batch. + Additionally, it returns batch data that preserves the original order of images by including + the list of base64 images and the prompt for each batch. + + Parameters + ---------- + data : dict + The input data containing "base64_images" (a list of base64-encoded images) and "prompt". + protocol : str + Only "http" is supported. + max_batch_size : int + Maximum number of images per payload. + kwargs : dict + Additional parameters including model_name, max_tokens, temperature, top_p, and stream. + + Returns + ------- + tuple + A tuple (payloads, batch_data_list) where: + - payloads is a list of JSON-serializable payload dictionaries. + - batch_data_list is a list of dictionaries containing the keys "base64_images" and "prompt" + corresponding to each batch. + """ + if protocol != "http": + raise ValueError("VLMModelInterface only supports HTTP protocol.") + + images = data.get("base64_images", []) + prompt = data["prompt"] + + # Helper function to chunk the list into batches. + def chunk_list(lst, chunk_size): + return [lst[i : i + chunk_size] for i in range(0, len(lst), chunk_size)] + + batches = chunk_list(images, max_batch_size) + payloads = [] + batch_data_list = [] + for batch in batches: + # Create one message per image in the batch. + messages = [ + {"role": "user", "content": f'{prompt} '} for img in batch + ] + payload = { + "model": kwargs.get("model_name"), + "messages": messages, + "max_tokens": kwargs.get("max_tokens", 512), + "temperature": kwargs.get("temperature", 1.0), + "top_p": kwargs.get("top_p", 1.0), + "stream": kwargs.get("stream", False), + } + payloads.append(payload) + batch_data_list.append({"base64_images": batch, "prompt": prompt}) + return payloads, batch_data_list + + def parse_output(self, response: Any, protocol: str, data: Optional[Dict[str, Any]] = None, **kwargs) -> Any: + """ + Parse the HTTP response from the VLM endpoint. Expects a response structure with a "choices" key. + + Parameters + ---------- + response : Any + The raw HTTP response (assumed to be already decoded as JSON). + protocol : str + Only "http" is supported. + data : dict, optional + The original input data. + kwargs : dict + Additional keyword arguments. + + Returns + ------- + list + A list of generated captions extracted from the response. + """ + if protocol != "http": + raise ValueError("VLMModelInterface only supports HTTP protocol.") + if isinstance(response, dict): + choices = response.get("choices", []) + if not choices: + raise RuntimeError("Unexpected response format: 'choices' key is missing or empty.") + # Return a list of captions, one per choice. + return [choice.get("message", {}).get("content", "No caption returned") for choice in choices] + else: + # If response is not a dict, return its string representation in a list. + return [str(response)] + + def process_inference_results(self, output: Any, protocol: str, **kwargs) -> Any: + """ + Process inference results for the VLM model. + For this implementation, the output is expected to be a list of captions. + + Returns + ------- + list + The processed list of captions. + """ + return output diff --git a/src/nv_ingest/util/nim/yolox.py b/src/nv_ingest/util/nim/yolox.py new file mode 100644 index 00000000..47df3569 --- /dev/null +++ b/src/nv_ingest/util/nim/yolox.py @@ -0,0 +1,1186 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + + +import base64 +import io +import logging +import warnings +from math import log +from typing import Any, Tuple +from typing import Dict +from typing import List +from typing import Optional + +import cv2 +import numpy as np +import pandas as pd +import torch +import torchvision +from PIL import Image + +from nv_ingest.util.image_processing.transforms import scale_image_to_encoding_size +from nv_ingest.util.nim.helpers import ModelInterface + +logger = logging.getLogger(__name__) + +# yolox-page-elements-v1 contants +YOLOX_PAGE_NUM_CLASSES = 3 +YOLOX_PAGE_CONF_THRESHOLD = 0.01 +YOLOX_PAGE_IOU_THRESHOLD = 0.5 +YOLOX_PAGE_MIN_SCORE = 0.1 +YOLOX_PAGE_FINAL_SCORE = 0.48 +YOLOX_PAGE_NIM_MAX_IMAGE_SIZE = 512_000 + +YOLOX_PAGE_IMAGE_PREPROC_HEIGHT = 1024 +YOLOX_PAGE_IMAGE_PREPROC_WIDTH = 1024 + +YOLOX_PAGE_CLASS_LABELS = [ + "table", + "chart", + "title", +] + +# yolox-graphic-elements-v1 contants +YOLOX_GRAPHIC_NUM_CLASSES = 10 +YOLOX_GRAPHIC_CONF_THRESHOLD = 0.01 +YOLOX_GRAPHIC_IOU_THRESHOLD = 0.25 +YOLOX_GRAPHIC_MIN_SCORE = 0.1 +YOLOX_GRAPHIC_FINAL_SCORE = 0.0 +YOLOX_GRAPHIC_NIM_MAX_IMAGE_SIZE = 512_000 + +YOLOX_GRAPHIC_IMAGE_PREPROC_HEIGHT = 768 +YOLOX_GRAPHIC_IMAGE_PREPROC_WIDTH = 768 + +YOLOX_GRAPHIC_CLASS_LABELS = [ + "chart_title", + "x_title", + "y_title", + "xlabel", + "ylabel", + "other", + "legend_label", + "legend_title", + "mark_label", + "value_label", +] + + +# YoloxModelInterfaceBase implements methods that are common to yolox-page-elements and yolox-graphic-elements +class YoloxModelInterfaceBase(ModelInterface): + """ + An interface for handling inference with a Yolox object detection model, supporting both gRPC and HTTP protocols. + """ + + def __init__( + self, + image_preproc_width: Optional[int] = None, + image_preproc_height: Optional[int] = None, + nim_max_image_size: Optional[int] = None, + num_classes: Optional[int] = None, + conf_threshold: Optional[float] = None, + iou_threshold: Optional[float] = None, + min_score: Optional[float] = None, + final_score: Optional[float] = None, + class_labels: Optional[List[str]] = None, + ): + """ + Initialize the YOLOX model interface. + Parameters + ---------- + """ + self.image_preproc_width = image_preproc_width + self.image_preproc_height = image_preproc_height + self.nim_max_image_size = nim_max_image_size + self.num_classes = num_classes + self.conf_threshold = conf_threshold + self.iou_threshold = iou_threshold + self.min_score = min_score + self.final_score = final_score + self.class_labels = class_labels + + def prepare_data_for_inference(self, data: Dict[str, Any]) -> Dict[str, Any]: + """ + Prepare input data for inference by resizing images and storing their original shapes. + + Parameters + ---------- + data : dict + The input data containing a list of images. + + Returns + ------- + dict + The updated data dictionary with resized images and original image shapes. + """ + if (not isinstance(data, dict)) or ("images" not in data): + raise KeyError("Input data must be a dictionary containing an 'images' key with a list of images.") + + if not all(isinstance(x, np.ndarray) for x in data["images"]): + raise ValueError("All elements in the 'images' list must be numpy.ndarray objects.") + + original_images = data["images"] + data["original_image_shapes"] = [image.shape for image in original_images] + + return data + + def format_input( + self, data: Dict[str, Any], protocol: str, max_batch_size: int, **kwargs + ) -> Tuple[List[Any], List[Dict[str, Any]]]: + """ + Format input data for the specified protocol, returning a tuple of: + (formatted_batches, formatted_batch_data) + where: + - For gRPC: formatted_batches is a list of NumPy arrays, each of shape (B, H, W, C) + with B <= max_batch_size. + - For HTTP: formatted_batches is a list of JSON-serializable dict payloads. + - In both cases, formatted_batch_data is a list of dicts that coalesce the original + images and their original shapes in the same order as provided. + + Parameters + ---------- + data : dict + The input data to format. Must include: + - "images": a list of numpy.ndarray images. + - "original_image_shapes": a list of tuples with each image's (height, width), + as set by prepare_data_for_inference. + protocol : str + The protocol to use ("grpc" or "http"). + max_batch_size : int + The maximum number of images per batch. + + Returns + ------- + tuple + A tuple (formatted_batches, formatted_batch_data). + + Raises + ------ + ValueError + If the protocol is invalid. + """ + + # Helper functions to chunk a list into sublists of length up to chunk_size. + def chunk_list(lst: list, chunk_size: int) -> List[list]: + return [lst[i : i + chunk_size] for i in range(0, len(lst), chunk_size)] + + def chunk_list_geometrically(lst: list, max_size: int) -> List[list]: + # TRT engine in Yolox NIM (gRPC) only allows a batch size in powers of 2. + chunks = [] + i = 0 + while i < len(lst): + chunk_size = min(2 ** int(log(len(lst) - i, 2)), max_size) + chunks.append(lst[i : i + chunk_size]) + i += chunk_size + return chunks + + if protocol == "grpc": + logger.debug("Formatting input for gRPC Yolox model") + # Resize images for model input (Yolox expects 1024x1024). + resized_images = [ + resize_image(image, (self.image_preproc_width, self.image_preproc_height)) for image in data["images"] + ] + # Chunk the resized images, the original images, and their shapes. + resized_chunks = chunk_list_geometrically(resized_images, max_batch_size) + original_chunks = chunk_list_geometrically(data["images"], max_batch_size) + shape_chunks = chunk_list_geometrically(data["original_image_shapes"], max_batch_size) + + batched_inputs = [] + formatted_batch_data = [] + for r_chunk, orig_chunk, shapes in zip(resized_chunks, original_chunks, shape_chunks): + # Reorder axes from (B, H, W, C) to (B, C, H, W) as expected by the model. + input_array = np.einsum("bijk->bkij", r_chunk).astype(np.float32) + batched_inputs.append(input_array) + formatted_batch_data.append({"images": orig_chunk, "original_image_shapes": shapes}) + return batched_inputs, formatted_batch_data + + elif protocol == "http": + logger.debug("Formatting input for HTTP Yolox model") + content_list: List[Dict[str, Any]] = [] + for image in data["images"]: + # Convert the numpy array to a PIL Image. Assume images are in [0,1]. + image_pil = Image.fromarray((image * 255).astype(np.uint8)) + original_size = image_pil.size + + # Save the image to a buffer and encode to base64. + buffered = io.BytesIO() + image_pil.save(buffered, format="PNG") + image_b64 = base64.b64encode(buffered.getvalue()).decode("utf-8") + + # Scale the image if necessary. + scaled_image_b64, new_size = scale_image_to_encoding_size( + image_b64, max_base64_size=self.nim_max_image_size + ) + if new_size != original_size: + logger.debug(f"Image was scaled from {original_size} to {new_size}.") + + content_list.append({"type": "image_url", "url": f"data:image/png;base64,{scaled_image_b64}"}) + + # Chunk the payload content, the original images, and their shapes. + content_chunks = chunk_list(content_list, max_batch_size) + original_chunks = chunk_list(data["images"], max_batch_size) + shape_chunks = chunk_list(data["original_image_shapes"], max_batch_size) + + payload_batches = [] + formatted_batch_data = [] + for chunk, orig_chunk, shapes in zip(content_chunks, original_chunks, shape_chunks): + payload = {"input": chunk} + payload_batches.append(payload) + formatted_batch_data.append({"images": orig_chunk, "original_image_shapes": shapes}) + return payload_batches, formatted_batch_data + + else: + raise ValueError("Invalid protocol specified. Must be 'grpc' or 'http'.") + + def parse_output(self, response: Any, protocol: str, data: Optional[Dict[str, Any]] = None, **kwargs) -> Any: + """ + Parse the output from the model's inference response. + + Parameters + ---------- + response : Any + The response from the model inference. + protocol : str + The protocol used ("grpc" or "http"). + data : dict, optional + Additional input data passed to the function. + + Returns + ------- + Any + The parsed output data. + + Raises + ------ + ValueError + If an invalid protocol is specified or the response format is unexpected. + """ + + if protocol == "grpc": + logger.debug("Parsing output from gRPC Yolox model") + return response # For gRPC, response is already a numpy array + elif protocol == "http": + logger.debug("Parsing output from HTTP Yolox model") + + processed_outputs = [] + + batch_results = response.get("data", []) + for detections in batch_results: + new_bounding_boxes = {label: [] for label in self.class_labels} + + bounding_boxes = detections.get("bounding_boxes", []) + for obj_type, bboxes in bounding_boxes.items(): + for bbox in bboxes: + xmin = bbox["x_min"] + ymin = bbox["y_min"] + xmax = bbox["x_max"] + ymax = bbox["y_max"] + confidence = bbox["confidence"] + + new_bounding_boxes[obj_type].append([xmin, ymin, xmax, ymax, confidence]) + + processed_outputs.append(new_bounding_boxes) + + return processed_outputs + else: + raise ValueError("Invalid protocol specified. Must be 'grpc' or 'http'.") + + def process_inference_results(self, output: Any, protocol: str, **kwargs) -> List[Dict[str, Any]]: + """ + Process the results of the Yolox model inference and return the final annotations. + + Parameters + ---------- + output_array : np.ndarray + The raw output from the Yolox model. + kwargs : dict + Additional parameters for processing, including thresholds and number of classes. + + Returns + ------- + list[dict] + A list of annotation dictionaries for each image in the batch. + """ + original_image_shapes = kwargs.get("original_image_shapes", []) + + if protocol == "http": + # For http, the output already has postprocessing applied. Skip to table/chart expansion. + results = output + + elif protocol == "grpc": + # For grpc, apply the same NIM postprocessing. + pred = postprocess_model_prediction( + output, self.num_classes, self.conf_threshold, self.iou_threshold, class_agnostic=True + ) + results = postprocess_results( + pred, + original_image_shapes, + self.image_preproc_width, + self.image_preproc_height, + self.class_labels, + min_score=self.min_score, + ) + + inference_results = self.postprocess_annotations(results, **kwargs) + + return inference_results + + def postprocess_annotations(self, annotation_dicts, **kwargs): + raise NotImplementedError() + + def transform_normalized_coordinates_to_original(self, results, original_image_shapes): + """ """ + transformed_results = [] + + for annotation_dict, shape in zip(results, original_image_shapes): + new_dict = {} + for label, bboxes_and_scores in annotation_dict.items(): + new_dict[label] = [] + for bbox_and_score in bboxes_and_scores: + bbox = bbox_and_score[:4] + transformed_bbox = [ + bbox[0] * shape[1], + bbox[1] * shape[0], + bbox[2] * shape[1], + bbox[3] * shape[0], + ] + transformed_bbox += bbox_and_score[4:] + new_dict[label].append(transformed_bbox) + transformed_results.append(new_dict) + + return transformed_results + + +class YoloxPageElementsModelInterface(YoloxModelInterfaceBase): + """ + An interface for handling inference with yolox-page-elements model, supporting both gRPC and HTTP protocols. + """ + + def __init__(self): + """ + Initialize the yolox-page-elements model interface. + """ + super().__init__( + image_preproc_width=YOLOX_PAGE_IMAGE_PREPROC_HEIGHT, + image_preproc_height=YOLOX_PAGE_IMAGE_PREPROC_HEIGHT, + nim_max_image_size=YOLOX_PAGE_NIM_MAX_IMAGE_SIZE, + num_classes=YOLOX_PAGE_NUM_CLASSES, + conf_threshold=YOLOX_PAGE_CONF_THRESHOLD, + iou_threshold=YOLOX_PAGE_IOU_THRESHOLD, + min_score=YOLOX_PAGE_MIN_SCORE, + final_score=YOLOX_PAGE_FINAL_SCORE, + class_labels=YOLOX_PAGE_CLASS_LABELS, + ) + + def name( + self, + ) -> str: + """ + Returns the name of the Yolox model interface. + + Returns + ------- + str + The name of the model interface. + """ + + return "yolox-page-elements" + + def postprocess_annotations(self, annotation_dicts, **kwargs): + original_image_shapes = kwargs.get("original_image_shapes", []) + + # Table/chart expansion is "business logic" specific to nv-ingest + annotation_dicts = [expand_table_bboxes(annotation_dict) for annotation_dict in annotation_dicts] + annotation_dicts = [expand_chart_bboxes(annotation_dict) for annotation_dict in annotation_dicts] + inference_results = [] + + # Filter out bounding boxes below the final threshold + # This final thresholding is "business logic" specific to nv-ingest + for annotation_dict in annotation_dicts: + new_dict = {} + if "table" in annotation_dict: + new_dict["table"] = [bb for bb in annotation_dict["table"] if bb[4] >= self.final_score] + if "chart" in annotation_dict: + new_dict["chart"] = [bb for bb in annotation_dict["chart"] if bb[4] >= self.final_score] + if "title" in annotation_dict: + new_dict["title"] = annotation_dict["title"] + inference_results.append(new_dict) + + inference_results = self.transform_normalized_coordinates_to_original(inference_results, original_image_shapes) + + return inference_results + + +class YoloxGraphicElementsModelInterface(YoloxModelInterfaceBase): + """ + An interface for handling inference with yolox-graphic-elemenents model, supporting both gRPC and HTTP protocols. + """ + + def __init__(self): + """ + Initialize the yolox-graphic-elements model interface. + """ + super().__init__( + image_preproc_width=YOLOX_GRAPHIC_IMAGE_PREPROC_HEIGHT, + image_preproc_height=YOLOX_GRAPHIC_IMAGE_PREPROC_HEIGHT, + nim_max_image_size=YOLOX_GRAPHIC_NIM_MAX_IMAGE_SIZE, + num_classes=YOLOX_GRAPHIC_NUM_CLASSES, + conf_threshold=YOLOX_GRAPHIC_CONF_THRESHOLD, + iou_threshold=YOLOX_GRAPHIC_IOU_THRESHOLD, + min_score=YOLOX_GRAPHIC_MIN_SCORE, + final_score=YOLOX_GRAPHIC_FINAL_SCORE, + class_labels=YOLOX_GRAPHIC_CLASS_LABELS, + ) + + def name( + self, + ) -> str: + """ + Returns the name of the Yolox model interface. + + Returns + ------- + str + The name of the model interface. + """ + + return "yolox-graphic-elements" + + def postprocess_annotations(self, annotation_dicts, **kwargs): + original_image_shapes = kwargs.get("original_image_shapes", []) + + annotation_dicts = self.transform_normalized_coordinates_to_original(annotation_dicts, original_image_shapes) + + inference_results = [] + + # bbox extraction: additional postprocessing speicifc to nv-ingest + for pred, shape in zip(annotation_dicts, original_image_shapes): + bbox_dict = get_bbox_dict_yolox_graphic( + pred, + shape, + self.class_labels, + self.min_score, + ) + # convert numpy arrays to list + bbox_dict = { + label: array.tolist() if isinstance(array, np.ndarray) else array for label, array in bbox_dict.items() + } + inference_results.append(bbox_dict) + + return inference_results + + +def postprocess_model_prediction(prediction, num_classes, conf_thre=0.7, nms_thre=0.45, class_agnostic=False): + # Convert numpy array to torch tensor + prediction = torch.from_numpy(prediction.copy()) + + # Compute box corners + box_corner = prediction.new(prediction.shape) + box_corner[:, :, 0] = prediction[:, :, 0] - prediction[:, :, 2] / 2 + box_corner[:, :, 1] = prediction[:, :, 1] - prediction[:, :, 3] / 2 + box_corner[:, :, 2] = prediction[:, :, 0] + prediction[:, :, 2] / 2 + box_corner[:, :, 3] = prediction[:, :, 1] + prediction[:, :, 3] / 2 + prediction[:, :, :4] = box_corner[:, :, :4] + + output = [None for _ in range(len(prediction))] + + for i, image_pred in enumerate(prediction): + # If no detections, continue to the next image + if not image_pred.size(0): + continue + + # Ensure image_pred is 2D + if image_pred.ndim == 1: + image_pred = image_pred.unsqueeze(0) + + # Get score and class with highest confidence + class_conf, class_pred = torch.max(image_pred[:, 5 : 5 + num_classes], 1, keepdim=True) + + # Confidence mask + squeezed_conf = class_conf.squeeze(dim=1) + conf_mask = image_pred[:, 4] * squeezed_conf >= conf_thre + + # Apply confidence mask + detections = torch.cat((image_pred[:, :5], class_conf, class_pred.float()), 1) + detections = detections[conf_mask] + + if not detections.size(0): + continue + + # Apply Non-Maximum Suppression (NMS) + if class_agnostic: + nms_out_index = torchvision.ops.nms( + detections[:, :4], + detections[:, 4] * detections[:, 5], + nms_thre, + ) + else: + nms_out_index = torchvision.ops.batched_nms( + detections[:, :4], + detections[:, 4] * detections[:, 5], + detections[:, 6], + nms_thre, + ) + detections = detections[nms_out_index] + + # Append detections to output + output[i] = detections + + return output + + +def postprocess_results( + results, original_image_shapes, image_preproc_width, image_preproc_height, class_labels, min_score=0.0 +): + """ + For each item (==image) in results, computes annotations in the form + + {"table": [[0.0107, 0.0859, 0.7537, 0.1219, 0.9861], ...], + "figure": [...], + "title": [...] + } + where each list of 5 floats represents a bounding box in the format [x1, y1, x2, y2, confidence] + + Keep only bboxes with high enough confidence. + """ + out = [] + + for original_image_shape, result in zip(original_image_shapes, results): + annotation_dict = {label: [] for label in class_labels} + + if result is None: + out.append(annotation_dict) + continue + + try: + result = result.cpu().numpy() + scores = result[:, 4] * result[:, 5] + result = result[scores > min_score] + + # ratio is used when image was padded + ratio = min( + image_preproc_width / original_image_shape[0], + image_preproc_height / original_image_shape[1], + ) + bboxes = result[:, :4] / ratio + + bboxes[:, [0, 2]] /= original_image_shape[1] + bboxes[:, [1, 3]] /= original_image_shape[0] + bboxes = np.clip(bboxes, 0.0, 1.0) + + labels = result[:, 6] + scores = scores[scores > min_score] + except Exception as e: + raise ValueError(f"Error in postprocessing {result.shape} and {original_image_shape}: {e}") + + for box, score, label in zip(bboxes, scores, labels): + class_name = class_labels[int(label)] + annotation_dict[class_name].append([round(float(x), 4) for x in np.concatenate((box, [score]))]) + + out.append(annotation_dict) + + return out + + +def resize_image(image, target_img_size): + w, h, _ = np.array(image).shape + + if target_img_size is not None: # Resize + Pad + r = min(target_img_size[0] / w, target_img_size[1] / h) + image = cv2.resize( + image, + (int(h * r), int(w * r)), + interpolation=cv2.INTER_LINEAR, + ).astype(np.uint8) + image = np.pad( + image, + ((0, target_img_size[0] - image.shape[0]), (0, target_img_size[1] - image.shape[1]), (0, 0)), + mode="constant", + constant_values=114, + ) + + return image + + +def expand_table_bboxes(annotation_dict, labels=None): + """ + Additional preprocessing for tables: extend the upper bounds to capture titles if any. + Args: + annotation_dict: output of postprocess_results, a dictionary with keys "table", "figure", "title" + + Returns: + annotation_dict: same as input, with expanded bboxes for charts + + """ + if not labels: + labels = ["table", "chart", "title"] + + if not annotation_dict or len(annotation_dict["table"]) == 0: + return annotation_dict + + new_annotation_dict = {label: [] for label in labels} + + for label, bboxes in annotation_dict.items(): + for bbox_and_score in bboxes: + bbox, score = bbox_and_score[:4], bbox_and_score[4] + + if label == "table": + height = bbox[3] - bbox[1] + bbox[1] = max(0.0, min(1.0, bbox[1] - height * 0.2)) + + new_annotation_dict[label].append([round(float(x), 4) for x in bbox + [score]]) + + return new_annotation_dict + + +def expand_chart_bboxes(annotation_dict, labels=None): + """ + Expand bounding boxes of charts and titles based on the bounding boxes of the other class. + Args: + annotation_dict: output of postprocess_results, a dictionary with keys "table", "figure", "title" + + Returns: + annotation_dict: same as input, with expanded bboxes for charts + + """ + if not labels: + labels = ["table", "chart", "title"] + + if not annotation_dict or len(annotation_dict["chart"]) == 0: + return annotation_dict + + bboxes = [] + confidences = [] + label_idxs = [] + for i, label in enumerate(labels): + label_annotations = np.array(annotation_dict[label]) + + if len(label_annotations) > 0: + bboxes.append(label_annotations[:, :4]) + confidences.append(label_annotations[:, 4]) + label_idxs.append(np.full(len(label_annotations), i)) + bboxes = np.concatenate(bboxes) + confidences = np.concatenate(confidences) + label_idxs = np.concatenate(label_idxs) + + pred_wbf, confidences_wbf, labels_wbf = weighted_boxes_fusion( + bboxes[:, None], + confidences[:, None], + label_idxs[:, None], + merge_type="biggest", + conf_type="max", + iou_thr=0.01, + class_agnostic=False, + ) + chart_bboxes = pred_wbf[labels_wbf == 1] + chart_confidences = confidences_wbf[labels_wbf == 1] + title_bboxes = pred_wbf[labels_wbf == 2] + + found_title_idxs, no_found_title_idxs = [], [] + for i in range(len(chart_bboxes)): + match = match_with_title(chart_bboxes[i], title_bboxes, iou_th=0.01) + if match is not None: + chart_bboxes[i] = match[0] + title_bboxes = match[1] + found_title_idxs.append(i) + else: + no_found_title_idxs.append(i) + + chart_bboxes[found_title_idxs] = expand_boxes(chart_bboxes[found_title_idxs], r_x=1.05, r_y=1.1) + chart_bboxes[no_found_title_idxs] = expand_boxes(chart_bboxes[no_found_title_idxs], r_x=1.1, r_y=1.25) + + annotation_dict = { + "table": annotation_dict["table"], + "chart": np.concatenate([chart_bboxes, chart_confidences[:, None]], axis=1).tolist(), + "title": annotation_dict["title"], + } + return annotation_dict + + +def weighted_boxes_fusion( + boxes_list, + scores_list, + labels_list, + iou_thr=0.5, + skip_box_thr=0.0, + conf_type="avg", + merge_type="weighted", + class_agnostic=False, +): + """ + Custom wbf implementation that supports a class_agnostic mode and a biggest box fusion. + Boxes are expected to be in normalized (x0, y0, x1, y1) format. + + Args: + boxes_list (list[np array[n x 4]]): List of boxes. One list per model. + scores_list (list[np array[n]]): List of confidences. + labels_list (list[np array[n]]): List of labels + iou_thr (float, optional): IoU threshold for matching. Defaults to 0.55. + skip_box_thr (float, optional): Exclude boxes with score < skip_box_thr. Defaults to 0.0. + conf_type (str, optional): Confidence merging type. Defaults to "avg". + merge_type (str, optional): Merge type "weighted" or "biggest". Defaults to "weighted". + class_agnostic (bool, optional): If True, merge boxes from different classes. Defaults to False. + + Returns: + np array[N x 4]: Merged boxes, + np array[N]: Merged confidences, + np array[N]: Merged labels. + """ + weights = np.ones(len(boxes_list)) + + assert conf_type in ["avg", "max"], 'Conf type must be "avg" or "max"' + assert merge_type in [ + "weighted", + "biggest", + ], 'Conf type must be "weighted" or "biggest"' + + filtered_boxes = prefilter_boxes( + boxes_list, + scores_list, + labels_list, + weights, + skip_box_thr, + class_agnostic=class_agnostic, + ) + if len(filtered_boxes) == 0: + return np.zeros((0, 4)), np.zeros((0,)), np.zeros((0,)) + + overall_boxes = [] + for label in filtered_boxes: + boxes = filtered_boxes[label] + np.empty((0, 8)) + + clusters = [] + + # Clusterize boxes + for j in range(len(boxes)): + ids = [i for i in range(len(boxes)) if i != j] + index, best_iou = find_matching_box_fast(boxes[ids], boxes[j], iou_thr) + + if index != -1: + index = ids[index] + cluster_idx = [clust_idx for clust_idx, clust in enumerate(clusters) if (j in clust or index in clust)] + if len(cluster_idx): + cluster_idx = cluster_idx[0] + clusters[cluster_idx] = list(set(clusters[cluster_idx] + [index, j])) + else: + clusters.append([index, j]) + else: + clusters.append([j]) + + for j, c in enumerate(clusters): + if merge_type == "weighted": + weighted_box = get_weighted_box(boxes[c], conf_type) + elif merge_type == "biggest": + weighted_box = get_biggest_box(boxes[c], conf_type) + + if conf_type == "max": + weighted_box[1] = weighted_box[1] / weights.max() + else: # avg + weighted_box[1] = weighted_box[1] * len(c) / weights.sum() + overall_boxes.append(weighted_box) + + overall_boxes = np.array(overall_boxes) + overall_boxes = overall_boxes[overall_boxes[:, 1].argsort()[::-1]] + boxes = overall_boxes[:, 4:] + scores = overall_boxes[:, 1] + labels = overall_boxes[:, 0] + return boxes, scores, labels + + +def prefilter_boxes(boxes, scores, labels, weights, thr, class_agnostic=False): + """ + Reformats and filters boxes. + Output is a dict of boxes to merge separately. + + Args: + boxes (list[np array[n x 4]]): List of boxes. One list per model. + scores (list[np array[n]]): List of confidences. + labels (list[np array[n]]): List of labels. + weights (list): Model weights. + thr (float): Confidence threshold + class_agnostic (bool, optional): If True, merge boxes from different classes. Defaults to False. + + Returns: + dict[np array [? x 8]]: Filtered boxes. + """ + # Create dict with boxes stored by its label + new_boxes = dict() + + for t in range(len(boxes)): + if len(boxes[t]) != len(scores[t]): + print( + "Error. Length of boxes arrays not equal to length of scores array: {} != {}".format( + len(boxes[t]), len(scores[t]) + ) + ) + exit() + + if len(boxes[t]) != len(labels[t]): + print( + "Error. Length of boxes arrays not equal to length of labels array: {} != {}".format( + len(boxes[t]), len(labels[t]) + ) + ) + exit() + + for j in range(len(boxes[t])): + score = scores[t][j] + if score < thr: + continue + label = int(labels[t][j]) + box_part = boxes[t][j] + x1 = float(box_part[0]) + y1 = float(box_part[1]) + x2 = float(box_part[2]) + y2 = float(box_part[3]) + + # Box data checks + if x2 < x1: + warnings.warn("X2 < X1 value in box. Swap them.") + x1, x2 = x2, x1 + if y2 < y1: + warnings.warn("Y2 < Y1 value in box. Swap them.") + y1, y2 = y2, y1 + if x1 < 0: + warnings.warn("X1 < 0 in box. Set it to 0.") + x1 = 0 + if x1 > 1: + warnings.warn("X1 > 1 in box. Set it to 1. Check that you normalize boxes in [0, 1] range.") + x1 = 1 + if x2 < 0: + warnings.warn("X2 < 0 in box. Set it to 0.") + x2 = 0 + if x2 > 1: + warnings.warn("X2 > 1 in box. Set it to 1. Check that you normalize boxes in [0, 1] range.") + x2 = 1 + if y1 < 0: + warnings.warn("Y1 < 0 in box. Set it to 0.") + y1 = 0 + if y1 > 1: + warnings.warn("Y1 > 1 in box. Set it to 1. Check that you normalize boxes in [0, 1] range.") + y1 = 1 + if y2 < 0: + warnings.warn("Y2 < 0 in box. Set it to 0.") + y2 = 0 + if y2 > 1: + warnings.warn("Y2 > 1 in box. Set it to 1. Check that you normalize boxes in [0, 1] range.") + y2 = 1 + if (x2 - x1) * (y2 - y1) == 0.0: + warnings.warn("Zero area box skipped: {}.".format(box_part)) + continue + + # [label, score, weight, model index, x1, y1, x2, y2] + b = [int(label), float(score) * weights[t], weights[t], t, x1, y1, x2, y2] + + label_k = "*" if class_agnostic else label + if label_k not in new_boxes: + new_boxes[label_k] = [] + new_boxes[label_k].append(b) + + # Sort each list in dict by score and transform it to numpy array + for k in new_boxes: + current_boxes = np.array(new_boxes[k]) + new_boxes[k] = current_boxes[current_boxes[:, 1].argsort()[::-1]] + + return new_boxes + + +def find_matching_box_fast(boxes_list, new_box, match_iou): + """ + Reimplementation of find_matching_box with numpy instead of loops. Gives significant speed up for larger arrays + (~100x). This was previously the bottleneck since the function is called for every entry in the array. + """ + + def bb_iou_array(boxes, new_box): + # bb interesection over union + xA = np.maximum(boxes[:, 0], new_box[0]) + yA = np.maximum(boxes[:, 1], new_box[1]) + xB = np.minimum(boxes[:, 2], new_box[2]) + yB = np.minimum(boxes[:, 3], new_box[3]) + + interArea = np.maximum(xB - xA, 0) * np.maximum(yB - yA, 0) + + # compute the area of both the prediction and ground-truth rectangles + boxAArea = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) + boxBArea = (new_box[2] - new_box[0]) * (new_box[3] - new_box[1]) + + iou = interArea / (boxAArea + boxBArea - interArea) + + return iou + + if boxes_list.shape[0] == 0: + return -1, match_iou + + ious = bb_iou_array(boxes_list[:, 4:], new_box[4:]) + # ious[boxes[:, 0] != new_box[0]] = -1 + + best_idx = np.argmax(ious) + best_iou = ious[best_idx] + + if best_iou <= match_iou: + best_iou = match_iou + best_idx = -1 + + return best_idx, best_iou + + +def get_biggest_box(boxes, conf_type="avg"): + """ + Merges boxes by using the biggest box. + + Args: + boxes (np array [n x 8]): Boxes to merge. + conf_type (str, optional): Confidence merging type. Defaults to "avg". + + Returns: + np array [8]: Merged box. + """ + box = np.zeros(8, dtype=np.float32) + box[4:] = boxes[0][4:] + conf_list = [] + w = 0 + for b in boxes: + box[4] = min(box[4], b[4]) + box[5] = min(box[5], b[5]) + box[6] = max(box[6], b[6]) + box[7] = max(box[7], b[7]) + conf_list.append(b[1]) + w += b[2] + + box[0] = merge_labels(np.array([b[0] for b in boxes]), np.array([b[1] for b in boxes])) + # print(box[0], np.array([b[0] for b in boxes])) + + box[1] = np.max(conf_list) if conf_type == "max" else np.mean(conf_list) + box[2] = w + box[3] = -1 # model index field is retained for consistency but is not used. + return box + + +def merge_labels(labels, confs): + """ + Custom function for merging labels. + If all labels are the same, return the unique value. + Else, return the label of the most confident non-title (class 2) box. + + Args: + labels (np array [n]): Labels. + confs (np array [n]): Confidence. + + Returns: + int: Label. + """ + if len(np.unique(labels)) == 1: + return labels[0] + else: # Most confident and not a title + confs = confs[confs != 2] + labels = labels[labels != 2] + return labels[np.argmax(confs)] + + +def match_with_title(chart_bbox, title_bboxes, iou_th=0.01): + if not len(title_bboxes): + return None + + dist_above = np.abs(title_bboxes[:, 3] - chart_bbox[1]) + dist_below = np.abs(chart_bbox[3] - title_bboxes[:, 1]) + + dist_left = np.abs(title_bboxes[:, 0] - chart_bbox[0]) + + ious = bb_iou_array(title_bboxes, chart_bbox) + + matches = None + if np.max(ious) > iou_th: + matches = np.where(ious > iou_th)[0] + else: + dists = np.min([dist_above, dist_below], 0) + dists += dist_left + # print(dists) + if np.min(dists) < 0.1: + matches = [np.argmin(dists)] + + if matches is not None: + new_bbox = chart_bbox + for match in matches: + new_bbox = merge_boxes(new_bbox, title_bboxes[match]) + title_bboxes = title_bboxes[[i for i in range(len(title_bboxes)) if i not in matches]] + return new_bbox, title_bboxes + + else: + return None + + +def bb_iou_array(boxes, new_box): + # bb interesection over union + xA = np.maximum(boxes[:, 0], new_box[0]) + yA = np.maximum(boxes[:, 1], new_box[1]) + xB = np.minimum(boxes[:, 2], new_box[2]) + yB = np.minimum(boxes[:, 3], new_box[3]) + + interArea = np.maximum(xB - xA, 0) * np.maximum(yB - yA, 0) + + # compute the area of both the prediction and ground-truth rectangles + boxAArea = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) + boxBArea = (new_box[2] - new_box[0]) * (new_box[3] - new_box[1]) + + iou = interArea / (boxAArea + boxBArea - interArea) + + return iou + + +def merge_boxes(b1, b2): + b = b1.copy() + b[0] = min(b1[0], b2[0]) + b[1] = min(b1[1], b2[1]) + b[2] = max(b1[2], b2[2]) + b[3] = max(b1[3], b2[3]) + return b + + +def expand_boxes(boxes, r_x=1, r_y=1): + dw = (boxes[:, 2] - boxes[:, 0]) / 2 * (r_x - 1) + boxes[:, 0] -= dw + boxes[:, 2] += dw + + dh = (boxes[:, 3] - boxes[:, 1]) / 2 * (r_y - 1) + boxes[:, 1] -= dh + boxes[:, 3] += dh + + boxes = np.clip(boxes, 0, 1) + return boxes + + +def get_weighted_box(boxes, conf_type="avg"): + """ + Merges boxes by using the weighted fusion. + + Args: + boxes (np array [n x 8]): Boxes to merge. + conf_type (str, optional): Confidence merging type. Defaults to "avg". + + Returns: + np array [8]: Merged box. + """ + box = np.zeros(8, dtype=np.float32) + conf = 0 + conf_list = [] + w = 0 + for b in boxes: + box[4:] += b[1] * b[4:] + conf += b[1] + conf_list.append(b[1]) + w += b[2] + + box[0] = merge_labels(np.array([b[0] for b in boxes]), np.array([b[1] for b in boxes])) + + box[1] = np.max(conf_list) if conf_type == "max" else np.mean(conf_list) + box[2] = w + box[3] = -1 # model index field is retained for consistency but is not used. + box[4:] /= conf + return box + + +def batched_overlaps(A, B): + """ + Calculate the Intersection over Union (IoU) between + two sets of bounding boxes in a batched manner. + Normalization is modified to only use the area of A boxes, hence computing the overlaps. + Args: + A (ndarray): Array of bounding boxes of shape (N, 4) in format [x1, y1, x2, y2]. + B (ndarray): Array of bounding boxes of shape (M, 4) in format [x1, y1, x2, y2]. + Returns: + ndarray: Array of IoU values of shape (N, M) representing the overlaps + between each pair of bounding boxes. + """ + A = A.copy() + B = B.copy() + + A = A[None].repeat(B.shape[0], 0) + B = B[:, None].repeat(A.shape[1], 1) + + low = np.s_[..., :2] + high = np.s_[..., 2:] + + A, B = A.copy(), B.copy() + A[high] += 1 + B[high] += 1 + + intrs = (np.maximum(0, np.minimum(A[high], B[high]) - np.maximum(A[low], B[low]))).prod(-1) + ious = intrs / (A[high] - A[low]).prod(-1) + + return ious + + +def find_boxes_inside(boxes, boxes_to_check, threshold=0.9): + """ + Find all boxes that are inside another box based on + the intersection area divided by the area of the smaller box, + and removes them. + """ + overlaps = batched_overlaps(boxes_to_check, boxes) + to_keep = (overlaps >= threshold).sum(0) <= 1 + return boxes_to_check[to_keep] + + +def get_bbox_dict_yolox_graphic(preds, shape, class_labels, threshold_=0.1) -> Dict[str, np.ndarray]: + """ + Extracts bounding boxes from YOLOX model predictions: + - Applies thresholding + - Reformats boxes + - Cleans the `other` detections: removes the ones that are included in other detections. + - If no title is found, the biggest `other` box is used if it is larger than 0.3*img_w. + Args: + preds (np.ndarray): YOLOX model predictions including bounding boxes, scores, and labels. + shape (tuple): Original image shape. + threshold_ (float): Score threshold to filter bounding boxes. + Returns: + Dict[str, np.ndarray]: Dictionary of bounding boxes, organized by class. + """ + bbox_dict = {label: np.array([]) for label in class_labels} + + for i, label in enumerate(class_labels): + bboxes_class = np.array(preds[label]) + + if bboxes_class.size == 0: + continue + + # Try to find a chart_title box + threshold = threshold_ if label != "chart_title" else min(threshold_, bboxes_class[:, -1].max()) + bboxes_class = bboxes_class[bboxes_class[:, -1] >= threshold][:, :4].astype(int) + + sort = ["x0", "y0"] if label != "ylabel" else ["y0", "x0"] + idxs = ( + pd.DataFrame( + { + "y0": bboxes_class[:, 1], + "x0": bboxes_class[:, 0], + } + ) + .sort_values(sort, ascending=label != "ylabel") + .index + ) + bboxes_class = bboxes_class[idxs] + bbox_dict[label] = bboxes_class + + # Remove other included + if len(bbox_dict.get("other", [])): + other = find_boxes_inside( + np.concatenate(list([v for v in bbox_dict.values() if len(v)])), bbox_dict["other"], threshold=0.7 + ) + del bbox_dict["other"] + if len(other): + bbox_dict["other"] = other + + # Biggest other is title if no title + if not len(bbox_dict.get("chart_title", [])) and len(bbox_dict.get("other", [])): + boxes = bbox_dict["other"] + ws = boxes[:, 2] - boxes[:, 0] + if np.max(ws) > shape[1] * 0.3: + bbox_dict["chart_title"] = boxes[np.argmax(ws)][None].copy() + bbox_dict["other"] = np.delete(boxes, (np.argmax(ws)), axis=0) + + # Make sure other key not lost + bbox_dict["other"] = bbox_dict.get("other", []) + + return bbox_dict