diff --git a/src/nv_ingest/util/nim/__init__.py b/src/nv_ingest/util/nim/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/src/nv_ingest/util/nim/cached.py b/src/nv_ingest/util/nim/cached.py
new file mode 100644
index 00000000..299f249c
--- /dev/null
+++ b/src/nv_ingest/util/nim/cached.py
@@ -0,0 +1,274 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES.
+# All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+
+
+import base64
+import io
+import logging
+import PIL.Image as Image
+from typing import Any, Dict, Optional, List
+
+import numpy as np
+
+from nv_ingest.util.image_processing.transforms import base64_to_numpy
+from nv_ingest.util.nim.helpers import ModelInterface
+
+logger = logging.getLogger(__name__)
+
+
+class CachedModelInterface(ModelInterface):
+ """
+ An interface for handling inference with a Cached model, supporting both gRPC and HTTP
+ protocols, including batched input.
+ """
+
+ def name(self) -> str:
+ """
+ Get the name of the model interface.
+
+ Returns
+ -------
+ str
+ The name of the model interface ("Cached").
+ """
+ return "Cached"
+
+ def prepare_data_for_inference(self, data: Dict[str, Any]) -> Dict[str, Any]:
+ """
+ Decode base64-encoded images into NumPy arrays, storing them in `data["image_arrays"]`.
+
+ Parameters
+ ----------
+ data : dict of str -> Any
+ The input data containing either:
+ - "base64_image": a single base64-encoded image, or
+ - "base64_images": a list of base64-encoded images.
+
+ Returns
+ -------
+ dict of str -> Any
+ The updated data dictionary with decoded image arrays stored in
+ "image_arrays", where each array has shape (H, W, C).
+
+ Raises
+ ------
+ KeyError
+ If neither 'base64_image' nor 'base64_images' is provided.
+ ValueError
+ If 'base64_images' is provided but is not a list.
+ """
+ if "base64_images" in data:
+ base64_list = data["base64_images"]
+ if not isinstance(base64_list, list):
+ raise ValueError("The 'base64_images' key must contain a list of base64-encoded strings.")
+ data["image_arrays"] = [base64_to_numpy(img) for img in base64_list]
+
+ elif "base64_image" in data:
+ # Fallback to single image case; wrap it in a list to keep the interface consistent
+ data["image_arrays"] = [base64_to_numpy(data["base64_image"])]
+
+ else:
+ raise KeyError("Input data must include 'base64_image' or 'base64_images' with base64-encoded images.")
+
+ return data
+
+ def format_input(self, data: Dict[str, Any], protocol: str, max_batch_size: int, **kwargs) -> Any:
+ """
+ Format input data for the specified protocol ("grpc" or "http"), handling batched images.
+ Additionally, returns batched data that coalesces the original image arrays and their dimensions
+ in the same order as provided.
+
+ Parameters
+ ----------
+ data : dict of str -> Any
+ The input data dictionary, expected to contain "image_arrays" (a list of np.ndarray).
+ protocol : str
+ The protocol to use, "grpc" or "http".
+ max_batch_size : int
+ The maximum number of images per batch.
+
+ Returns
+ -------
+ tuple
+ A tuple (formatted_batches, formatted_batch_data) where:
+ - For gRPC: formatted_batches is a list of NumPy arrays, each of shape (B, H, W, C)
+ with B <= max_batch_size.
+ - For HTTP: formatted_batches is a list of JSON-serializable dict payloads.
+ - In both cases, formatted_batch_data is a list of dicts with the keys:
+ "image_arrays": the list of original np.ndarray images for that batch, and
+ "image_dims": a list of (height, width) tuples for each image in the batch.
+
+ Raises
+ ------
+ KeyError
+ If "image_arrays" is missing in the data dictionary.
+ ValueError
+ If the protocol is invalid, or if no valid images are found.
+ """
+ if "image_arrays" not in data:
+ raise KeyError("Expected 'image_arrays' in data. Make sure prepare_data_for_inference was called.")
+
+ image_arrays = data["image_arrays"]
+ # Compute dimensions for each image.
+ image_dims = [(img.shape[0], img.shape[1]) for img in image_arrays]
+
+ # Helper: chunk a list into sublists of length up to chunk_size.
+ def chunk_list(lst: list, chunk_size: int) -> List[list]:
+ return [lst[i : i + chunk_size] for i in range(0, len(lst), chunk_size)]
+
+ if protocol == "grpc":
+ logger.debug("Formatting input for gRPC Cached model (batched).")
+ batched_images = []
+ for arr in image_arrays:
+ # Expand from (H, W, C) to (1, H, W, C) if needed
+ if arr.ndim == 3:
+ arr = np.expand_dims(arr, axis=0)
+ batched_images.append(arr.astype(np.float32))
+
+ if not batched_images:
+ raise ValueError("No valid images found for gRPC formatting.")
+
+ # Chunk the processed images, original arrays, and dimensions.
+ batched_image_chunks = chunk_list(batched_images, max_batch_size)
+ orig_chunks = chunk_list(image_arrays, max_batch_size)
+ dims_chunks = chunk_list(image_dims, max_batch_size)
+
+ batched_inputs = []
+ formatted_batch_data = []
+ for proc_chunk, orig_chunk, dims_chunk in zip(batched_image_chunks, orig_chunks, dims_chunks):
+ # Concatenate along the batch dimension => shape (B, H, W, C)
+ batched_input = np.concatenate(proc_chunk, axis=0)
+ batched_inputs.append(batched_input)
+ formatted_batch_data.append({"image_arrays": orig_chunk, "image_dims": dims_chunk})
+ return batched_inputs, formatted_batch_data
+
+ elif protocol == "http":
+ logger.debug("Formatting input for HTTP Cached model (batched).")
+ content_list: List[Dict[str, Any]] = []
+ for arr in image_arrays:
+ # Convert to uint8 if needed, then to PIL Image and base64-encode it.
+ if arr.dtype != np.uint8:
+ arr = (arr * 255).astype(np.uint8)
+ image_pil = Image.fromarray(arr)
+ buffered = io.BytesIO()
+ image_pil.save(buffered, format="PNG")
+ base64_img = base64.b64encode(buffered.getvalue()).decode("utf-8")
+ image_item = {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{base64_img}"}}
+ content_list.append(image_item)
+
+ # Chunk the content list, original arrays, and dimensions.
+ content_chunks = chunk_list(content_list, max_batch_size)
+ orig_chunks = chunk_list(image_arrays, max_batch_size)
+ dims_chunks = chunk_list(image_dims, max_batch_size)
+
+ payload_batches = []
+ formatted_batch_data = []
+ for chunk, orig_chunk, dims_chunk in zip(content_chunks, orig_chunks, dims_chunks):
+ message = {"content": chunk}
+ payload = {"messages": [message]}
+ payload_batches.append(payload)
+ formatted_batch_data.append({"image_arrays": orig_chunk, "image_dims": dims_chunk})
+ return payload_batches, formatted_batch_data
+
+ else:
+ raise ValueError("Invalid protocol specified. Must be 'grpc' or 'http'.")
+
+ def parse_output(self, response: Any, protocol: str, data: Optional[Dict[str, Any]] = None, **kwargs: Any) -> Any:
+ """
+ Parse the output from the Cached model's inference response.
+
+ Parameters
+ ----------
+ response : Any
+ The raw response from the model inference.
+ protocol : str
+ The protocol used ("grpc" or "http").
+ data : dict of str -> Any, optional
+ Additional input data (unused here, but available for consistency).
+ **kwargs : Any
+ Additional keyword arguments for future compatibility.
+
+ Returns
+ -------
+ Any
+ The parsed output data (e.g., list of strings), depending on the protocol.
+
+ Raises
+ ------
+ ValueError
+ If the protocol is invalid.
+ RuntimeError
+ If the HTTP response is not as expected (missing 'data' key).
+ """
+ if protocol == "grpc":
+ logger.debug("Parsing output from gRPC Cached model (batched).")
+ parsed: List[str] = []
+ # Assume `response` is iterable, each element a list/array of byte strings
+ for single_output in response:
+ joined_str = " ".join(o.decode("utf-8") for o in single_output)
+ parsed.append(joined_str)
+ return parsed
+
+ elif protocol == "http":
+ logger.debug("Parsing output from HTTP Cached model (batched).")
+ if not isinstance(response, dict):
+ raise RuntimeError("Expected JSON/dict response for HTTP, got something else.")
+ if "data" not in response or not response["data"]:
+ raise RuntimeError("Unexpected response format: 'data' key missing or empty.")
+
+ contents: List[str] = []
+ for item in response["data"]:
+ # Each "item" might have a "content" key
+ content = item.get("content", "")
+ contents.append(content)
+
+ return contents
+
+ else:
+ raise ValueError("Invalid protocol specified. Must be 'grpc' or 'http'.")
+
+ def process_inference_results(self, output: Any, protocol: str, **kwargs: Any) -> Any:
+ """
+ Process inference results for the Cached model.
+
+ Parameters
+ ----------
+ output : Any
+ The raw output from the model.
+ protocol : str
+ The inference protocol used ("grpc" or "http").
+ **kwargs : Any
+ Additional parameters for post-processing (not used here).
+
+ Returns
+ -------
+ Any
+ The processed inference results, which here is simply returned as-is.
+ """
+ # For Cached model, we simply return what we parsed (e.g., a list of strings or a single string)
+ return output
+
+ def _extract_content_from_nim_response(self, json_response: Dict[str, Any]) -> Any:
+ """
+ Extract content from the JSON response of a NIM (HTTP) API request.
+
+ Parameters
+ ----------
+ json_response : dict of str -> Any
+ The JSON response from the NIM API.
+
+ Returns
+ -------
+ Any
+ The extracted content from the response.
+
+ Raises
+ ------
+ RuntimeError
+ If the response format is unexpected (missing 'data' or empty).
+ """
+ if "data" not in json_response or not json_response["data"]:
+ raise RuntimeError("Unexpected response format: 'data' key is missing or empty.")
+
+ return json_response["data"][0]["content"]
diff --git a/src/nv_ingest/util/nim/decorators.py b/src/nv_ingest/util/nim/decorators.py
new file mode 100644
index 00000000..869c9e54
--- /dev/null
+++ b/src/nv_ingest/util/nim/decorators.py
@@ -0,0 +1,52 @@
+import logging
+from functools import wraps
+from multiprocessing import Lock
+from multiprocessing import Manager
+
+logger = logging.getLogger(__name__)
+
+# Create a shared manager and lock for thread-safe access
+manager = Manager()
+global_cache = manager.dict()
+lock = Lock()
+
+
+def multiprocessing_cache(max_calls):
+ """
+ A decorator that creates a global cache shared between multiple processes.
+ The cache is invalidated after `max_calls` number of accesses.
+
+ Args:
+ max_calls (int): The number of calls after which the cache is cleared.
+
+ Returns:
+ function: The decorated function with global cache and invalidation logic.
+ """
+
+ def decorator(func):
+ call_count = manager.Value("i", 0) # Shared integer for call counting
+
+ @wraps(func)
+ def wrapper(*args, **kwargs):
+ key = (func.__name__, args, frozenset(kwargs.items()))
+
+ with lock:
+ call_count.value += 1
+
+ if call_count.value > max_calls:
+ global_cache.clear()
+ call_count.value = 0
+
+ if key in global_cache:
+ return global_cache[key]
+
+ result = func(*args, **kwargs)
+
+ with lock:
+ global_cache[key] = result
+
+ return result
+
+ return wrapper
+
+ return decorator
diff --git a/src/nv_ingest/util/nim/deplot.py b/src/nv_ingest/util/nim/deplot.py
new file mode 100644
index 00000000..3a8414d7
--- /dev/null
+++ b/src/nv_ingest/util/nim/deplot.py
@@ -0,0 +1,270 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES.
+# All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Dict, Any, Optional, List
+
+import numpy as np
+import logging
+
+from nv_ingest.util.image_processing.transforms import base64_to_numpy
+from nv_ingest.util.nim.helpers import ModelInterface
+
+logger = logging.getLogger(__name__)
+
+
+class DeplotModelInterface(ModelInterface):
+ """
+ An interface for handling inference with a Deplot model, supporting both gRPC and HTTP protocols,
+ now updated to handle multiple base64 images ('base64_images').
+ """
+
+ def name(self) -> str:
+ """
+ Get the name of the model interface.
+
+ Returns
+ -------
+ str
+ The name of the model interface ("Deplot").
+ """
+ return "Deplot"
+
+ def prepare_data_for_inference(self, data: Dict[str, Any]) -> Dict[str, Any]:
+ """
+ Prepare input data by decoding one or more base64-encoded images into NumPy arrays.
+
+ Parameters
+ ----------
+ data : dict
+ The input data containing either 'base64_image' (single image)
+ or 'base64_images' (multiple images).
+
+ Returns
+ -------
+ dict
+ The updated data dictionary with 'image_arrays': a list of decoded NumPy arrays.
+ """
+
+ # Handle a single base64_image or multiple base64_images
+ if "base64_images" in data:
+ base64_list = data["base64_images"]
+ if not isinstance(base64_list, list):
+ raise ValueError("The 'base64_images' key must contain a list of base64-encoded strings.")
+ image_arrays = [base64_to_numpy(b64) for b64 in base64_list]
+
+ elif "base64_image" in data:
+ # Fallback for single image
+ image_arrays = [base64_to_numpy(data["base64_image"])]
+ else:
+ raise KeyError("Input data must include 'base64_image' or 'base64_images'.")
+
+ data["image_arrays"] = image_arrays
+
+ return data
+
+ def format_input(self, data: Dict[str, Any], protocol: str, max_batch_size: int, **kwargs) -> Any:
+ """
+ Format input data for the specified protocol (gRPC or HTTP) for Deplot.
+ For HTTP, we now construct multiple messages—one per image batch—along with
+ corresponding batch data carrying the original image arrays and their dimensions.
+
+ Parameters
+ ----------
+ data : dict of str -> Any
+ The input data dictionary, expected to contain "image_arrays" (a list of np.ndarray).
+ protocol : str
+ The protocol to use, "grpc" or "http".
+ max_batch_size : int
+ The maximum number of images per batch.
+ kwargs : dict
+ Additional parameters to pass to the payload preparation (for HTTP).
+
+ Returns
+ -------
+ tuple
+ (formatted_batches, formatted_batch_data) where:
+ - For gRPC: formatted_batches is a list of NumPy arrays, each of shape (B, H, W, C)
+ with B <= max_batch_size.
+ - For HTTP: formatted_batches is a list of JSON-serializable payload dicts.
+ - In both cases, formatted_batch_data is a list of dicts containing:
+ "image_arrays": the list of original np.ndarray images for that batch, and
+ "image_dims": a list of (height, width) tuples for each image in the batch.
+
+ Raises
+ ------
+ KeyError
+ If "image_arrays" is missing in the data dictionary.
+ ValueError
+ If the protocol is invalid, or if no valid images are found.
+ """
+ if "image_arrays" not in data:
+ raise KeyError("Expected 'image_arrays' in data. Call prepare_data_for_inference first.")
+
+ image_arrays = data["image_arrays"]
+ # Compute image dimensions from each image array.
+ image_dims = [(img.shape[0], img.shape[1]) for img in image_arrays]
+
+ # Helper function: chunk a list into sublists of length <= chunk_size.
+ def chunk_list(lst: list, chunk_size: int) -> List[list]:
+ return [lst[i : i + chunk_size] for i in range(0, len(lst), chunk_size)]
+
+ if protocol == "grpc":
+ logger.debug("Formatting input for gRPC Deplot model (potentially batched).")
+ processed = []
+ for arr in image_arrays:
+ # Ensure each image has shape (1, H, W, C)
+ if arr.ndim == 3:
+ arr = np.expand_dims(arr, axis=0)
+ arr = arr.astype(np.float32)
+ arr /= 255.0 # Normalize to [0,1]
+ processed.append(arr)
+
+ if not processed:
+ raise ValueError("No valid images found for gRPC formatting.")
+
+ formatted_batches = []
+ formatted_batch_data = []
+ proc_chunks = chunk_list(processed, max_batch_size)
+ orig_chunks = chunk_list(image_arrays, max_batch_size)
+ dims_chunks = chunk_list(image_dims, max_batch_size)
+
+ for proc_chunk, orig_chunk, dims_chunk in zip(proc_chunks, orig_chunks, dims_chunks):
+ # Concatenate along the batch dimension to form a single input.
+ batched_input = np.concatenate(proc_chunk, axis=0)
+ formatted_batches.append(batched_input)
+ formatted_batch_data.append({"image_arrays": orig_chunk, "image_dims": dims_chunk})
+ return formatted_batches, formatted_batch_data
+
+ elif protocol == "http":
+ logger.debug("Formatting input for HTTP Deplot model (multiple messages).")
+ if "base64_images" in data:
+ base64_list = data["base64_images"]
+ else:
+ base64_list = [data["base64_image"]]
+
+ formatted_batches = []
+ formatted_batch_data = []
+ b64_chunks = chunk_list(base64_list, max_batch_size)
+ orig_chunks = chunk_list(image_arrays, max_batch_size)
+ dims_chunks = chunk_list(image_dims, max_batch_size)
+
+ for b64_chunk, orig_chunk, dims_chunk in zip(b64_chunks, orig_chunks, dims_chunks):
+ payload = self._prepare_deplot_payload(
+ base64_list=b64_chunk,
+ max_tokens=kwargs.get("max_tokens", 500),
+ temperature=kwargs.get("temperature", 0.5),
+ top_p=kwargs.get("top_p", 0.9),
+ )
+ formatted_batches.append(payload)
+ formatted_batch_data.append({"image_arrays": orig_chunk, "image_dims": dims_chunk})
+ return formatted_batches, formatted_batch_data
+
+ else:
+ raise ValueError("Invalid protocol specified. Must be 'grpc' or 'http'.")
+
+ def parse_output(self, response: Any, protocol: str, data: Optional[Dict[str, Any]] = None, **kwargs) -> Any:
+ """
+ Parse the model's inference response.
+ """
+ if protocol == "grpc":
+ logger.debug("Parsing output from gRPC Deplot model (batched).")
+ # Each batch element might be returned as a list of bytes. Combine or keep separate as needed.
+ results = []
+ for item in response:
+ # If item is [b'...'], decode and join
+ if isinstance(item, list):
+ joined_str = " ".join(o.decode("utf-8") for o in item)
+ results.append(joined_str)
+ else:
+ # single bytes or str
+ val = item.decode("utf-8") if isinstance(item, bytes) else str(item)
+ results.append(val)
+ return results # Return a list of strings, one per image.
+
+ elif protocol == "http":
+ logger.debug("Parsing output from HTTP Deplot model.")
+ return self._extract_content_from_deplot_response(response)
+ else:
+ raise ValueError("Invalid protocol specified. Must be 'grpc' or 'http'.")
+
+ def process_inference_results(self, output: Any, protocol: str, **kwargs) -> Any:
+ """
+ Process inference results for the Deplot model.
+
+ Parameters
+ ----------
+ output : Any
+ The raw output from the model.
+ protocol : str
+ The protocol used for inference (gRPC or HTTP).
+
+ Returns
+ -------
+ Any
+ The processed inference results.
+ """
+
+ # For Deplot, the output is the chart content as a string
+ return output
+
+ @staticmethod
+ def _prepare_deplot_payload(
+ base64_list: list,
+ max_tokens: int = 500,
+ temperature: float = 0.5,
+ top_p: float = 0.9,
+ ) -> Dict[str, Any]:
+ """
+ Prepare an HTTP payload for Deplot that includes one message per image,
+ matching the original single-image style:
+
+ messages = [
+ {
+ "role": "user",
+ "content": "Generate ...
"
+ },
+ {
+ "role": "user",
+ "content": "Generate ...
"
+ },
+ ...
+ ]
+
+ If your backend expects multiple messages in a single request, this keeps
+ the same structure as the single-image code repeated N times.
+ """
+ messages = []
+ # Note: deplot NIM currently only supports a single message per request
+ for b64_img in base64_list:
+ messages.append(
+ {
+ "role": "user",
+ "content": (
+ "Generate the underlying data table of the figure below: "
+ f'
'
+ ),
+ }
+ )
+
+ payload = {
+ "model": "google/deplot",
+ "messages": messages, # multiple user messages now
+ "max_tokens": max_tokens,
+ "stream": False,
+ "temperature": temperature,
+ "top_p": top_p,
+ }
+ return payload
+
+ @staticmethod
+ def _extract_content_from_deplot_response(json_response: Dict[str, Any]) -> Any:
+ """
+ Extract content from the JSON response of a Deplot HTTP API request.
+ The original code expected a single choice with a single textual content.
+ """
+ if "choices" not in json_response or not json_response["choices"]:
+ raise RuntimeError("Unexpected response format: 'choices' key is missing or empty.")
+
+ # If the service only returns one textual result, we return that one.
+ return json_response["choices"][0]["message"]["content"]
diff --git a/src/nv_ingest/util/nim/doughnut.py b/src/nv_ingest/util/nim/doughnut.py
new file mode 100644
index 00000000..84cafe3b
--- /dev/null
+++ b/src/nv_ingest/util/nim/doughnut.py
@@ -0,0 +1,165 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES.
+# All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+
+import logging
+import re
+from typing import List
+from typing import Tuple
+
+ACCEPTED_TEXT_CLASSES = set(
+ [
+ "Text",
+ "Title",
+ "Section-header",
+ "List-item",
+ "TOC",
+ "Bibliography",
+ "Formula",
+ "Page-header",
+ "Page-footer",
+ "Caption",
+ "Footnote",
+ "Floating-text",
+ ]
+)
+ACCEPTED_TABLE_CLASSES = set(
+ [
+ "Table",
+ ]
+)
+ACCEPTED_IMAGE_CLASSES = set(
+ [
+ "Picture",
+ ]
+)
+ACCEPTED_CLASSES = ACCEPTED_TEXT_CLASSES | ACCEPTED_TABLE_CLASSES | ACCEPTED_IMAGE_CLASSES
+
+_re_extract_class_bbox = re.compile(
+ r"((?:|.(?:(?", # noqa: E501
+ re.MULTILINE | re.DOTALL,
+)
+
+logger = logging.getLogger(__name__)
+
+
+def extract_classes_bboxes(text: str) -> Tuple[List[str], List[Tuple[int, int, int, int]], List[str]]:
+ classes: List[str] = []
+ bboxes: List[Tuple[int, int, int, int]] = []
+ texts: List[str] = []
+
+ last_end = 0
+
+ for m in _re_extract_class_bbox.finditer(text):
+ start, end = m.span()
+
+ # [Bad box] Add the non-match chunk (text between the last match and the current match)
+ if start > last_end:
+ bad_text = text[last_end:start].strip()
+ classes.append("Bad-box")
+ bboxes.append((0, 0, 0, 0))
+ texts.append(bad_text)
+
+ last_end = end
+
+ x1, y1, text, x2, y2, cls = m.groups()
+
+ bbox = tuple(map(int, (x1, y1, x2, y2)))
+
+ # [Bad box] check if the class is a valid class.
+ if cls not in ACCEPTED_CLASSES:
+ logger.debug(f"Dropped a bad box: invalid class {cls} at {bbox}.")
+ classes.append("Bad-box")
+ bboxes.append(bbox)
+ texts.append(text)
+ continue
+
+ # Drop bad box: drop if the box is invalid.
+ if (bbox[0] >= bbox[2]) or (bbox[1] >= bbox[3]):
+ logger.debug(f"Dropped a bad box: invalid box {cls} at {bbox}.")
+ classes.append("Bad-box")
+ bboxes.append(bbox)
+ texts.append(text)
+ continue
+
+ classes.append(cls)
+ bboxes.append(bbox)
+ texts.append(text)
+
+ if last_end < len(text):
+ bad_text = text[last_end:].strip()
+ if len(bad_text) > 0:
+ classes.append("Bad-box")
+ bboxes.append((0, 0, 0, 0))
+ texts.append(bad_text)
+
+ return classes, bboxes, texts
+
+
+def _fix_dots(m):
+ # Remove spaces between dots.
+ s = m.group(0)
+ return s.startswith(" ") * " " + min(5, s.count(".")) * "." + s.endswith(" ") * " "
+
+
+def strip_markdown_formatting(text):
+ # Remove headers (e.g., # Header, ## Header, ### Header)
+ text = re.sub(r"^(#+)\s*(.*)", r"\2", text, flags=re.MULTILINE)
+
+ # Remove bold formatting (e.g., **bold text** or __bold text__)
+ text = re.sub(r"\*\*(.*?)\*\*", r"\1", text)
+ text = re.sub(r"__(.*?)__", r"\1", text)
+
+ # Remove italic formatting (e.g., *italic text* or _italic text_)
+ text = re.sub(r"\*(.*?)\*", r"\1", text)
+ text = re.sub(r"_(.*?)_", r"\1", text)
+
+ # Remove strikethrough formatting (e.g., ~~strikethrough~~)
+ text = re.sub(r"~~(.*?)~~", r"\1", text)
+
+ # Remove list items (e.g., - item, * item, 1. item)
+ text = re.sub(r"^\s*([-*+]|[0-9]+\.)\s+", "", text, flags=re.MULTILINE)
+
+ # Remove hyperlinks (e.g., [link text](http://example.com))
+ text = re.sub(r"\[(.*?)\]\(.*?\)", r"\1", text)
+
+ # Remove inline code (e.g., `code`)
+ text = re.sub(r"`(.*?)`", r"\1", text)
+
+ # Remove blockquotes (e.g., > quote)
+ text = re.sub(r"^\s*>\s*(.*)", r"\1", text, flags=re.MULTILINE)
+
+ # Remove multiple newlines
+ text = re.sub(r"\n{3,}", "\n\n", text)
+
+ # Limit dots sequences to max 5 dots
+ text = re.sub(r"(?:\s*\.\s*){3,}", _fix_dots, text, flags=re.DOTALL)
+
+ return text
+
+
+def reverse_transform_bbox(
+ bbox: Tuple[int, int, int, int],
+ bbox_offset: Tuple[int, int],
+ original_width: int,
+ original_height: int,
+) -> Tuple[int, int, int, int]:
+ width_ratio = (original_width - 2 * bbox_offset[0]) / original_width
+ height_ratio = (original_height - 2 * bbox_offset[1]) / original_height
+ w1, h1, w2, h2 = bbox
+ w1 = int((w1 - bbox_offset[0]) / width_ratio)
+ h1 = int((h1 - bbox_offset[1]) / height_ratio)
+ w2 = int((w2 - bbox_offset[0]) / width_ratio)
+ h2 = int((h2 - bbox_offset[1]) / height_ratio)
+
+ return (w1, h1, w2, h2)
+
+
+def postprocess_text(txt: str, cls: str):
+ if cls in ACCEPTED_CLASSES:
+ txt = txt.replace("", "").strip() # remove tokens (continued paragraphs)
+ txt = strip_markdown_formatting(txt)
+ else:
+ txt = ""
+
+ return txt
diff --git a/src/nv_ingest/util/nim/helpers.py b/src/nv_ingest/util/nim/helpers.py
new file mode 100644
index 00000000..22505b08
--- /dev/null
+++ b/src/nv_ingest/util/nim/helpers.py
@@ -0,0 +1,708 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES.
+# All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+
+import logging
+import re
+import threading
+import time
+from concurrent.futures import ThreadPoolExecutor
+from typing import Any
+from typing import Optional
+from typing import Tuple
+
+import backoff
+import cv2
+import numpy as np
+import requests
+import tritonclient.grpc as grpcclient
+
+from nv_ingest.util.image_processing.transforms import normalize_image
+from nv_ingest.util.image_processing.transforms import pad_image
+from nv_ingest.util.nim.decorators import multiprocessing_cache
+from nv_ingest.util.tracing.tagging import traceable_func
+
+logger = logging.getLogger(__name__)
+
+DEPLOT_MAX_TOKENS = 128
+DEPLOT_TEMPERATURE = 1.0
+DEPLOT_TOP_P = 1.0
+
+
+class ModelInterface:
+ """
+ Base class for defining a model interface that supports preparing input data, formatting it for
+ inference, parsing output, and processing inference results.
+ """
+
+ def format_input(self, data: dict, protocol: str, max_batch_size: int):
+ """
+ Format the input data for the specified protocol.
+
+ Parameters
+ ----------
+ data : dict
+ The input data to format.
+ protocol : str
+ The protocol to format the data for.
+ """
+
+ raise NotImplementedError("Subclasses should implement this method")
+
+ def parse_output(self, response, protocol: str, data: Optional[dict] = None, **kwargs):
+ """
+ Parse the output data from the model's inference response.
+
+ Parameters
+ ----------
+ response : Any
+ The response from the model inference.
+ protocol : str
+ The protocol used ("grpc" or "http").
+ data : dict, optional
+ Additional input data passed to the function.
+ """
+
+ raise NotImplementedError("Subclasses should implement this method")
+
+ def prepare_data_for_inference(self, data: dict):
+ """
+ Prepare input data for inference by processing or transforming it as required.
+
+ Parameters
+ ----------
+ data : dict
+ The input data to prepare.
+ """
+ raise NotImplementedError("Subclasses should implement this method")
+
+ def process_inference_results(self, output_array, protocol: str, **kwargs):
+ """
+ Process the inference results from the model.
+
+ Parameters
+ ----------
+ output_array : Any
+ The raw output from the model.
+ kwargs : dict
+ Additional parameters for processing.
+ """
+ raise NotImplementedError("Subclasses should implement this method")
+
+ def name(self) -> str:
+ """
+ Get the name of the model interface.
+
+ Returns
+ -------
+ str
+ The name of the model interface.
+ """
+ raise NotImplementedError("Subclasses should implement this method")
+
+
+class NimClient:
+ """
+ A client for interfacing with a model inference server using gRPC or HTTP protocols.
+ """
+
+ def __init__(
+ self,
+ model_interface,
+ protocol: str,
+ endpoints: Tuple[str, str],
+ auth_token: Optional[str] = None,
+ timeout: float = 120.0,
+ max_retries: int = 5,
+ ):
+ """
+ Initialize the NimClient with the specified model interface, protocol, and server endpoints.
+
+ Parameters
+ ----------
+ model_interface : ModelInterface
+ The model interface implementation to use.
+ protocol : str
+ The protocol to use ("grpc" or "http").
+ endpoints : tuple
+ A tuple containing the gRPC and HTTP endpoints.
+ auth_token : str, optional
+ Authorization token for HTTP requests (default: None).
+ timeout : float, optional
+ Timeout for HTTP requests in seconds (default: 30.0).
+
+ Raises
+ ------
+ ValueError
+ If an invalid protocol is specified or if required endpoints are missing.
+ """
+
+ self.client = None
+ self.model_interface = model_interface
+ self.protocol = protocol.lower()
+ self.auth_token = auth_token
+ self.timeout = timeout # Timeout for HTTP requests
+ self.max_retries = max_retries
+ self._grpc_endpoint, self._http_endpoint = endpoints
+ self._max_batch_sizes = {}
+ self._lock = threading.Lock()
+
+ if self.protocol == "grpc":
+ if not self._grpc_endpoint:
+ raise ValueError("gRPC endpoint must be provided for gRPC protocol")
+ logger.debug(f"Creating gRPC client with {self._grpc_endpoint}")
+ self.client = grpcclient.InferenceServerClient(url=self._grpc_endpoint)
+ elif self.protocol == "http":
+ if not self._http_endpoint:
+ raise ValueError("HTTP endpoint must be provided for HTTP protocol")
+ logger.debug(f"Creating HTTP client with {self._http_endpoint}")
+ self.endpoint_url = generate_url(self._http_endpoint)
+ self.headers = {"accept": "application/json", "content-type": "application/json"}
+ if self.auth_token:
+ self.headers["Authorization"] = f"Bearer {self.auth_token}"
+ else:
+ raise ValueError("Invalid protocol specified. Must be 'grpc' or 'http'.")
+
+ def _fetch_max_batch_size(self, model_name, model_version: str = "") -> int:
+ """Fetch the maximum batch size from the Triton model configuration in a thread-safe manner."""
+ if model_name in self._max_batch_sizes:
+ return self._max_batch_sizes[model_name]
+
+ with self._lock:
+ # Double check, just in case another thread set the value while we were waiting
+ if model_name in self._max_batch_sizes:
+ return self._max_batch_sizes[model_name]
+
+ if not self._grpc_endpoint:
+ self._max_batch_sizes[model_name] = 1
+ return 1
+
+ try:
+ client = self.client if self.client else grpcclient.InferenceServerClient(url=self._grpc_endpoint)
+ model_config = client.get_model_config(model_name=model_name, model_version=model_version)
+ self._max_batch_sizes[model_name] = model_config.config.max_batch_size
+ logger.debug(f"Max batch size for model '{model_name}': {self._max_batch_sizes[model_name]}")
+ except Exception as e:
+ self._max_batch_sizes[model_name] = 1
+ logger.warning(f"Failed to retrieve max batch size: {e}, defaulting to 1")
+
+ return self._max_batch_sizes[model_name]
+
+ def _process_batch(self, batch_input, *, batch_data, model_name, **kwargs):
+ """
+ Process a single batch input for inference using its corresponding batch_data.
+
+ Parameters
+ ----------
+ batch_input : Any
+ The input data for this batch.
+ batch_data : Any
+ The corresponding scratch-pad data for this batch as returned by format_input.
+ model_name : str
+ The model name for inference.
+ kwargs : dict
+ Additional parameters.
+
+ Returns
+ -------
+ tuple
+ A tuple (parsed_output, batch_data) for subsequent post-processing.
+ """
+ if self.protocol == "grpc":
+ logger.debug("Performing gRPC inference for a batch...")
+ response = self._grpc_infer(batch_input, model_name)
+ logger.debug("gRPC inference received response for a batch")
+ elif self.protocol == "http":
+ logger.debug("Performing HTTP inference for a batch...")
+ response = self._http_infer(batch_input)
+ logger.debug("HTTP inference received response for a batch")
+ else:
+ raise ValueError("Invalid protocol specified. Must be 'grpc' or 'http'.")
+
+ parsed_output = self.model_interface.parse_output(response, protocol=self.protocol, data=batch_data, **kwargs)
+ return parsed_output, batch_data
+
+ def try_set_max_batch_size(self, model_name, model_version: str = ""):
+ """Attempt to set the max batch size for the model if it is not already set, ensuring thread safety."""
+ self._fetch_max_batch_size(model_name, model_version)
+
+ @traceable_func(trace_name="{stage_name}::{model_name}")
+ def infer(self, data: dict, model_name: str, **kwargs) -> Any:
+ """
+ Perform inference using the specified model and input data.
+
+ Parameters
+ ----------
+ data : dict
+ The input data for inference.
+ model_name : str
+ The model name.
+ kwargs : dict
+ Additional parameters for inference.
+
+ Returns
+ -------
+ Any
+ The processed inference results, coalesced in the same order as the input images.
+ """
+ try:
+ # 1. Retrieve or default to the model's maximum batch size.
+ batch_size = self._fetch_max_batch_size(model_name)
+ max_requested_batch_size = kwargs.get("max_batch_size", batch_size)
+ force_requested_batch_size = kwargs.get("force_max_batch_size", False)
+ max_batch_size = (
+ min(batch_size, max_requested_batch_size)
+ if not force_requested_batch_size
+ else max_requested_batch_size
+ )
+
+ # 2. Prepare data for inference.
+ data = self.model_interface.prepare_data_for_inference(data)
+
+ # 3. Format the input based on protocol.
+ formatted_batches, formatted_batch_data = self.model_interface.format_input(
+ data, protocol=self.protocol, max_batch_size=max_batch_size, model_name=model_name
+ )
+
+ # Check for a custom maximum pool worker count, and remove it from kwargs.
+ max_pool_workers = kwargs.pop("max_pool_workers", 16)
+
+ # 4. Process each batch concurrently using a thread pool.
+ # We enumerate the batches so that we can later reassemble results in order.
+ results = [None] * len(formatted_batches)
+ with ThreadPoolExecutor(max_workers=max_pool_workers) as executor:
+ futures = []
+ for idx, (batch, batch_data) in enumerate(zip(formatted_batches, formatted_batch_data)):
+ future = executor.submit(
+ self._process_batch, batch, batch_data=batch_data, model_name=model_name, **kwargs
+ )
+ futures.append((idx, future))
+ for idx, future in futures:
+ results[idx] = future.result()
+
+ # 5. Process the parsed outputs for each batch using its corresponding batch_data.
+ # As the batches are in order, we coalesce their outputs accordingly.
+ all_results = []
+ for parsed_output, batch_data in results:
+ batch_results = self.model_interface.process_inference_results(
+ parsed_output,
+ original_image_shapes=batch_data.get("original_image_shapes"),
+ protocol=self.protocol,
+ **kwargs,
+ )
+ if isinstance(batch_results, list):
+ all_results.extend(batch_results)
+ else:
+ all_results.append(batch_results)
+
+ except Exception as err:
+ error_str = f"Error during NimClient inference [{self.model_interface.name()}, {self.protocol}]: {err}"
+ logger.error(error_str)
+ raise RuntimeError(error_str)
+
+ return all_results
+
+ def _grpc_infer(self, formatted_input: np.ndarray, model_name: str) -> np.ndarray:
+ """
+ Perform inference using the gRPC protocol.
+
+ Parameters
+ ----------
+ formatted_input : np.ndarray
+ The input data formatted as a numpy array.
+ model_name : str
+ The name of the model to use for inference.
+
+ Returns
+ -------
+ np.ndarray
+ The output of the model as a numpy array.
+ """
+
+ input_tensors = [grpcclient.InferInput("input", formatted_input.shape, datatype="FP32")]
+ input_tensors[0].set_data_from_numpy(formatted_input)
+
+ outputs = [grpcclient.InferRequestedOutput("output")]
+ response = self.client.infer(model_name=model_name, inputs=input_tensors, outputs=outputs)
+ logger.debug(f"gRPC inference response: {response}")
+
+ return response.as_numpy("output")
+
+ def _http_infer(self, formatted_input: dict) -> dict:
+ """
+ Perform inference using the HTTP protocol, retrying for timeouts or 5xx errors up to 5 times.
+
+ Parameters
+ ----------
+ formatted_input : dict
+ The input data formatted as a dictionary.
+
+ Returns
+ -------
+ dict
+ The output of the model as a dictionary.
+
+ Raises
+ ------
+ TimeoutError
+ If the HTTP request times out repeatedly, up to the max retries.
+ requests.RequestException
+ For other HTTP-related errors that persist after max retries.
+ """
+
+ base_delay = 2.0
+ attempt = 0
+
+ while attempt < self.max_retries:
+ try:
+ response = requests.post(
+ self.endpoint_url, json=formatted_input, headers=self.headers, timeout=self.timeout
+ )
+ status_code = response.status_code
+
+ # Check for server-side or rate-limit type errors
+ # e.g. 5xx => server error, 429 => too many requests
+ if status_code == 429 or status_code == 503 or (500 <= status_code < 600):
+ logger.warning(
+ f"Received HTTP {status_code} ({response.reason}) from "
+ f"{self.model_interface.name()}. Attempt {attempt + 1} of {self.max_retries}."
+ )
+ if attempt == self.max_retries - 1:
+ # No more retries left
+ logger.error(f"Max retries exceeded after receiving HTTP {status_code}.")
+ response.raise_for_status() # raise the appropriate HTTPError
+ else:
+ # Exponential backoff
+ backoff_time = base_delay * (2**attempt)
+ time.sleep(backoff_time)
+ attempt += 1
+ continue
+ else:
+ # Not in our "retry" category => just raise_for_status or return
+ response.raise_for_status()
+ logger.debug(f"HTTP inference response: {response.json()}")
+ return response.json()
+
+ except requests.Timeout:
+ # Treat timeouts similarly to 5xx => attempt a retry
+ logger.warning(
+ f"HTTP request timed out after {self.timeout} seconds during {self.model_interface.name()} "
+ f"inference. Attempt {attempt + 1} of {self.max_retries}."
+ )
+ if attempt == self.max_retries - 1:
+ logger.error("Max retries exceeded after repeated timeouts.")
+ raise TimeoutError(
+ f"Repeated timeouts for {self.model_interface.name()} after {attempt + 1} attempts."
+ )
+ # Exponential backoff
+ backoff_time = base_delay * (2**attempt)
+ time.sleep(backoff_time)
+ attempt += 1
+
+ except requests.HTTPError as http_err:
+ # If we ended up here, it's a non-retryable 4xx or final 5xx after final attempt
+ logger.error(f"HTTP request failed with status code {response.status_code}: {http_err}")
+ raise
+
+ except requests.RequestException as e:
+ # ConnectionError or other non-HTTPError
+ logger.error(f"HTTP request encountered a network issue: {e}")
+ if attempt == self.max_retries - 1:
+ raise
+ # Else retry on next loop iteration
+ backoff_time = base_delay * (2**attempt)
+ time.sleep(backoff_time)
+ attempt += 1
+
+ # If we exit the loop without returning, we've exhausted all attempts
+ logger.error(f"Failed to get a successful response after {self.max_retries} retries.")
+ raise Exception(f"Failed to get a successful response after {self.max_retries} retries.")
+
+ def close(self):
+ if self.protocol == "grpc" and hasattr(self.client, "close"):
+ self.client.close()
+
+
+def create_inference_client(
+ endpoints: Tuple[str, str],
+ model_interface: ModelInterface,
+ auth_token: Optional[str] = None,
+ infer_protocol: Optional[str] = None,
+) -> NimClient:
+ """
+ Create a NimClient for interfacing with a model inference server.
+
+ Parameters
+ ----------
+ endpoints : tuple
+ A tuple containing the gRPC and HTTP endpoints.
+ model_interface : ModelInterface
+ The model interface implementation to use.
+ auth_token : str, optional
+ Authorization token for HTTP requests (default: None).
+ infer_protocol : str, optional
+ The protocol to use ("grpc" or "http"). If not specified, it is inferred from the endpoints.
+
+ Returns
+ -------
+ NimClient
+ The initialized NimClient.
+
+ Raises
+ ------
+ ValueError
+ If an invalid infer_protocol is specified.
+ """
+
+ grpc_endpoint, http_endpoint = endpoints
+
+ if (infer_protocol is None) and (grpc_endpoint and grpc_endpoint.strip()):
+ infer_protocol = "grpc"
+ elif infer_protocol is None and http_endpoint:
+ infer_protocol = "http"
+
+ if infer_protocol not in ["grpc", "http"]:
+ raise ValueError("Invalid infer_protocol specified. Must be 'grpc' or 'http'.")
+
+ return NimClient(model_interface, infer_protocol, endpoints, auth_token)
+
+
+def preprocess_image_for_paddle(array: np.ndarray, image_max_dimension: int = 960) -> np.ndarray:
+ """
+ Preprocesses an input image to be suitable for use with PaddleOCR by resizing, normalizing, padding,
+ and transposing it into the required format.
+
+ This function is intended for preprocessing images to be passed as input to PaddleOCR using GRPC.
+ It is not necessary when using the HTTP endpoint.
+
+ Steps:
+ -----
+ 1. Resizes the image while maintaining aspect ratio such that its largest dimension is scaled to 960 pixels.
+ 2. Normalizes the image using the `normalize_image` function.
+ 3. Pads the image to ensure both its height and width are multiples of 32, as required by PaddleOCR.
+ 4. Transposes the image from (height, width, channel) to (channel, height, width), the format expected by PaddleOCR.
+
+ Parameters:
+ ----------
+ array : np.ndarray
+ The input image array of shape (height, width, channels). It should have pixel values in the range [0, 255].
+
+ Returns:
+ -------
+ np.ndarray
+ A preprocessed image with the shape (channels, height, width) and normalized pixel values.
+ The image will be padded to have dimensions that are multiples of 32, with the padding color set to 0.
+
+ Notes:
+ -----
+ - The image is resized so that its largest dimension becomes 960 pixels, maintaining the aspect ratio.
+ - After normalization, the image is padded to the nearest multiple of 32 in both dimensions, which is
+ a requirement for PaddleOCR.
+ - The normalized pixel values are scaled between 0 and 1 before padding and transposing the image.
+ """
+ height, width = array.shape[:2]
+ scale_factor = image_max_dimension / max(height, width)
+ new_height = int(height * scale_factor)
+ new_width = int(width * scale_factor)
+ resized = cv2.resize(array, (new_width, new_height))
+
+ normalized = normalize_image(resized)
+
+ # PaddleOCR NIM (GRPC) requires input shapes to be multiples of 32.
+ new_height = (normalized.shape[0] + 31) // 32 * 32
+ new_width = (normalized.shape[1] + 31) // 32 * 32
+ padded, (pad_width, pad_height) = pad_image(
+ normalized, target_height=new_height, target_width=new_width, background_color=0, dtype=np.float32
+ )
+
+ # PaddleOCR NIM (GRPC) requires input to be (channel, height, width).
+ transposed = padded.transpose((2, 0, 1))
+
+ # Metadata can used for inverting transformations on the resulting bounding boxes.
+ metadata = {
+ "original_height": height,
+ "original_width": width,
+ "scale_factor": scale_factor,
+ "new_height": transposed.shape[1],
+ "new_width": transposed.shape[2],
+ "pad_height": pad_height,
+ "pad_width": pad_width,
+ }
+
+ return transposed, metadata
+
+
+def remove_url_endpoints(url) -> str:
+ """Some configurations provide the full endpoint in the URL.
+ Ex: http://deplot:8000/v1/chat/completions. For hitting the
+ health endpoint we need to get just the hostname:port combo
+ that we can append the health/ready endpoint to so we attempt
+ to parse that information here.
+
+ Args:
+ url str: Incoming URL
+
+ Returns:
+ str: URL with just the hostname:port portion remaining
+ """
+ if "/v1" in url:
+ url = url.split("/v1")[0]
+
+ return url
+
+
+def generate_url(url) -> str:
+ """Examines the user defined URL for http*://. If that
+ pattern is detected the URL is used as provided by the user.
+ If that pattern does not exist then the assumption is made that
+ the endpoint is simply `http://` and that is prepended
+ to the user supplied endpoint.
+
+ Args:
+ url str: Endpoint where the Rest service is running
+
+ Returns:
+ str: Fully validated URL
+ """
+ if not re.match(r"^https?://", url):
+ # Add the default `http://` if it's not already present in the URL
+ url = f"http://{url}"
+
+ return url
+
+
+def is_ready(http_endpoint: str, ready_endpoint: str) -> bool:
+ """
+ Check if the server at the given endpoint is ready.
+
+ Parameters
+ ----------
+ http_endpoint : str
+ The HTTP endpoint of the server.
+ ready_endpoint : str
+ The specific ready-check endpoint.
+
+ Returns
+ -------
+ bool
+ True if the server is ready, False otherwise.
+ """
+
+ # IF the url is empty or None that means the service was not configured
+ # and is therefore automatically marked as "ready"
+ if http_endpoint is None or http_endpoint == "":
+ return True
+
+ # If the url is for build.nvidia.com, it is automatically assumed "ready"
+ if "ai.api.nvidia.com" in http_endpoint:
+ return True
+
+ url = generate_url(http_endpoint)
+ url = remove_url_endpoints(url)
+
+ if not ready_endpoint.startswith("/") and not url.endswith("/"):
+ ready_endpoint = "/" + ready_endpoint
+
+ url = url + ready_endpoint
+
+ # Call the ready endpoint of the NIM
+ try:
+ # Use a short timeout to prevent long hanging calls. 5 seconds seems resonable
+ resp = requests.get(url, timeout=5)
+ if resp.status_code == 200:
+ # The NIM is saying it is ready to serve
+ return True
+ elif resp.status_code == 503:
+ # NIM is explicitly saying it is not ready.
+ return False
+ else:
+ # Any other code is confusing. We should log it with a warning
+ # as it could be something that might hold up ready state
+ logger.warning(f"'{url}' HTTP Status: {resp.status_code} - Response Payload: {resp.json()}")
+ return False
+ except requests.HTTPError as http_err:
+ logger.warning(f"'{url}' produced a HTTP error: {http_err}")
+ return False
+ except requests.Timeout:
+ logger.warning(f"'{url}' request timed out")
+ return False
+ except ConnectionError:
+ logger.warning(f"A connection error for '{url}' occurred")
+ return False
+ except requests.RequestException as err:
+ logger.warning(f"An error occurred: {err} for '{url}'")
+ return False
+ except Exception as ex:
+ # Don't let anything squeeze by
+ logger.warning(f"Exception: {ex}")
+ return False
+
+
+@multiprocessing_cache(max_calls=100) # Cache results first to avoid redundant retries from backoff
+@backoff.on_predicate(backoff.expo, max_time=30)
+def get_version(http_endpoint: str, metadata_endpoint: str = "/v1/metadata", version_field: str = "version") -> str:
+ """
+ Get the version of the server from its metadata endpoint.
+
+ Parameters
+ ----------
+ http_endpoint : str
+ The HTTP endpoint of the server.
+ metadata_endpoint : str, optional
+ The metadata endpoint to query (default: "/v1/metadata").
+ version_field : str, optional
+ The field containing the version in the response (default: "version").
+
+ Returns
+ -------
+ str
+ The version of the server, or an empty string if unavailable.
+ """
+
+ if (http_endpoint is None) or (http_endpoint == ""):
+ return ""
+
+ # TODO: Need a way to match NIM versions to API versions.
+ if "ai.api.nvidia.com" in http_endpoint or "api.nvcf.nvidia.com" in http_endpoint:
+ return "1.0.0"
+
+ url = generate_url(http_endpoint)
+ url = remove_url_endpoints(url)
+
+ if not metadata_endpoint.startswith("/") and not url.endswith("/"):
+ metadata_endpoint = "/" + metadata_endpoint
+
+ url = url + metadata_endpoint
+
+ # Call the metadata endpoint of the NIM
+ try:
+ # Use a short timeout to prevent long hanging calls. 5 seconds seems reasonable
+ resp = requests.get(url, timeout=5)
+ if resp.status_code == 200:
+ version = resp.json().get(version_field, "")
+ if version:
+ return version
+ else:
+ # If version field is empty, retry
+ logger.warning(f"No version field in response from '{url}'. Retrying.")
+ return ""
+ else:
+ # Any other code is confusing. We should log it with a warning
+ logger.warning(f"'{url}' HTTP Status: {resp.status_code} - Response Payload: {resp.text}")
+ return ""
+ except requests.HTTPError as http_err:
+ logger.warning(f"'{url}' produced a HTTP error: {http_err}")
+ return ""
+ except requests.Timeout:
+ logger.warning(f"'{url}' request timed out")
+ return ""
+ except ConnectionError:
+ logger.warning(f"A connection error for '{url}' occurred")
+ return ""
+ except requests.RequestException as err:
+ logger.warning(f"An error occurred: {err} for '{url}'")
+ return ""
+ except Exception as ex:
+ # Don't let anything squeeze by
+ logger.warning(f"Exception: {ex}")
+ return ""
diff --git a/src/nv_ingest/util/nim/paddle.py b/src/nv_ingest/util/nim/paddle.py
new file mode 100644
index 00000000..98b99efd
--- /dev/null
+++ b/src/nv_ingest/util/nim/paddle.py
@@ -0,0 +1,446 @@
+import json
+import logging
+from typing import Any, List, Tuple
+from typing import Dict
+from typing import Optional
+
+import numpy as np
+
+from nv_ingest.util.image_processing.transforms import base64_to_numpy
+from nv_ingest.util.nim.helpers import ModelInterface
+from nv_ingest.util.nim.helpers import preprocess_image_for_paddle
+
+logger = logging.getLogger(__name__)
+
+
+class PaddleOCRModelInterface(ModelInterface):
+ """
+ An interface for handling inference with a PaddleOCR model, supporting both gRPC and HTTP protocols.
+ """
+
+ def name(self) -> str:
+ """
+ Get the name of the model interface.
+
+ Returns
+ -------
+ str
+ The name of the model interface.
+ """
+ return "PaddleOCR"
+
+ def prepare_data_for_inference(self, data: Dict[str, Any]) -> Dict[str, Any]:
+ """
+ Decode one or more base64-encoded images into NumPy arrays, storing them
+ alongside their dimensions in `data`.
+
+ Parameters
+ ----------
+ data : dict of str -> Any
+ The input data containing either:
+ - 'base64_image': a single base64-encoded image, or
+ - 'base64_images': a list of base64-encoded images.
+
+ Returns
+ -------
+ dict of str -> Any
+ The updated data dictionary with the following keys added:
+ - "image_arrays": List of decoded NumPy arrays of shape (H, W, C).
+ - "image_dims": List of (height, width) tuples for each decoded image.
+
+ Raises
+ ------
+ KeyError
+ If neither 'base64_image' nor 'base64_images' is found in `data`.
+ ValueError
+ If 'base64_images' is present but is not a list.
+ """
+ if "base64_images" in data:
+ base64_list = data["base64_images"]
+ if not isinstance(base64_list, list):
+ raise ValueError("The 'base64_images' key must contain a list of base64-encoded strings.")
+
+ image_arrays: List[np.ndarray] = []
+ for b64 in base64_list:
+ img = base64_to_numpy(b64)
+ image_arrays.append(img)
+
+ data["image_arrays"] = image_arrays
+
+ elif "base64_image" in data:
+ # Single-image fallback
+ img = base64_to_numpy(data["base64_image"])
+ data["image_arrays"] = [img]
+
+ else:
+ raise KeyError("Input data must include 'base64_image' or 'base64_images'.")
+
+ return data
+
+ def format_input(self, data: Dict[str, Any], protocol: str, max_batch_size: int, **kwargs) -> Any:
+ """
+ Format input data for the specified protocol ("grpc" or "http"), supporting batched data.
+
+ Parameters
+ ----------
+ data : dict of str -> Any
+ The input data dictionary, expected to contain "image_arrays" (list of np.ndarray)
+ and "image_dims" (list of (height, width) tuples), as produced by prepare_data_for_inference.
+ protocol : str
+ The inference protocol, either "grpc" or "http".
+ max_batch_size : int
+ The maximum batch size for batching.
+
+ Returns
+ -------
+ tuple
+ A tuple (formatted_batches, formatted_batch_data) where:
+ - formatted_batches is a list of batches ready for inference.
+ - formatted_batch_data is a list of scratch-pad dictionaries corresponding to each batch,
+ containing the keys "image_arrays" and "image_dims" for later post-processing.
+
+ Raises
+ ------
+ KeyError
+ If either "image_arrays" or "image_dims" is not found in `data`.
+ ValueError
+ If an invalid protocol is specified.
+ """
+
+ images = data["image_arrays"]
+
+ dims: List[Dict[str, Any]] = []
+ data["image_dims"] = dims
+
+ # Helper function to split a list into chunks of size up to chunk_size.
+ def chunk_list(lst, chunk_size):
+ return [lst[i : i + chunk_size] for i in range(0, len(lst), chunk_size)]
+
+ if "image_arrays" not in data or "image_dims" not in data:
+ raise KeyError("Expected 'image_arrays' and 'image_dims' in data. Call prepare_data_for_inference first.")
+
+ images = data["image_arrays"]
+ dims = data["image_dims"]
+
+ if protocol == "grpc":
+ logger.debug("Formatting input for gRPC PaddleOCR model (batched).")
+ processed: List[np.ndarray] = []
+ for img in images:
+ arr, _dims = preprocess_image_for_paddle(img)
+ dims.append(_dims)
+ arr = arr.astype(np.float32)
+ arr = np.expand_dims(arr, axis=0) # => shape (1, H, W, C)
+ processed.append(arr)
+
+ batches = []
+ batch_data_list = []
+ for proc_chunk, orig_chunk, dims_chunk in zip(
+ chunk_list(processed, max_batch_size),
+ chunk_list(images, max_batch_size),
+ chunk_list(dims, max_batch_size),
+ ):
+ batched_input = np.concatenate(proc_chunk, axis=0)
+ batches.append(batched_input)
+ batch_data_list.append({"image_arrays": orig_chunk, "image_dims": dims_chunk})
+ return batches, batch_data_list
+
+ elif protocol == "http":
+ logger.debug("Formatting input for HTTP PaddleOCR model (batched).")
+ if "base64_images" in data:
+ base64_list = data["base64_images"]
+ else:
+ base64_list = [data["base64_image"]]
+
+ input_list: List[Dict[str, Any]] = []
+ for b64 in base64_list:
+ image_url = f"data:image/png;base64,{b64}"
+ image_obj = {"type": "image_url", "url": image_url}
+ input_list.append(image_obj)
+
+ batches = []
+ batch_data_list = []
+ for input_chunk, orig_chunk, dims_chunk in zip(
+ chunk_list(input_list, max_batch_size),
+ chunk_list(images, max_batch_size),
+ chunk_list(dims, max_batch_size),
+ ):
+ payload = {"input": input_chunk}
+ batches.append(payload)
+ batch_data_list.append({"image_arrays": orig_chunk, "image_dims": dims_chunk})
+ return batches, batch_data_list
+
+ else:
+ raise ValueError("Invalid protocol specified. Must be 'grpc' or 'http'.")
+
+ def parse_output(self, response: Any, protocol: str, data: Optional[Dict[str, Any]] = None, **kwargs: Any) -> Any:
+ """
+ Parse the model's inference response for the given protocol. The parsing
+ may handle batched outputs for multiple images.
+
+ Parameters
+ ----------
+ response : Any
+ The raw response from the PaddleOCR model.
+ protocol : str
+ The protocol used for inference, "grpc" or "http".
+ data : dict of str -> Any, optional
+ Additional data dictionary that may include "image_dims" for bounding box scaling.
+ **kwargs : Any
+ Additional keyword arguments, such as custom `table_content_format`.
+
+ Returns
+ -------
+ Any
+ The parsed output, typically a list of (content, table_content_format) tuples.
+
+ Raises
+ ------
+ ValueError
+ If an invalid protocol is specified.
+ """
+ # Retrieve image dimensions if available
+ dims: Optional[List[Tuple[int, int]]] = data.get("image_dims") if data else None
+
+ if protocol == "grpc":
+ logger.debug("Parsing output from gRPC PaddleOCR model (batched).")
+ return self._extract_content_from_paddle_grpc_response(response, dims)
+
+ elif protocol == "http":
+ logger.debug("Parsing output from HTTP PaddleOCR model (batched).")
+ return self._extract_content_from_paddle_http_response(response)
+
+ else:
+ raise ValueError("Invalid protocol specified. Must be 'grpc' or 'http'.")
+
+ def process_inference_results(self, output: Any, **kwargs: Any) -> Any:
+ """
+ Process inference results for the PaddleOCR model.
+
+ Parameters
+ ----------
+ output : Any
+ The raw output parsed from the PaddleOCR model.
+ **kwargs : Any
+ Additional keyword arguments for customization.
+
+ Returns
+ -------
+ Any
+ The post-processed inference results. By default, this simply returns the output
+ as the table content (or content list).
+ """
+ return output
+
+ def _prepare_paddle_payload(self, base64_img: str) -> Dict[str, Any]:
+ """
+ DEPRECATED by batch logic in format_input. Kept here if you need single-image direct calls.
+
+ Parameters
+ ----------
+ base64_img : str
+ A single base64-encoded image string.
+
+ Returns
+ -------
+ dict of str -> Any
+ The payload in either legacy or new format for PaddleOCR's HTTP endpoint.
+ """
+ image_url = f"data:image/png;base64,{base64_img}"
+
+ image = {"type": "image_url", "url": image_url}
+ payload = {"input": [image]}
+
+ return payload
+
+ def _extract_content_from_paddle_http_response(
+ self,
+ json_response: Dict[str, Any],
+ table_content_format: Optional[str],
+ ) -> List[Tuple[str, str]]:
+ """
+ Extract content from the JSON response of a PaddleOCR HTTP API request.
+
+ Parameters
+ ----------
+ json_response : dict of str -> Any
+ The JSON response returned by the PaddleOCR endpoint.
+ table_content_format : str or None
+ The specified format for table content (e.g., 'simple' or 'pseudo_markdown').
+
+ Returns
+ -------
+ list of (str, str)
+ A list of (content, table_content_format) tuples, one for each image result.
+
+ Raises
+ ------
+ RuntimeError
+ If the response format is missing or invalid.
+ ValueError
+ If the `table_content_format` is unrecognized.
+ """
+ if "data" not in json_response or not json_response["data"]:
+ raise RuntimeError("Unexpected response format: 'data' key is missing or empty.")
+
+ results: List[str] = []
+ for item_idx, item in enumerate(json_response["data"]):
+ text_detections = item.get("text_detections", [])
+ text_predictions = []
+ bounding_boxes = []
+ for td in text_detections:
+ text_predictions.append(td["text_prediction"]["text"])
+ bounding_boxes.append([[pt["x"], pt["y"]] for pt in td["bounding_box"]["points"]])
+
+ results.append([bounding_boxes, text_predictions])
+
+ return results
+
+ def _extract_content_from_paddle_grpc_response(
+ self,
+ response: np.ndarray,
+ dimensions: List[Dict[str, Any]],
+ ) -> List[Tuple[str, str]]:
+ """
+ Parse a gRPC response for one or more images. The response can have two possible shapes:
+ - (3,) for batch_size=1
+ - (3, n) for batch_size=n
+
+ In either case:
+ response[0, i]: byte string containing bounding box data
+ response[1, i]: byte string containing text prediction data
+ response[2, i]: (Optional) additional data/metadata (ignored here)
+
+ Parameters
+ ----------
+ response : np.ndarray
+ The raw NumPy array from gRPC. Expected shape: (3,) or (3, n).
+ table_content_format : str
+ The format of the output text content, e.g. 'simple' or 'pseudo_markdown'.
+ dims : list of dict, optional
+ A list of dict for each corresponding image, used for bounding box scaling.
+
+ Returns
+ -------
+ list of (str, str)
+ A list of (content, table_content_format) for each image.
+
+ Raises
+ ------
+ ValueError
+ If the response is not a NumPy array or has an unexpected shape,
+ or if the `table_content_format` is unrecognized.
+ """
+ if not isinstance(response, np.ndarray):
+ raise ValueError("Unexpected response format: response is not a NumPy array.")
+
+ # If we have shape (3,), convert to (3, 1)
+ if response.ndim == 1 and response.shape == (3,):
+ response = response.reshape(3, 1)
+ elif response.ndim != 2 or response.shape[0] != 3:
+ raise ValueError(f"Unexpected response shape: {response.shape}. Expecting (3,) or (3, n).")
+
+ batch_size = response.shape[1]
+ results: List[Tuple[str, str]] = []
+
+ for i in range(batch_size):
+ # 1) Parse bounding boxes
+ bboxes_bytestr: bytes = response[0, i]
+ bounding_boxes = json.loads(bboxes_bytestr.decode("utf8"))
+
+ # 2) Parse text predictions
+ texts_bytestr: bytes = response[1, i]
+ text_predictions = json.loads(texts_bytestr.decode("utf8"))
+
+ # 3) Log the third element (extra data/metadata) if needed
+ extra_data_bytestr: bytes = response[2, i]
+ logger.debug(f"Ignoring extra_data for image {i}: {extra_data_bytestr}")
+
+ # Some gRPC responses nest single-item lists; flatten them if needed
+ if isinstance(bounding_boxes, list) and len(bounding_boxes) == 1:
+ bounding_boxes = bounding_boxes[0]
+ if isinstance(text_predictions, list) and len(text_predictions) == 1:
+ text_predictions = text_predictions[0]
+
+ bounding_boxes, text_predictions = self._postprocess_paddle_response(
+ bounding_boxes,
+ text_predictions,
+ dimensions,
+ img_index=i,
+ )
+
+ results.append([bounding_boxes, text_predictions])
+
+ return results
+
+ @staticmethod
+ def _postprocess_paddle_response(
+ bounding_boxes: List[Any],
+ text_predictions: List[str],
+ dims: Optional[List[Dict[str, Any]]] = None,
+ img_index: int = 0,
+ ) -> Tuple[List[Any], List[str]]:
+ """
+ Convert bounding boxes with normalized coordinates to pixel cooridnates by using
+ the dimensions. Also shift the coorindates if the inputs were padded. For multiple images,
+ the correct image dimensions (height, width) are retrieved from `dims[img_index]`.
+
+ Parameters
+ ----------
+ bounding_boxes : list of Any
+ A list (per line of text) of bounding boxes, each a list of (x, y) points.
+ text_predictions : list of str
+ A list of text predictions, one for each bounding box.
+ img_index : int, optional
+ The index of the image for which bounding boxes are being converted. Default is 0.
+ dims : list of dict, optional
+ A list of dictionaries, where each dictionary contains image-specific dimensions
+ and scaling information:
+ - "new_width" (int): The width of the image after processing.
+ - "new_height" (int): The height of the image after processing.
+ - "pad_width" (int, optional): The width of padding added to the image.
+ - "pad_height" (int, optional): The height of padding added to the image.
+ - "scale_factor" (float, optional): The scaling factor applied to the image.
+
+ Returns
+ -------
+ Tuple[List[Any], List[str]]
+ Bounding boxes scaled backed to the original dimensions and detected text lines.
+
+ Notes
+ -----
+ - If `dims` is None or `img_index` is out of range, bounding boxes will not be scaled properly.
+ """
+ # Default to no scaling if dims are missing or out of range
+ if not dims:
+ raise ValueError("No image_dims provided.")
+ else:
+ if img_index >= len(dims):
+ logger.warning("Image index out of range for stored dimensions. Using first image dims by default.")
+ img_index = 0
+
+ max_width = dims[img_index]["new_width"]
+ max_height = dims[img_index]["new_height"]
+ pad_width = dims[img_index].get("pad_width", 0)
+ pad_height = dims[img_index].get("pad_height", 0)
+ scale_factor = dims[img_index].get("scale_factor", 1.0)
+
+ bboxes: List[List[float]] = []
+ texts: List[str] = []
+
+ # Convert normalized coords back to actual pixel coords
+ for box, txt in zip(bounding_boxes, text_predictions):
+ if box == "nan":
+ continue
+ points: List[List[float]] = []
+ for point in box:
+ # Convert normalized coords back to actual pixel coords,
+ # and shift them back to their original positions if padded.
+ x_pixels = float(point[0]) * max_width - pad_width
+ y_pixels = float(point[1]) * max_height - pad_height
+ x_original = x_pixels / scale_factor
+ y_original = y_pixels / scale_factor
+ points.append([x_original, y_original])
+ bboxes.append(points)
+ texts.append(txt)
+
+ return bboxes, texts
diff --git a/src/nv_ingest/util/nim/text_embedding.py b/src/nv_ingest/util/nim/text_embedding.py
new file mode 100644
index 00000000..11a16789
--- /dev/null
+++ b/src/nv_ingest/util/nim/text_embedding.py
@@ -0,0 +1,128 @@
+from typing import Any, Dict, List, Optional, Tuple
+
+from nv_ingest.util.nim.helpers import ModelInterface
+
+
+# Assume ModelInterface is defined elsewhere in the project.
+class EmbeddingModelInterface(ModelInterface):
+ """
+ An interface for handling inference with an embedding model endpoint.
+ This implementation supports HTTP inference for generating embeddings from text prompts.
+ """
+
+ def name(self) -> str:
+ """
+ Return the name of this model interface.
+ """
+ return "Embedding"
+
+ def prepare_data_for_inference(self, data: Dict[str, Any]) -> Dict[str, Any]:
+ """
+ Prepare input data for embedding inference. Ensures that a 'prompts' key is provided
+ and that its value is a list.
+
+ Raises
+ ------
+ KeyError
+ If the 'prompts' key is missing.
+ """
+ if "prompts" not in data:
+ raise KeyError("Input data must include 'prompts'.")
+ # Ensure the prompts are in list format.
+ if not isinstance(data["prompts"], list):
+ data["prompts"] = [data["prompts"]]
+ return data
+
+ def format_input(
+ self, data: Dict[str, Any], protocol: str, max_batch_size: int, **kwargs
+ ) -> Tuple[List[Any], List[Dict[str, Any]]]:
+ """
+ Format the input payload for the embedding endpoint. This method constructs one payload per batch,
+ where each payload includes a list of text prompts.
+ Additionally, it returns batch data that preserves the original order of prompts.
+
+ Parameters
+ ----------
+ data : dict
+ The input data containing "prompts" (a list of text prompts).
+ protocol : str
+ Only "http" is supported.
+ max_batch_size : int
+ Maximum number of prompts per payload.
+ kwargs : dict
+ Additional parameters including model_name, encoding_format, input_type, and truncate.
+
+ Returns
+ -------
+ tuple
+ A tuple (payloads, batch_data_list) where:
+ - payloads is a list of JSON-serializable payload dictionaries.
+ - batch_data_list is a list of dictionaries containing the key "prompts" corresponding to each batch.
+ """
+ if protocol != "http":
+ raise ValueError("EmbeddingModelInterface only supports HTTP protocol.")
+
+ prompts = data.get("prompts", [])
+
+ def chunk_list(lst, chunk_size):
+ return [lst[i : i + chunk_size] for i in range(0, len(lst), chunk_size)]
+
+ batches = chunk_list(prompts, max_batch_size)
+ payloads = []
+ batch_data_list = []
+ for batch in batches:
+ payload = {
+ "model": kwargs.get("model_name"),
+ "input": batch,
+ "encoding_format": kwargs.get("encoding_format", "float"),
+ "extra_body": {
+ "input_type": kwargs.get("input_type", "query"),
+ "truncate": kwargs.get("truncate", "NONE"),
+ },
+ }
+ payloads.append(payload)
+ batch_data_list.append({"prompts": batch})
+ return payloads, batch_data_list
+
+ def parse_output(self, response: Any, protocol: str, data: Optional[Dict[str, Any]] = None, **kwargs) -> Any:
+ """
+ Parse the HTTP response from the embedding endpoint. Expects a response structure with a "data" key.
+
+ Parameters
+ ----------
+ response : Any
+ The raw HTTP response (assumed to be already decoded as JSON).
+ protocol : str
+ Only "http" is supported.
+ data : dict, optional
+ The original input data.
+ kwargs : dict
+ Additional keyword arguments.
+
+ Returns
+ -------
+ list
+ A list of generated embeddings extracted from the response.
+ """
+ if protocol != "http":
+ raise ValueError("EmbeddingModelInterface only supports HTTP protocol.")
+ if isinstance(response, dict):
+ embeddings = response.get("data")
+ if not embeddings:
+ raise RuntimeError("Unexpected response format: 'data' key is missing or empty.")
+ # Each item in embeddings is expected to have an 'embedding' field.
+ return [item.get("embedding", None) for item in embeddings]
+ else:
+ return [str(response)]
+
+ def process_inference_results(self, output: Any, protocol: str, **kwargs) -> Any:
+ """
+ Process inference results for the embedding model.
+ For this implementation, the output is expected to be a list of embeddings.
+
+ Returns
+ -------
+ list
+ The processed list of embeddings.
+ """
+ return output
diff --git a/src/nv_ingest/util/nim/vlm.py b/src/nv_ingest/util/nim/vlm.py
new file mode 100644
index 00000000..cffefb0f
--- /dev/null
+++ b/src/nv_ingest/util/nim/vlm.py
@@ -0,0 +1,148 @@
+from typing import Dict, Any, Optional, Tuple, List
+
+import logging
+
+from nv_ingest.util.nim.helpers import ModelInterface
+
+logger = logging.getLogger(__name__)
+
+
+class VLMModelInterface(ModelInterface):
+ """
+ An interface for handling inference with a VLM model endpoint (e.g., NVIDIA LLaMA-based VLM).
+ This implementation supports HTTP inference with one or more base64-encoded images and a caption prompt.
+ """
+
+ def name(self) -> str:
+ """
+ Return the name of this model interface.
+ """
+ return "VLM"
+
+ def prepare_data_for_inference(self, data: Dict[str, Any]) -> Dict[str, Any]:
+ """
+ Prepare input data for VLM inference. Accepts either a single base64 image or a list of images.
+ Ensures that a 'prompt' is provided.
+
+ Raises
+ ------
+ KeyError
+ If neither "base64_image" nor "base64_images" is provided or if "prompt" is missing.
+ ValueError
+ If "base64_images" exists but is not a list.
+ """
+ # Allow either a single image with "base64_image" or multiple images with "base64_images".
+ if "base64_images" in data:
+ if not isinstance(data["base64_images"], list):
+ raise ValueError("The 'base64_images' key must contain a list of base64-encoded strings.")
+ elif "base64_image" in data:
+ # Convert a single image into a list.
+ data["base64_images"] = [data["base64_image"]]
+ else:
+ raise KeyError("Input data must include 'base64_image' or 'base64_images'.")
+
+ if "prompt" not in data:
+ raise KeyError("Input data must include 'prompt'.")
+ return data
+
+ def format_input(
+ self, data: Dict[str, Any], protocol: str, max_batch_size: int, **kwargs
+ ) -> Tuple[List[Any], List[Dict[str, Any]]]:
+ """
+ Format the input payload for the VLM endpoint. This method constructs one payload per batch,
+ where each payload includes one message per image in the batch.
+ Additionally, it returns batch data that preserves the original order of images by including
+ the list of base64 images and the prompt for each batch.
+
+ Parameters
+ ----------
+ data : dict
+ The input data containing "base64_images" (a list of base64-encoded images) and "prompt".
+ protocol : str
+ Only "http" is supported.
+ max_batch_size : int
+ Maximum number of images per payload.
+ kwargs : dict
+ Additional parameters including model_name, max_tokens, temperature, top_p, and stream.
+
+ Returns
+ -------
+ tuple
+ A tuple (payloads, batch_data_list) where:
+ - payloads is a list of JSON-serializable payload dictionaries.
+ - batch_data_list is a list of dictionaries containing the keys "base64_images" and "prompt"
+ corresponding to each batch.
+ """
+ if protocol != "http":
+ raise ValueError("VLMModelInterface only supports HTTP protocol.")
+
+ images = data.get("base64_images", [])
+ prompt = data["prompt"]
+
+ # Helper function to chunk the list into batches.
+ def chunk_list(lst, chunk_size):
+ return [lst[i : i + chunk_size] for i in range(0, len(lst), chunk_size)]
+
+ batches = chunk_list(images, max_batch_size)
+ payloads = []
+ batch_data_list = []
+ for batch in batches:
+ # Create one message per image in the batch.
+ messages = [
+ {"role": "user", "content": f'{prompt}
'} for img in batch
+ ]
+ payload = {
+ "model": kwargs.get("model_name"),
+ "messages": messages,
+ "max_tokens": kwargs.get("max_tokens", 512),
+ "temperature": kwargs.get("temperature", 1.0),
+ "top_p": kwargs.get("top_p", 1.0),
+ "stream": kwargs.get("stream", False),
+ }
+ payloads.append(payload)
+ batch_data_list.append({"base64_images": batch, "prompt": prompt})
+ return payloads, batch_data_list
+
+ def parse_output(self, response: Any, protocol: str, data: Optional[Dict[str, Any]] = None, **kwargs) -> Any:
+ """
+ Parse the HTTP response from the VLM endpoint. Expects a response structure with a "choices" key.
+
+ Parameters
+ ----------
+ response : Any
+ The raw HTTP response (assumed to be already decoded as JSON).
+ protocol : str
+ Only "http" is supported.
+ data : dict, optional
+ The original input data.
+ kwargs : dict
+ Additional keyword arguments.
+
+ Returns
+ -------
+ list
+ A list of generated captions extracted from the response.
+ """
+ if protocol != "http":
+ raise ValueError("VLMModelInterface only supports HTTP protocol.")
+ if isinstance(response, dict):
+ choices = response.get("choices", [])
+ if not choices:
+ raise RuntimeError("Unexpected response format: 'choices' key is missing or empty.")
+ # Return a list of captions, one per choice.
+ return [choice.get("message", {}).get("content", "No caption returned") for choice in choices]
+ else:
+ # If response is not a dict, return its string representation in a list.
+ return [str(response)]
+
+ def process_inference_results(self, output: Any, protocol: str, **kwargs) -> Any:
+ """
+ Process inference results for the VLM model.
+ For this implementation, the output is expected to be a list of captions.
+
+ Returns
+ -------
+ list
+ The processed list of captions.
+ """
+ return output
diff --git a/src/nv_ingest/util/nim/yolox.py b/src/nv_ingest/util/nim/yolox.py
new file mode 100644
index 00000000..47df3569
--- /dev/null
+++ b/src/nv_ingest/util/nim/yolox.py
@@ -0,0 +1,1186 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES.
+# All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+
+
+import base64
+import io
+import logging
+import warnings
+from math import log
+from typing import Any, Tuple
+from typing import Dict
+from typing import List
+from typing import Optional
+
+import cv2
+import numpy as np
+import pandas as pd
+import torch
+import torchvision
+from PIL import Image
+
+from nv_ingest.util.image_processing.transforms import scale_image_to_encoding_size
+from nv_ingest.util.nim.helpers import ModelInterface
+
+logger = logging.getLogger(__name__)
+
+# yolox-page-elements-v1 contants
+YOLOX_PAGE_NUM_CLASSES = 3
+YOLOX_PAGE_CONF_THRESHOLD = 0.01
+YOLOX_PAGE_IOU_THRESHOLD = 0.5
+YOLOX_PAGE_MIN_SCORE = 0.1
+YOLOX_PAGE_FINAL_SCORE = 0.48
+YOLOX_PAGE_NIM_MAX_IMAGE_SIZE = 512_000
+
+YOLOX_PAGE_IMAGE_PREPROC_HEIGHT = 1024
+YOLOX_PAGE_IMAGE_PREPROC_WIDTH = 1024
+
+YOLOX_PAGE_CLASS_LABELS = [
+ "table",
+ "chart",
+ "title",
+]
+
+# yolox-graphic-elements-v1 contants
+YOLOX_GRAPHIC_NUM_CLASSES = 10
+YOLOX_GRAPHIC_CONF_THRESHOLD = 0.01
+YOLOX_GRAPHIC_IOU_THRESHOLD = 0.25
+YOLOX_GRAPHIC_MIN_SCORE = 0.1
+YOLOX_GRAPHIC_FINAL_SCORE = 0.0
+YOLOX_GRAPHIC_NIM_MAX_IMAGE_SIZE = 512_000
+
+YOLOX_GRAPHIC_IMAGE_PREPROC_HEIGHT = 768
+YOLOX_GRAPHIC_IMAGE_PREPROC_WIDTH = 768
+
+YOLOX_GRAPHIC_CLASS_LABELS = [
+ "chart_title",
+ "x_title",
+ "y_title",
+ "xlabel",
+ "ylabel",
+ "other",
+ "legend_label",
+ "legend_title",
+ "mark_label",
+ "value_label",
+]
+
+
+# YoloxModelInterfaceBase implements methods that are common to yolox-page-elements and yolox-graphic-elements
+class YoloxModelInterfaceBase(ModelInterface):
+ """
+ An interface for handling inference with a Yolox object detection model, supporting both gRPC and HTTP protocols.
+ """
+
+ def __init__(
+ self,
+ image_preproc_width: Optional[int] = None,
+ image_preproc_height: Optional[int] = None,
+ nim_max_image_size: Optional[int] = None,
+ num_classes: Optional[int] = None,
+ conf_threshold: Optional[float] = None,
+ iou_threshold: Optional[float] = None,
+ min_score: Optional[float] = None,
+ final_score: Optional[float] = None,
+ class_labels: Optional[List[str]] = None,
+ ):
+ """
+ Initialize the YOLOX model interface.
+ Parameters
+ ----------
+ """
+ self.image_preproc_width = image_preproc_width
+ self.image_preproc_height = image_preproc_height
+ self.nim_max_image_size = nim_max_image_size
+ self.num_classes = num_classes
+ self.conf_threshold = conf_threshold
+ self.iou_threshold = iou_threshold
+ self.min_score = min_score
+ self.final_score = final_score
+ self.class_labels = class_labels
+
+ def prepare_data_for_inference(self, data: Dict[str, Any]) -> Dict[str, Any]:
+ """
+ Prepare input data for inference by resizing images and storing their original shapes.
+
+ Parameters
+ ----------
+ data : dict
+ The input data containing a list of images.
+
+ Returns
+ -------
+ dict
+ The updated data dictionary with resized images and original image shapes.
+ """
+ if (not isinstance(data, dict)) or ("images" not in data):
+ raise KeyError("Input data must be a dictionary containing an 'images' key with a list of images.")
+
+ if not all(isinstance(x, np.ndarray) for x in data["images"]):
+ raise ValueError("All elements in the 'images' list must be numpy.ndarray objects.")
+
+ original_images = data["images"]
+ data["original_image_shapes"] = [image.shape for image in original_images]
+
+ return data
+
+ def format_input(
+ self, data: Dict[str, Any], protocol: str, max_batch_size: int, **kwargs
+ ) -> Tuple[List[Any], List[Dict[str, Any]]]:
+ """
+ Format input data for the specified protocol, returning a tuple of:
+ (formatted_batches, formatted_batch_data)
+ where:
+ - For gRPC: formatted_batches is a list of NumPy arrays, each of shape (B, H, W, C)
+ with B <= max_batch_size.
+ - For HTTP: formatted_batches is a list of JSON-serializable dict payloads.
+ - In both cases, formatted_batch_data is a list of dicts that coalesce the original
+ images and their original shapes in the same order as provided.
+
+ Parameters
+ ----------
+ data : dict
+ The input data to format. Must include:
+ - "images": a list of numpy.ndarray images.
+ - "original_image_shapes": a list of tuples with each image's (height, width),
+ as set by prepare_data_for_inference.
+ protocol : str
+ The protocol to use ("grpc" or "http").
+ max_batch_size : int
+ The maximum number of images per batch.
+
+ Returns
+ -------
+ tuple
+ A tuple (formatted_batches, formatted_batch_data).
+
+ Raises
+ ------
+ ValueError
+ If the protocol is invalid.
+ """
+
+ # Helper functions to chunk a list into sublists of length up to chunk_size.
+ def chunk_list(lst: list, chunk_size: int) -> List[list]:
+ return [lst[i : i + chunk_size] for i in range(0, len(lst), chunk_size)]
+
+ def chunk_list_geometrically(lst: list, max_size: int) -> List[list]:
+ # TRT engine in Yolox NIM (gRPC) only allows a batch size in powers of 2.
+ chunks = []
+ i = 0
+ while i < len(lst):
+ chunk_size = min(2 ** int(log(len(lst) - i, 2)), max_size)
+ chunks.append(lst[i : i + chunk_size])
+ i += chunk_size
+ return chunks
+
+ if protocol == "grpc":
+ logger.debug("Formatting input for gRPC Yolox model")
+ # Resize images for model input (Yolox expects 1024x1024).
+ resized_images = [
+ resize_image(image, (self.image_preproc_width, self.image_preproc_height)) for image in data["images"]
+ ]
+ # Chunk the resized images, the original images, and their shapes.
+ resized_chunks = chunk_list_geometrically(resized_images, max_batch_size)
+ original_chunks = chunk_list_geometrically(data["images"], max_batch_size)
+ shape_chunks = chunk_list_geometrically(data["original_image_shapes"], max_batch_size)
+
+ batched_inputs = []
+ formatted_batch_data = []
+ for r_chunk, orig_chunk, shapes in zip(resized_chunks, original_chunks, shape_chunks):
+ # Reorder axes from (B, H, W, C) to (B, C, H, W) as expected by the model.
+ input_array = np.einsum("bijk->bkij", r_chunk).astype(np.float32)
+ batched_inputs.append(input_array)
+ formatted_batch_data.append({"images": orig_chunk, "original_image_shapes": shapes})
+ return batched_inputs, formatted_batch_data
+
+ elif protocol == "http":
+ logger.debug("Formatting input for HTTP Yolox model")
+ content_list: List[Dict[str, Any]] = []
+ for image in data["images"]:
+ # Convert the numpy array to a PIL Image. Assume images are in [0,1].
+ image_pil = Image.fromarray((image * 255).astype(np.uint8))
+ original_size = image_pil.size
+
+ # Save the image to a buffer and encode to base64.
+ buffered = io.BytesIO()
+ image_pil.save(buffered, format="PNG")
+ image_b64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
+
+ # Scale the image if necessary.
+ scaled_image_b64, new_size = scale_image_to_encoding_size(
+ image_b64, max_base64_size=self.nim_max_image_size
+ )
+ if new_size != original_size:
+ logger.debug(f"Image was scaled from {original_size} to {new_size}.")
+
+ content_list.append({"type": "image_url", "url": f"data:image/png;base64,{scaled_image_b64}"})
+
+ # Chunk the payload content, the original images, and their shapes.
+ content_chunks = chunk_list(content_list, max_batch_size)
+ original_chunks = chunk_list(data["images"], max_batch_size)
+ shape_chunks = chunk_list(data["original_image_shapes"], max_batch_size)
+
+ payload_batches = []
+ formatted_batch_data = []
+ for chunk, orig_chunk, shapes in zip(content_chunks, original_chunks, shape_chunks):
+ payload = {"input": chunk}
+ payload_batches.append(payload)
+ formatted_batch_data.append({"images": orig_chunk, "original_image_shapes": shapes})
+ return payload_batches, formatted_batch_data
+
+ else:
+ raise ValueError("Invalid protocol specified. Must be 'grpc' or 'http'.")
+
+ def parse_output(self, response: Any, protocol: str, data: Optional[Dict[str, Any]] = None, **kwargs) -> Any:
+ """
+ Parse the output from the model's inference response.
+
+ Parameters
+ ----------
+ response : Any
+ The response from the model inference.
+ protocol : str
+ The protocol used ("grpc" or "http").
+ data : dict, optional
+ Additional input data passed to the function.
+
+ Returns
+ -------
+ Any
+ The parsed output data.
+
+ Raises
+ ------
+ ValueError
+ If an invalid protocol is specified or the response format is unexpected.
+ """
+
+ if protocol == "grpc":
+ logger.debug("Parsing output from gRPC Yolox model")
+ return response # For gRPC, response is already a numpy array
+ elif protocol == "http":
+ logger.debug("Parsing output from HTTP Yolox model")
+
+ processed_outputs = []
+
+ batch_results = response.get("data", [])
+ for detections in batch_results:
+ new_bounding_boxes = {label: [] for label in self.class_labels}
+
+ bounding_boxes = detections.get("bounding_boxes", [])
+ for obj_type, bboxes in bounding_boxes.items():
+ for bbox in bboxes:
+ xmin = bbox["x_min"]
+ ymin = bbox["y_min"]
+ xmax = bbox["x_max"]
+ ymax = bbox["y_max"]
+ confidence = bbox["confidence"]
+
+ new_bounding_boxes[obj_type].append([xmin, ymin, xmax, ymax, confidence])
+
+ processed_outputs.append(new_bounding_boxes)
+
+ return processed_outputs
+ else:
+ raise ValueError("Invalid protocol specified. Must be 'grpc' or 'http'.")
+
+ def process_inference_results(self, output: Any, protocol: str, **kwargs) -> List[Dict[str, Any]]:
+ """
+ Process the results of the Yolox model inference and return the final annotations.
+
+ Parameters
+ ----------
+ output_array : np.ndarray
+ The raw output from the Yolox model.
+ kwargs : dict
+ Additional parameters for processing, including thresholds and number of classes.
+
+ Returns
+ -------
+ list[dict]
+ A list of annotation dictionaries for each image in the batch.
+ """
+ original_image_shapes = kwargs.get("original_image_shapes", [])
+
+ if protocol == "http":
+ # For http, the output already has postprocessing applied. Skip to table/chart expansion.
+ results = output
+
+ elif protocol == "grpc":
+ # For grpc, apply the same NIM postprocessing.
+ pred = postprocess_model_prediction(
+ output, self.num_classes, self.conf_threshold, self.iou_threshold, class_agnostic=True
+ )
+ results = postprocess_results(
+ pred,
+ original_image_shapes,
+ self.image_preproc_width,
+ self.image_preproc_height,
+ self.class_labels,
+ min_score=self.min_score,
+ )
+
+ inference_results = self.postprocess_annotations(results, **kwargs)
+
+ return inference_results
+
+ def postprocess_annotations(self, annotation_dicts, **kwargs):
+ raise NotImplementedError()
+
+ def transform_normalized_coordinates_to_original(self, results, original_image_shapes):
+ """ """
+ transformed_results = []
+
+ for annotation_dict, shape in zip(results, original_image_shapes):
+ new_dict = {}
+ for label, bboxes_and_scores in annotation_dict.items():
+ new_dict[label] = []
+ for bbox_and_score in bboxes_and_scores:
+ bbox = bbox_and_score[:4]
+ transformed_bbox = [
+ bbox[0] * shape[1],
+ bbox[1] * shape[0],
+ bbox[2] * shape[1],
+ bbox[3] * shape[0],
+ ]
+ transformed_bbox += bbox_and_score[4:]
+ new_dict[label].append(transformed_bbox)
+ transformed_results.append(new_dict)
+
+ return transformed_results
+
+
+class YoloxPageElementsModelInterface(YoloxModelInterfaceBase):
+ """
+ An interface for handling inference with yolox-page-elements model, supporting both gRPC and HTTP protocols.
+ """
+
+ def __init__(self):
+ """
+ Initialize the yolox-page-elements model interface.
+ """
+ super().__init__(
+ image_preproc_width=YOLOX_PAGE_IMAGE_PREPROC_HEIGHT,
+ image_preproc_height=YOLOX_PAGE_IMAGE_PREPROC_HEIGHT,
+ nim_max_image_size=YOLOX_PAGE_NIM_MAX_IMAGE_SIZE,
+ num_classes=YOLOX_PAGE_NUM_CLASSES,
+ conf_threshold=YOLOX_PAGE_CONF_THRESHOLD,
+ iou_threshold=YOLOX_PAGE_IOU_THRESHOLD,
+ min_score=YOLOX_PAGE_MIN_SCORE,
+ final_score=YOLOX_PAGE_FINAL_SCORE,
+ class_labels=YOLOX_PAGE_CLASS_LABELS,
+ )
+
+ def name(
+ self,
+ ) -> str:
+ """
+ Returns the name of the Yolox model interface.
+
+ Returns
+ -------
+ str
+ The name of the model interface.
+ """
+
+ return "yolox-page-elements"
+
+ def postprocess_annotations(self, annotation_dicts, **kwargs):
+ original_image_shapes = kwargs.get("original_image_shapes", [])
+
+ # Table/chart expansion is "business logic" specific to nv-ingest
+ annotation_dicts = [expand_table_bboxes(annotation_dict) for annotation_dict in annotation_dicts]
+ annotation_dicts = [expand_chart_bboxes(annotation_dict) for annotation_dict in annotation_dicts]
+ inference_results = []
+
+ # Filter out bounding boxes below the final threshold
+ # This final thresholding is "business logic" specific to nv-ingest
+ for annotation_dict in annotation_dicts:
+ new_dict = {}
+ if "table" in annotation_dict:
+ new_dict["table"] = [bb for bb in annotation_dict["table"] if bb[4] >= self.final_score]
+ if "chart" in annotation_dict:
+ new_dict["chart"] = [bb for bb in annotation_dict["chart"] if bb[4] >= self.final_score]
+ if "title" in annotation_dict:
+ new_dict["title"] = annotation_dict["title"]
+ inference_results.append(new_dict)
+
+ inference_results = self.transform_normalized_coordinates_to_original(inference_results, original_image_shapes)
+
+ return inference_results
+
+
+class YoloxGraphicElementsModelInterface(YoloxModelInterfaceBase):
+ """
+ An interface for handling inference with yolox-graphic-elemenents model, supporting both gRPC and HTTP protocols.
+ """
+
+ def __init__(self):
+ """
+ Initialize the yolox-graphic-elements model interface.
+ """
+ super().__init__(
+ image_preproc_width=YOLOX_GRAPHIC_IMAGE_PREPROC_HEIGHT,
+ image_preproc_height=YOLOX_GRAPHIC_IMAGE_PREPROC_HEIGHT,
+ nim_max_image_size=YOLOX_GRAPHIC_NIM_MAX_IMAGE_SIZE,
+ num_classes=YOLOX_GRAPHIC_NUM_CLASSES,
+ conf_threshold=YOLOX_GRAPHIC_CONF_THRESHOLD,
+ iou_threshold=YOLOX_GRAPHIC_IOU_THRESHOLD,
+ min_score=YOLOX_GRAPHIC_MIN_SCORE,
+ final_score=YOLOX_GRAPHIC_FINAL_SCORE,
+ class_labels=YOLOX_GRAPHIC_CLASS_LABELS,
+ )
+
+ def name(
+ self,
+ ) -> str:
+ """
+ Returns the name of the Yolox model interface.
+
+ Returns
+ -------
+ str
+ The name of the model interface.
+ """
+
+ return "yolox-graphic-elements"
+
+ def postprocess_annotations(self, annotation_dicts, **kwargs):
+ original_image_shapes = kwargs.get("original_image_shapes", [])
+
+ annotation_dicts = self.transform_normalized_coordinates_to_original(annotation_dicts, original_image_shapes)
+
+ inference_results = []
+
+ # bbox extraction: additional postprocessing speicifc to nv-ingest
+ for pred, shape in zip(annotation_dicts, original_image_shapes):
+ bbox_dict = get_bbox_dict_yolox_graphic(
+ pred,
+ shape,
+ self.class_labels,
+ self.min_score,
+ )
+ # convert numpy arrays to list
+ bbox_dict = {
+ label: array.tolist() if isinstance(array, np.ndarray) else array for label, array in bbox_dict.items()
+ }
+ inference_results.append(bbox_dict)
+
+ return inference_results
+
+
+def postprocess_model_prediction(prediction, num_classes, conf_thre=0.7, nms_thre=0.45, class_agnostic=False):
+ # Convert numpy array to torch tensor
+ prediction = torch.from_numpy(prediction.copy())
+
+ # Compute box corners
+ box_corner = prediction.new(prediction.shape)
+ box_corner[:, :, 0] = prediction[:, :, 0] - prediction[:, :, 2] / 2
+ box_corner[:, :, 1] = prediction[:, :, 1] - prediction[:, :, 3] / 2
+ box_corner[:, :, 2] = prediction[:, :, 0] + prediction[:, :, 2] / 2
+ box_corner[:, :, 3] = prediction[:, :, 1] + prediction[:, :, 3] / 2
+ prediction[:, :, :4] = box_corner[:, :, :4]
+
+ output = [None for _ in range(len(prediction))]
+
+ for i, image_pred in enumerate(prediction):
+ # If no detections, continue to the next image
+ if not image_pred.size(0):
+ continue
+
+ # Ensure image_pred is 2D
+ if image_pred.ndim == 1:
+ image_pred = image_pred.unsqueeze(0)
+
+ # Get score and class with highest confidence
+ class_conf, class_pred = torch.max(image_pred[:, 5 : 5 + num_classes], 1, keepdim=True)
+
+ # Confidence mask
+ squeezed_conf = class_conf.squeeze(dim=1)
+ conf_mask = image_pred[:, 4] * squeezed_conf >= conf_thre
+
+ # Apply confidence mask
+ detections = torch.cat((image_pred[:, :5], class_conf, class_pred.float()), 1)
+ detections = detections[conf_mask]
+
+ if not detections.size(0):
+ continue
+
+ # Apply Non-Maximum Suppression (NMS)
+ if class_agnostic:
+ nms_out_index = torchvision.ops.nms(
+ detections[:, :4],
+ detections[:, 4] * detections[:, 5],
+ nms_thre,
+ )
+ else:
+ nms_out_index = torchvision.ops.batched_nms(
+ detections[:, :4],
+ detections[:, 4] * detections[:, 5],
+ detections[:, 6],
+ nms_thre,
+ )
+ detections = detections[nms_out_index]
+
+ # Append detections to output
+ output[i] = detections
+
+ return output
+
+
+def postprocess_results(
+ results, original_image_shapes, image_preproc_width, image_preproc_height, class_labels, min_score=0.0
+):
+ """
+ For each item (==image) in results, computes annotations in the form
+
+ {"table": [[0.0107, 0.0859, 0.7537, 0.1219, 0.9861], ...],
+ "figure": [...],
+ "title": [...]
+ }
+ where each list of 5 floats represents a bounding box in the format [x1, y1, x2, y2, confidence]
+
+ Keep only bboxes with high enough confidence.
+ """
+ out = []
+
+ for original_image_shape, result in zip(original_image_shapes, results):
+ annotation_dict = {label: [] for label in class_labels}
+
+ if result is None:
+ out.append(annotation_dict)
+ continue
+
+ try:
+ result = result.cpu().numpy()
+ scores = result[:, 4] * result[:, 5]
+ result = result[scores > min_score]
+
+ # ratio is used when image was padded
+ ratio = min(
+ image_preproc_width / original_image_shape[0],
+ image_preproc_height / original_image_shape[1],
+ )
+ bboxes = result[:, :4] / ratio
+
+ bboxes[:, [0, 2]] /= original_image_shape[1]
+ bboxes[:, [1, 3]] /= original_image_shape[0]
+ bboxes = np.clip(bboxes, 0.0, 1.0)
+
+ labels = result[:, 6]
+ scores = scores[scores > min_score]
+ except Exception as e:
+ raise ValueError(f"Error in postprocessing {result.shape} and {original_image_shape}: {e}")
+
+ for box, score, label in zip(bboxes, scores, labels):
+ class_name = class_labels[int(label)]
+ annotation_dict[class_name].append([round(float(x), 4) for x in np.concatenate((box, [score]))])
+
+ out.append(annotation_dict)
+
+ return out
+
+
+def resize_image(image, target_img_size):
+ w, h, _ = np.array(image).shape
+
+ if target_img_size is not None: # Resize + Pad
+ r = min(target_img_size[0] / w, target_img_size[1] / h)
+ image = cv2.resize(
+ image,
+ (int(h * r), int(w * r)),
+ interpolation=cv2.INTER_LINEAR,
+ ).astype(np.uint8)
+ image = np.pad(
+ image,
+ ((0, target_img_size[0] - image.shape[0]), (0, target_img_size[1] - image.shape[1]), (0, 0)),
+ mode="constant",
+ constant_values=114,
+ )
+
+ return image
+
+
+def expand_table_bboxes(annotation_dict, labels=None):
+ """
+ Additional preprocessing for tables: extend the upper bounds to capture titles if any.
+ Args:
+ annotation_dict: output of postprocess_results, a dictionary with keys "table", "figure", "title"
+
+ Returns:
+ annotation_dict: same as input, with expanded bboxes for charts
+
+ """
+ if not labels:
+ labels = ["table", "chart", "title"]
+
+ if not annotation_dict or len(annotation_dict["table"]) == 0:
+ return annotation_dict
+
+ new_annotation_dict = {label: [] for label in labels}
+
+ for label, bboxes in annotation_dict.items():
+ for bbox_and_score in bboxes:
+ bbox, score = bbox_and_score[:4], bbox_and_score[4]
+
+ if label == "table":
+ height = bbox[3] - bbox[1]
+ bbox[1] = max(0.0, min(1.0, bbox[1] - height * 0.2))
+
+ new_annotation_dict[label].append([round(float(x), 4) for x in bbox + [score]])
+
+ return new_annotation_dict
+
+
+def expand_chart_bboxes(annotation_dict, labels=None):
+ """
+ Expand bounding boxes of charts and titles based on the bounding boxes of the other class.
+ Args:
+ annotation_dict: output of postprocess_results, a dictionary with keys "table", "figure", "title"
+
+ Returns:
+ annotation_dict: same as input, with expanded bboxes for charts
+
+ """
+ if not labels:
+ labels = ["table", "chart", "title"]
+
+ if not annotation_dict or len(annotation_dict["chart"]) == 0:
+ return annotation_dict
+
+ bboxes = []
+ confidences = []
+ label_idxs = []
+ for i, label in enumerate(labels):
+ label_annotations = np.array(annotation_dict[label])
+
+ if len(label_annotations) > 0:
+ bboxes.append(label_annotations[:, :4])
+ confidences.append(label_annotations[:, 4])
+ label_idxs.append(np.full(len(label_annotations), i))
+ bboxes = np.concatenate(bboxes)
+ confidences = np.concatenate(confidences)
+ label_idxs = np.concatenate(label_idxs)
+
+ pred_wbf, confidences_wbf, labels_wbf = weighted_boxes_fusion(
+ bboxes[:, None],
+ confidences[:, None],
+ label_idxs[:, None],
+ merge_type="biggest",
+ conf_type="max",
+ iou_thr=0.01,
+ class_agnostic=False,
+ )
+ chart_bboxes = pred_wbf[labels_wbf == 1]
+ chart_confidences = confidences_wbf[labels_wbf == 1]
+ title_bboxes = pred_wbf[labels_wbf == 2]
+
+ found_title_idxs, no_found_title_idxs = [], []
+ for i in range(len(chart_bboxes)):
+ match = match_with_title(chart_bboxes[i], title_bboxes, iou_th=0.01)
+ if match is not None:
+ chart_bboxes[i] = match[0]
+ title_bboxes = match[1]
+ found_title_idxs.append(i)
+ else:
+ no_found_title_idxs.append(i)
+
+ chart_bboxes[found_title_idxs] = expand_boxes(chart_bboxes[found_title_idxs], r_x=1.05, r_y=1.1)
+ chart_bboxes[no_found_title_idxs] = expand_boxes(chart_bboxes[no_found_title_idxs], r_x=1.1, r_y=1.25)
+
+ annotation_dict = {
+ "table": annotation_dict["table"],
+ "chart": np.concatenate([chart_bboxes, chart_confidences[:, None]], axis=1).tolist(),
+ "title": annotation_dict["title"],
+ }
+ return annotation_dict
+
+
+def weighted_boxes_fusion(
+ boxes_list,
+ scores_list,
+ labels_list,
+ iou_thr=0.5,
+ skip_box_thr=0.0,
+ conf_type="avg",
+ merge_type="weighted",
+ class_agnostic=False,
+):
+ """
+ Custom wbf implementation that supports a class_agnostic mode and a biggest box fusion.
+ Boxes are expected to be in normalized (x0, y0, x1, y1) format.
+
+ Args:
+ boxes_list (list[np array[n x 4]]): List of boxes. One list per model.
+ scores_list (list[np array[n]]): List of confidences.
+ labels_list (list[np array[n]]): List of labels
+ iou_thr (float, optional): IoU threshold for matching. Defaults to 0.55.
+ skip_box_thr (float, optional): Exclude boxes with score < skip_box_thr. Defaults to 0.0.
+ conf_type (str, optional): Confidence merging type. Defaults to "avg".
+ merge_type (str, optional): Merge type "weighted" or "biggest". Defaults to "weighted".
+ class_agnostic (bool, optional): If True, merge boxes from different classes. Defaults to False.
+
+ Returns:
+ np array[N x 4]: Merged boxes,
+ np array[N]: Merged confidences,
+ np array[N]: Merged labels.
+ """
+ weights = np.ones(len(boxes_list))
+
+ assert conf_type in ["avg", "max"], 'Conf type must be "avg" or "max"'
+ assert merge_type in [
+ "weighted",
+ "biggest",
+ ], 'Conf type must be "weighted" or "biggest"'
+
+ filtered_boxes = prefilter_boxes(
+ boxes_list,
+ scores_list,
+ labels_list,
+ weights,
+ skip_box_thr,
+ class_agnostic=class_agnostic,
+ )
+ if len(filtered_boxes) == 0:
+ return np.zeros((0, 4)), np.zeros((0,)), np.zeros((0,))
+
+ overall_boxes = []
+ for label in filtered_boxes:
+ boxes = filtered_boxes[label]
+ np.empty((0, 8))
+
+ clusters = []
+
+ # Clusterize boxes
+ for j in range(len(boxes)):
+ ids = [i for i in range(len(boxes)) if i != j]
+ index, best_iou = find_matching_box_fast(boxes[ids], boxes[j], iou_thr)
+
+ if index != -1:
+ index = ids[index]
+ cluster_idx = [clust_idx for clust_idx, clust in enumerate(clusters) if (j in clust or index in clust)]
+ if len(cluster_idx):
+ cluster_idx = cluster_idx[0]
+ clusters[cluster_idx] = list(set(clusters[cluster_idx] + [index, j]))
+ else:
+ clusters.append([index, j])
+ else:
+ clusters.append([j])
+
+ for j, c in enumerate(clusters):
+ if merge_type == "weighted":
+ weighted_box = get_weighted_box(boxes[c], conf_type)
+ elif merge_type == "biggest":
+ weighted_box = get_biggest_box(boxes[c], conf_type)
+
+ if conf_type == "max":
+ weighted_box[1] = weighted_box[1] / weights.max()
+ else: # avg
+ weighted_box[1] = weighted_box[1] * len(c) / weights.sum()
+ overall_boxes.append(weighted_box)
+
+ overall_boxes = np.array(overall_boxes)
+ overall_boxes = overall_boxes[overall_boxes[:, 1].argsort()[::-1]]
+ boxes = overall_boxes[:, 4:]
+ scores = overall_boxes[:, 1]
+ labels = overall_boxes[:, 0]
+ return boxes, scores, labels
+
+
+def prefilter_boxes(boxes, scores, labels, weights, thr, class_agnostic=False):
+ """
+ Reformats and filters boxes.
+ Output is a dict of boxes to merge separately.
+
+ Args:
+ boxes (list[np array[n x 4]]): List of boxes. One list per model.
+ scores (list[np array[n]]): List of confidences.
+ labels (list[np array[n]]): List of labels.
+ weights (list): Model weights.
+ thr (float): Confidence threshold
+ class_agnostic (bool, optional): If True, merge boxes from different classes. Defaults to False.
+
+ Returns:
+ dict[np array [? x 8]]: Filtered boxes.
+ """
+ # Create dict with boxes stored by its label
+ new_boxes = dict()
+
+ for t in range(len(boxes)):
+ if len(boxes[t]) != len(scores[t]):
+ print(
+ "Error. Length of boxes arrays not equal to length of scores array: {} != {}".format(
+ len(boxes[t]), len(scores[t])
+ )
+ )
+ exit()
+
+ if len(boxes[t]) != len(labels[t]):
+ print(
+ "Error. Length of boxes arrays not equal to length of labels array: {} != {}".format(
+ len(boxes[t]), len(labels[t])
+ )
+ )
+ exit()
+
+ for j in range(len(boxes[t])):
+ score = scores[t][j]
+ if score < thr:
+ continue
+ label = int(labels[t][j])
+ box_part = boxes[t][j]
+ x1 = float(box_part[0])
+ y1 = float(box_part[1])
+ x2 = float(box_part[2])
+ y2 = float(box_part[3])
+
+ # Box data checks
+ if x2 < x1:
+ warnings.warn("X2 < X1 value in box. Swap them.")
+ x1, x2 = x2, x1
+ if y2 < y1:
+ warnings.warn("Y2 < Y1 value in box. Swap them.")
+ y1, y2 = y2, y1
+ if x1 < 0:
+ warnings.warn("X1 < 0 in box. Set it to 0.")
+ x1 = 0
+ if x1 > 1:
+ warnings.warn("X1 > 1 in box. Set it to 1. Check that you normalize boxes in [0, 1] range.")
+ x1 = 1
+ if x2 < 0:
+ warnings.warn("X2 < 0 in box. Set it to 0.")
+ x2 = 0
+ if x2 > 1:
+ warnings.warn("X2 > 1 in box. Set it to 1. Check that you normalize boxes in [0, 1] range.")
+ x2 = 1
+ if y1 < 0:
+ warnings.warn("Y1 < 0 in box. Set it to 0.")
+ y1 = 0
+ if y1 > 1:
+ warnings.warn("Y1 > 1 in box. Set it to 1. Check that you normalize boxes in [0, 1] range.")
+ y1 = 1
+ if y2 < 0:
+ warnings.warn("Y2 < 0 in box. Set it to 0.")
+ y2 = 0
+ if y2 > 1:
+ warnings.warn("Y2 > 1 in box. Set it to 1. Check that you normalize boxes in [0, 1] range.")
+ y2 = 1
+ if (x2 - x1) * (y2 - y1) == 0.0:
+ warnings.warn("Zero area box skipped: {}.".format(box_part))
+ continue
+
+ # [label, score, weight, model index, x1, y1, x2, y2]
+ b = [int(label), float(score) * weights[t], weights[t], t, x1, y1, x2, y2]
+
+ label_k = "*" if class_agnostic else label
+ if label_k not in new_boxes:
+ new_boxes[label_k] = []
+ new_boxes[label_k].append(b)
+
+ # Sort each list in dict by score and transform it to numpy array
+ for k in new_boxes:
+ current_boxes = np.array(new_boxes[k])
+ new_boxes[k] = current_boxes[current_boxes[:, 1].argsort()[::-1]]
+
+ return new_boxes
+
+
+def find_matching_box_fast(boxes_list, new_box, match_iou):
+ """
+ Reimplementation of find_matching_box with numpy instead of loops. Gives significant speed up for larger arrays
+ (~100x). This was previously the bottleneck since the function is called for every entry in the array.
+ """
+
+ def bb_iou_array(boxes, new_box):
+ # bb interesection over union
+ xA = np.maximum(boxes[:, 0], new_box[0])
+ yA = np.maximum(boxes[:, 1], new_box[1])
+ xB = np.minimum(boxes[:, 2], new_box[2])
+ yB = np.minimum(boxes[:, 3], new_box[3])
+
+ interArea = np.maximum(xB - xA, 0) * np.maximum(yB - yA, 0)
+
+ # compute the area of both the prediction and ground-truth rectangles
+ boxAArea = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
+ boxBArea = (new_box[2] - new_box[0]) * (new_box[3] - new_box[1])
+
+ iou = interArea / (boxAArea + boxBArea - interArea)
+
+ return iou
+
+ if boxes_list.shape[0] == 0:
+ return -1, match_iou
+
+ ious = bb_iou_array(boxes_list[:, 4:], new_box[4:])
+ # ious[boxes[:, 0] != new_box[0]] = -1
+
+ best_idx = np.argmax(ious)
+ best_iou = ious[best_idx]
+
+ if best_iou <= match_iou:
+ best_iou = match_iou
+ best_idx = -1
+
+ return best_idx, best_iou
+
+
+def get_biggest_box(boxes, conf_type="avg"):
+ """
+ Merges boxes by using the biggest box.
+
+ Args:
+ boxes (np array [n x 8]): Boxes to merge.
+ conf_type (str, optional): Confidence merging type. Defaults to "avg".
+
+ Returns:
+ np array [8]: Merged box.
+ """
+ box = np.zeros(8, dtype=np.float32)
+ box[4:] = boxes[0][4:]
+ conf_list = []
+ w = 0
+ for b in boxes:
+ box[4] = min(box[4], b[4])
+ box[5] = min(box[5], b[5])
+ box[6] = max(box[6], b[6])
+ box[7] = max(box[7], b[7])
+ conf_list.append(b[1])
+ w += b[2]
+
+ box[0] = merge_labels(np.array([b[0] for b in boxes]), np.array([b[1] for b in boxes]))
+ # print(box[0], np.array([b[0] for b in boxes]))
+
+ box[1] = np.max(conf_list) if conf_type == "max" else np.mean(conf_list)
+ box[2] = w
+ box[3] = -1 # model index field is retained for consistency but is not used.
+ return box
+
+
+def merge_labels(labels, confs):
+ """
+ Custom function for merging labels.
+ If all labels are the same, return the unique value.
+ Else, return the label of the most confident non-title (class 2) box.
+
+ Args:
+ labels (np array [n]): Labels.
+ confs (np array [n]): Confidence.
+
+ Returns:
+ int: Label.
+ """
+ if len(np.unique(labels)) == 1:
+ return labels[0]
+ else: # Most confident and not a title
+ confs = confs[confs != 2]
+ labels = labels[labels != 2]
+ return labels[np.argmax(confs)]
+
+
+def match_with_title(chart_bbox, title_bboxes, iou_th=0.01):
+ if not len(title_bboxes):
+ return None
+
+ dist_above = np.abs(title_bboxes[:, 3] - chart_bbox[1])
+ dist_below = np.abs(chart_bbox[3] - title_bboxes[:, 1])
+
+ dist_left = np.abs(title_bboxes[:, 0] - chart_bbox[0])
+
+ ious = bb_iou_array(title_bboxes, chart_bbox)
+
+ matches = None
+ if np.max(ious) > iou_th:
+ matches = np.where(ious > iou_th)[0]
+ else:
+ dists = np.min([dist_above, dist_below], 0)
+ dists += dist_left
+ # print(dists)
+ if np.min(dists) < 0.1:
+ matches = [np.argmin(dists)]
+
+ if matches is not None:
+ new_bbox = chart_bbox
+ for match in matches:
+ new_bbox = merge_boxes(new_bbox, title_bboxes[match])
+ title_bboxes = title_bboxes[[i for i in range(len(title_bboxes)) if i not in matches]]
+ return new_bbox, title_bboxes
+
+ else:
+ return None
+
+
+def bb_iou_array(boxes, new_box):
+ # bb interesection over union
+ xA = np.maximum(boxes[:, 0], new_box[0])
+ yA = np.maximum(boxes[:, 1], new_box[1])
+ xB = np.minimum(boxes[:, 2], new_box[2])
+ yB = np.minimum(boxes[:, 3], new_box[3])
+
+ interArea = np.maximum(xB - xA, 0) * np.maximum(yB - yA, 0)
+
+ # compute the area of both the prediction and ground-truth rectangles
+ boxAArea = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
+ boxBArea = (new_box[2] - new_box[0]) * (new_box[3] - new_box[1])
+
+ iou = interArea / (boxAArea + boxBArea - interArea)
+
+ return iou
+
+
+def merge_boxes(b1, b2):
+ b = b1.copy()
+ b[0] = min(b1[0], b2[0])
+ b[1] = min(b1[1], b2[1])
+ b[2] = max(b1[2], b2[2])
+ b[3] = max(b1[3], b2[3])
+ return b
+
+
+def expand_boxes(boxes, r_x=1, r_y=1):
+ dw = (boxes[:, 2] - boxes[:, 0]) / 2 * (r_x - 1)
+ boxes[:, 0] -= dw
+ boxes[:, 2] += dw
+
+ dh = (boxes[:, 3] - boxes[:, 1]) / 2 * (r_y - 1)
+ boxes[:, 1] -= dh
+ boxes[:, 3] += dh
+
+ boxes = np.clip(boxes, 0, 1)
+ return boxes
+
+
+def get_weighted_box(boxes, conf_type="avg"):
+ """
+ Merges boxes by using the weighted fusion.
+
+ Args:
+ boxes (np array [n x 8]): Boxes to merge.
+ conf_type (str, optional): Confidence merging type. Defaults to "avg".
+
+ Returns:
+ np array [8]: Merged box.
+ """
+ box = np.zeros(8, dtype=np.float32)
+ conf = 0
+ conf_list = []
+ w = 0
+ for b in boxes:
+ box[4:] += b[1] * b[4:]
+ conf += b[1]
+ conf_list.append(b[1])
+ w += b[2]
+
+ box[0] = merge_labels(np.array([b[0] for b in boxes]), np.array([b[1] for b in boxes]))
+
+ box[1] = np.max(conf_list) if conf_type == "max" else np.mean(conf_list)
+ box[2] = w
+ box[3] = -1 # model index field is retained for consistency but is not used.
+ box[4:] /= conf
+ return box
+
+
+def batched_overlaps(A, B):
+ """
+ Calculate the Intersection over Union (IoU) between
+ two sets of bounding boxes in a batched manner.
+ Normalization is modified to only use the area of A boxes, hence computing the overlaps.
+ Args:
+ A (ndarray): Array of bounding boxes of shape (N, 4) in format [x1, y1, x2, y2].
+ B (ndarray): Array of bounding boxes of shape (M, 4) in format [x1, y1, x2, y2].
+ Returns:
+ ndarray: Array of IoU values of shape (N, M) representing the overlaps
+ between each pair of bounding boxes.
+ """
+ A = A.copy()
+ B = B.copy()
+
+ A = A[None].repeat(B.shape[0], 0)
+ B = B[:, None].repeat(A.shape[1], 1)
+
+ low = np.s_[..., :2]
+ high = np.s_[..., 2:]
+
+ A, B = A.copy(), B.copy()
+ A[high] += 1
+ B[high] += 1
+
+ intrs = (np.maximum(0, np.minimum(A[high], B[high]) - np.maximum(A[low], B[low]))).prod(-1)
+ ious = intrs / (A[high] - A[low]).prod(-1)
+
+ return ious
+
+
+def find_boxes_inside(boxes, boxes_to_check, threshold=0.9):
+ """
+ Find all boxes that are inside another box based on
+ the intersection area divided by the area of the smaller box,
+ and removes them.
+ """
+ overlaps = batched_overlaps(boxes_to_check, boxes)
+ to_keep = (overlaps >= threshold).sum(0) <= 1
+ return boxes_to_check[to_keep]
+
+
+def get_bbox_dict_yolox_graphic(preds, shape, class_labels, threshold_=0.1) -> Dict[str, np.ndarray]:
+ """
+ Extracts bounding boxes from YOLOX model predictions:
+ - Applies thresholding
+ - Reformats boxes
+ - Cleans the `other` detections: removes the ones that are included in other detections.
+ - If no title is found, the biggest `other` box is used if it is larger than 0.3*img_w.
+ Args:
+ preds (np.ndarray): YOLOX model predictions including bounding boxes, scores, and labels.
+ shape (tuple): Original image shape.
+ threshold_ (float): Score threshold to filter bounding boxes.
+ Returns:
+ Dict[str, np.ndarray]: Dictionary of bounding boxes, organized by class.
+ """
+ bbox_dict = {label: np.array([]) for label in class_labels}
+
+ for i, label in enumerate(class_labels):
+ bboxes_class = np.array(preds[label])
+
+ if bboxes_class.size == 0:
+ continue
+
+ # Try to find a chart_title box
+ threshold = threshold_ if label != "chart_title" else min(threshold_, bboxes_class[:, -1].max())
+ bboxes_class = bboxes_class[bboxes_class[:, -1] >= threshold][:, :4].astype(int)
+
+ sort = ["x0", "y0"] if label != "ylabel" else ["y0", "x0"]
+ idxs = (
+ pd.DataFrame(
+ {
+ "y0": bboxes_class[:, 1],
+ "x0": bboxes_class[:, 0],
+ }
+ )
+ .sort_values(sort, ascending=label != "ylabel")
+ .index
+ )
+ bboxes_class = bboxes_class[idxs]
+ bbox_dict[label] = bboxes_class
+
+ # Remove other included
+ if len(bbox_dict.get("other", [])):
+ other = find_boxes_inside(
+ np.concatenate(list([v for v in bbox_dict.values() if len(v)])), bbox_dict["other"], threshold=0.7
+ )
+ del bbox_dict["other"]
+ if len(other):
+ bbox_dict["other"] = other
+
+ # Biggest other is title if no title
+ if not len(bbox_dict.get("chart_title", [])) and len(bbox_dict.get("other", [])):
+ boxes = bbox_dict["other"]
+ ws = boxes[:, 2] - boxes[:, 0]
+ if np.max(ws) > shape[1] * 0.3:
+ bbox_dict["chart_title"] = boxes[np.argmax(ws)][None].copy()
+ bbox_dict["other"] = np.delete(boxes, (np.argmax(ws)), axis=0)
+
+ # Make sure other key not lost
+ bbox_dict["other"] = bbox_dict.get("other", [])
+
+ return bbox_dict