From 82a8e81459f43a73559994faca0fa4e31a54984f Mon Sep 17 00:00:00 2001 From: edknv Date: Sat, 18 Jan 2025 22:05:49 -0800 Subject: [PATCH 01/28] parent class for yolox base model interface --- src/nv_ingest/util/nim/yolox.py | 150 +++++++++++++++++++++++--------- 1 file changed, 109 insertions(+), 41 deletions(-) diff --git a/src/nv_ingest/util/nim/yolox.py b/src/nv_ingest/util/nim/yolox.py index 831c4e62..bf116138 100644 --- a/src/nv_ingest/util/nim/yolox.py +++ b/src/nv_ingest/util/nim/yolox.py @@ -23,39 +23,58 @@ logger = logging.getLogger(__name__) -YOLOX_MAX_BATCH_SIZE = 8 -YOLOX_MAX_WIDTH = 1536 -YOLOX_MAX_HEIGHT = 1536 -YOLOX_NUM_CLASSES = 3 -YOLOX_CONF_THRESHOLD = 0.01 -YOLOX_IOU_THRESHOLD = 0.5 -YOLOX_MIN_SCORE = 0.1 -YOLOX_FINAL_SCORE = 0.48 -YOLOX_NIM_MAX_IMAGE_SIZE = 360_000 - -YOLOX_IMAGE_PREPROC_HEIGHT = 1024 -YOLOX_IMAGE_PREPROC_WIDTH = 1024 - - -# Implementing YoloxPageElemenetsModelInterface with required methods -class YoloxPageElementsModelInterface(ModelInterface): +YOLOX_PAGE_MAX_BATCH_SIZE = 8 +YOLOX_PAGE_MAX_WIDTH = 1536 +YOLOX_PAGE_MAX_HEIGHT = 1536 +YOLOX_PAGE_NUM_CLASSES = 3 +YOLOX_PAGE_CONF_THRESHOLD = 0.01 +YOLOX_PAGE_IOU_THRESHOLD = 0.5 +YOLOX_PAGE_MIN_SCORE = 0.1 +YOLOX_PAGE_FINAL_SCORE = 0.48 +YOLOX_PAGE_NIM_MAX_IMAGE_SIZE = 360_000 + +YOLOX_PAGE_IMAGE_PREPROC_HEIGHT = 1024 +YOLOX_PAGE_IMAGE_PREPROC_WIDTH = 1024 + +YOLOX_PAGE_CLASS_LABELS = [ + "table", + "chart", + "title", +] + + +# YoloxModelInterfaceBase implements methods that are common to yolox-page-elements and yolox-graphic-elements +class YoloxModelInterfaceBase(ModelInterface): """ An interface for handling inference with a Yolox object detection model, supporting both gRPC and HTTP protocols. """ - def name( + def __init__( self, - ) -> str: + image_preproc_width: Optional[int] = None, + image_preproc_height: Optional[int] = None, + nim_max_image_size: Optional[int] = None, + num_classes: Optional[int] = None, + conf_threshold: Optional[float] = None, + iou_threshold: Optional[float] = None, + min_score: Optional[float] = None, + final_score: Optional[float] = None, + class_labels: Optional[List[str]] = None, + ): """ - Returns the name of the Yolox model interface. - - Returns - ------- - str - The name of the model interface. + Initialize the YOLOX model interface. + Parameters + ---------- """ - - return "yolox-page-elements" + self.image_preproc_width = image_preproc_width + self.image_preproc_height = image_preproc_height + self.nim_max_image_size = nim_max_image_size + self.num_classes = num_classes + self.conf_threshold = conf_threshold + self.iou_threshold = iou_threshold + self.min_score = min_score + self.final_score = final_score + self.class_labels = class_labels def prepare_data_for_inference(self, data: Dict[str, Any]) -> Dict[str, Any]: """ @@ -107,8 +126,9 @@ def format_input(self, data: Dict[str, Any], protocol: str) -> Any: if protocol == "grpc": logger.debug("Formatting input for gRPC Yolox model") # Our yolox-page-elements model (grPC) expects images to be resized to 1024x1024 + # Our yolox-graphic-elements model (grPC) expects images to be resized to 768x768 resized_images = [ - resize_image(image, (YOLOX_IMAGE_PREPROC_WIDTH, YOLOX_IMAGE_PREPROC_HEIGHT)) for image in data["images"] + resize_image(image, (self.image_preproc_width, self.image_preproc_height)) for image in data["images"] ] # Reorder axes to match model input (batch, channels, height, width) input_array = np.einsum("bijk->bkij", resized_images).astype(np.float32) @@ -129,7 +149,7 @@ def format_input(self, data: Dict[str, Any], protocol: str) -> Any: # Now scale the image if necessary scaled_image_b64, new_size = scale_image_to_encoding_size( - image_b64, max_base64_size=YOLOX_NIM_MAX_IMAGE_SIZE + image_b64, max_base64_size=self.nim_max_image_size ) if new_size != original_size: @@ -216,11 +236,6 @@ def process_inference_results(self, output: Any, protocol: str, **kwargs) -> Lis A list of annotation dictionaries for each image in the batch. """ original_image_shapes = kwargs.get("original_image_shapes", []) - num_classes = kwargs.get("num_classes", YOLOX_NUM_CLASSES) - conf_thresh = kwargs.get("conf_thresh", YOLOX_CONF_THRESHOLD) - iou_thresh = kwargs.get("iou_thresh", YOLOX_IOU_THRESHOLD) - min_score = kwargs.get("min_score", YOLOX_MIN_SCORE) - final_thresh = kwargs.get("final_thresh", YOLOX_FINAL_SCORE) if protocol == "http": # For http, the output already has postprocessing applied. Skip to table/chart expansion. @@ -228,11 +243,64 @@ def process_inference_results(self, output: Any, protocol: str, **kwargs) -> Lis elif protocol == "grpc": # For grpc, apply the same NIM postprocessing. - pred = postprocess_model_prediction(output, num_classes, conf_thresh, iou_thresh, class_agnostic=True) - results = postprocess_results(pred, original_image_shapes, min_score=min_score) + pred = postprocess_model_prediction( + output, self.num_classes, self.conf_threshold, self.iou_threshold, class_agnostic=True + ) + results = postprocess_results( + pred, + original_image_shapes, + self.image_preproc_width, + self.image_preproc_height, + min_score=self.min_score, + ) + + inference_results = self.postprocess_annotations(results, **kwargs) + + return inference_results + + def postprocess_annotations(self, annotation_dicts: Dict[str, Any], **kwargs) -> Dict[str, Any]: + raise NotImplementedError() + + +# Implementing YoloxPageElemenetsModelInterface with required methods +class YoloxPageElementsModelInterface(YoloxModelInterfaceBase): + """ + An interface for handling inference with a Yolox object detection model, supporting both gRPC and HTTP protocols. + """ + + def __init__(self): + """ + Initialize the yolox-page-elements model interface. + """ + super().__init__( + image_preproc_width=YOLOX_PAGE_IMAGE_PREPROC_HEIGHT, + image_preproc_height=YOLOX_PAGE_IMAGE_PREPROC_HEIGHT, + nim_max_image_size=YOLOX_PAGE_NIM_MAX_IMAGE_SIZE, + num_classes=YOLOX_PAGE_NUM_CLASSES, + conf_threshold=YOLOX_PAGE_CONF_THRESHOLD, + iou_threshold=YOLOX_PAGE_IOU_THRESHOLD, + min_score=YOLOX_PAGE_MIN_SCORE, + final_score=YOLOX_PAGE_FINAL_SCORE, + class_labels=YOLOX_PAGE_CLASS_LABELS, + ) + + def name( + self, + ) -> str: + """ + Returns the name of the Yolox model interface. + + Returns + ------- + str + The name of the model interface. + """ + + return "yolox-page-elements" + def postprocess_annotations(self, annotation_dicts: Dict[str, Any], **kwargs) -> Dict[str, Any]: # Table/chart expansion is "business logic" specific to nv-ingest - annotation_dicts = [expand_table_bboxes(annotation_dict) for annotation_dict in results] + annotation_dicts = [expand_table_bboxes(annotation_dict) for annotation_dict in annotation_dicts] annotation_dicts = [expand_chart_bboxes(annotation_dict) for annotation_dict in annotation_dicts] inference_results = [] @@ -241,9 +309,9 @@ def process_inference_results(self, output: Any, protocol: str, **kwargs) -> Lis for annotation_dict in annotation_dicts: new_dict = {} if "table" in annotation_dict: - new_dict["table"] = [bb for bb in annotation_dict["table"] if bb[4] >= final_thresh] + new_dict["table"] = [bb for bb in annotation_dict["table"] if bb[4] >= self.final_score] if "chart" in annotation_dict: - new_dict["chart"] = [bb for bb in annotation_dict["chart"] if bb[4] >= final_thresh] + new_dict["chart"] = [bb for bb in annotation_dict["chart"] if bb[4] >= self.final_score] if "title" in annotation_dict: new_dict["title"] = annotation_dict["title"] inference_results.append(new_dict) @@ -310,7 +378,7 @@ def postprocess_model_prediction(prediction, num_classes, conf_thre=0.7, nms_thr return output -def postprocess_results(results, original_image_shapes, min_score=0.0): +def postprocess_results(results, original_image_shapes, image_preproc_width, image_preproc_height, min_score=0.0): """ For each item (==image) in results, computes annotations in the form @@ -339,8 +407,8 @@ def postprocess_results(results, original_image_shapes, min_score=0.0): # ratio is used when image was padded ratio = min( - YOLOX_IMAGE_PREPROC_WIDTH / original_image_shape[0], - YOLOX_IMAGE_PREPROC_HEIGHT / original_image_shape[1], + image_preproc_width / original_image_shape[0], + image_preproc_height / original_image_shape[1], ) bboxes = result[:, :4] / ratio From 0699844a8053282f3e954c4c191b43707416ed54 Mon Sep 17 00:00:00 2001 From: edknv Date: Sat, 18 Jan 2025 22:42:07 -0800 Subject: [PATCH 02/28] add model interface for yolox-graphic-elements --- .../image/image_handlers.py | 2 +- .../extraction_workflows/pdf/pdfium_helper.py | 6 +- src/nv_ingest/util/nim/yolox.py | 216 +++++++++++++++++- 3 files changed, 213 insertions(+), 11 deletions(-) diff --git a/src/nv_ingest/extraction_workflows/image/image_handlers.py b/src/nv_ingest/extraction_workflows/image/image_handlers.py index f7b12982..66252ef4 100644 --- a/src/nv_ingest/extraction_workflows/image/image_handlers.py +++ b/src/nv_ingest/extraction_workflows/image/image_handlers.py @@ -159,7 +159,7 @@ def extract_table_and_chart_images( objects = annotation_dict[label] for idx, bboxes in enumerate(objects): *bbox, _ = bboxes - h1, w1, h2, w2 = np.array(bbox) * np.array([height, width, height, width]) + h1, w1, h2, w2 = bbox base64_img = crop_image(original_image, (int(h1), int(w1), int(h2), int(w2))) diff --git a/src/nv_ingest/extraction_workflows/pdf/pdfium_helper.py b/src/nv_ingest/extraction_workflows/pdf/pdfium_helper.py index ad4de2d6..a717442e 100644 --- a/src/nv_ingest/extraction_workflows/pdf/pdfium_helper.py +++ b/src/nv_ingest/extraction_workflows/pdf/pdfium_helper.py @@ -170,15 +170,15 @@ def extract_table_and_chart_images( objects = annotation_dict[label] for idx, bboxes in enumerate(objects): *bbox, _ = bboxes - h1, w1, h2, w2 = bbox * np.array([height, width, height, width]) + h1, w1, h2, w2 = bbox - cropped = crop_image(original_image, (h1, w1, h2, w2)) + cropped = crop_image(original_image, (int(h1), int(w1), int(h2), int(w2))) base64_img = numpy_to_base64(cropped) table_data = CroppedImageWithContent( content="", image=base64_img, - bbox=(w1, h1, w2, h2), + bbox=(int(w1), int(h1), int(w2), int(h2)), max_width=width, max_height=height, type_string=label, diff --git a/src/nv_ingest/util/nim/yolox.py b/src/nv_ingest/util/nim/yolox.py index bf116138..ebfdd31a 100644 --- a/src/nv_ingest/util/nim/yolox.py +++ b/src/nv_ingest/util/nim/yolox.py @@ -14,6 +14,7 @@ import cv2 import numpy as np +import pandas as pd import torch import torchvision from PIL import Image @@ -23,9 +24,7 @@ logger = logging.getLogger(__name__) -YOLOX_PAGE_MAX_BATCH_SIZE = 8 -YOLOX_PAGE_MAX_WIDTH = 1536 -YOLOX_PAGE_MAX_HEIGHT = 1536 +# yolox-page-elements-v1 contants YOLOX_PAGE_NUM_CLASSES = 3 YOLOX_PAGE_CONF_THRESHOLD = 0.01 YOLOX_PAGE_IOU_THRESHOLD = 0.5 @@ -42,6 +41,30 @@ "title", ] +# yolox-graphic-elements-v1 contants +YOLOX_GRAPHIC_NUM_CLASSES = 10 +YOLOX_GRAPHIC_CONF_THRESHOLD = 0.01 +YOLOX_GRAPHIC_IOU_THRESHOLD = 0.25 +YOLOX_GRAPHIC_MIN_SCORE = 0.1 +YOLOX_GRAPHIC_FINAL_SCORE = 0.0 +YOLOX_GRAPHIC_NIM_MAX_IMAGE_SIZE = 360_000 + +YOLOX_GRAPHIC_IMAGE_PREPROC_HEIGHT = 768 +YOLOX_GRAPHIC_IMAGE_PREPROC_WIDTH = 768 + +YOLOX_GRAPHIC_CLASS_LABELS = [ + "chart_title", + "x_title", + "y_title", + "xlabel", + "ylabel", + "other", + "legend_label", + "legend_title", + "mark_label", + "value_label", +] + # YoloxModelInterfaceBase implements methods that are common to yolox-page-elements and yolox-graphic-elements class YoloxModelInterfaceBase(ModelInterface): @@ -255,17 +278,37 @@ def process_inference_results(self, output: Any, protocol: str, **kwargs) -> Lis ) inference_results = self.postprocess_annotations(results, **kwargs) + inference_results = self.transform_normalized_coordinates_to_original(inference_results, original_image_shapes) return inference_results - def postprocess_annotations(self, annotation_dicts: Dict[str, Any], **kwargs) -> Dict[str, Any]: + def postprocess_annotations(self, annotation_dicts: List[Dict[str, Any]], **kwargs) -> List[Dict[str, Any]]: raise NotImplementedError() + def transform_normalized_coordinates_to_original(self, results, original_image_shapes): + """ """ + transformed_results = [] + + for annotation_dict, shape in zip(results, original_image_shapes): + new_dict = {} + for label, bboxes_and_scores in annotation_dict.items(): + new_dict[label] = [] + for *bbox, score in bboxes_and_scores: + transformed_bbox = [ + bbox[0] * shape[1], + bbox[1] * shape[0], + bbox[2] * shape[1], + bbox[3] * shape[0], + ] + new_dict[label].append(transformed_bbox + [score]) + transformed_results.append(new_dict) + + return transformed_results + -# Implementing YoloxPageElemenetsModelInterface with required methods class YoloxPageElementsModelInterface(YoloxModelInterfaceBase): """ - An interface for handling inference with a Yolox object detection model, supporting both gRPC and HTTP protocols. + An interface for handling inference with yolox-page-elements model, supporting both gRPC and HTTP protocols. """ def __init__(self): @@ -298,7 +341,7 @@ def name( return "yolox-page-elements" - def postprocess_annotations(self, annotation_dicts: Dict[str, Any], **kwargs) -> Dict[str, Any]: + def postprocess_annotations(self, annotation_dicts: List[Dict[str, Any]], **kwargs) -> List[Dict[str, Any]]: # Table/chart expansion is "business logic" specific to nv-ingest annotation_dicts = [expand_table_bboxes(annotation_dict) for annotation_dict in annotation_dicts] annotation_dicts = [expand_chart_bboxes(annotation_dict) for annotation_dict in annotation_dicts] @@ -319,6 +362,60 @@ def postprocess_annotations(self, annotation_dicts: Dict[str, Any], **kwargs) -> return inference_results +class YoloxGraphicElementsModelInterface(YoloxModelInterfaceBase): + """ + An interface for handling inference with yolox-graphic-elemenents model, supporting both gRPC and HTTP protocols. + """ + + def __init__(self): + """ + Initialize the yolox-graphic-elements model interface. + """ + super().__init__( + image_preproc_width=YOLOX_GRAPHIC_IMAGE_PREPROC_HEIGHT, + image_preproc_height=YOLOX_GRAPHIC_IMAGE_PREPROC_HEIGHT, + nim_max_image_size=YOLOX_GRAPHIC_NIM_MAX_IMAGE_SIZE, + num_classes=YOLOX_GRAPHIC_NUM_CLASSES, + conf_threshold=YOLOX_GRAPHIC_CONF_THRESHOLD, + iou_threshold=YOLOX_GRAPHIC_IOU_THRESHOLD, + min_score=YOLOX_GRAPHIC_MIN_SCORE, + final_score=YOLOX_GRAPHIC_FINAL_SCORE, + class_labels=YOLOX_GRAPHIC_CLASS_LABELS, + ) + + def name( + self, + ) -> str: + """ + Returns the name of the Yolox model interface. + + Returns + ------- + str + The name of the model interface. + """ + + return "yolox-graphic-elements" + + def postprocess_annotations(self, annotation_dicts: List[Dict[str, Any]], **kwargs) -> List[Dict[str, Any]]: + original_image_shapes = kwargs.get("original_image_shapes", []) + + inference_results = [] + + # bbox extraction: additional postprocessing speicifc to nv-ingest + for pred, shape in zip(annotation_dicts, original_image_shapes): + inference_results.append( + get_bbox_dict_yolox_graphic( + pred, + shape, + self.class_labels, + self.min_score, + ) + ) + + return inference_results + + def postprocess_model_prediction(prediction, num_classes, conf_thre=0.7, nms_thre=0.45, class_agnostic=False): # Convert numpy array to torch tensor prediction = torch.from_numpy(prediction.copy()) @@ -926,3 +1023,108 @@ def get_weighted_box(boxes, conf_type="avg"): box[3] = -1 # model index field is retained for consistency but is not used. box[4:] /= conf return box + + +def batched_overlaps(A, B): + """ + Calculate the Intersection over Union (IoU) between + two sets of bounding boxes in a batched manner. + Normalization is modified to only use the area of A boxes, hence computing the overlaps. + Args: + A (ndarray): Array of bounding boxes of shape (N, 4) in format [x1, y1, x2, y2]. + B (ndarray): Array of bounding boxes of shape (M, 4) in format [x1, y1, x2, y2]. + Returns: + ndarray: Array of IoU values of shape (N, M) representing the overlaps + between each pair of bounding boxes. + """ + A = A.copy() + B = B.copy() + + A = A[None].repeat(B.shape[0], 0) + B = B[:, None].repeat(A.shape[1], 1) + + low = np.s_[..., :2] + high = np.s_[..., 2:] + + A, B = A.copy(), B.copy() + A[high] += 1 + B[high] += 1 + + intrs = (np.maximum(0, np.minimum(A[high], B[high]) - np.maximum(A[low], B[low]))).prod(-1) + ious = intrs / (A[high] - A[low]).prod(-1) + + return ious + + +def find_boxes_inside(boxes, boxes_to_check, threshold=0.9): + """ + Find all boxes that are inside another box based on + the intersection area divided by the area of the smaller box, + and removes them. + """ + overlaps = batched_overlaps(boxes_to_check, boxes) + to_keep = (overlaps >= threshold).sum(0) <= 1 + return boxes_to_check[to_keep] + + +def get_bbox_dict_yolox_graphic(preds, shape, class_labels, threshold_=0.1) -> Dict[str, np.ndarray]: + """ + Extracts bounding boxes from YOLOX model predictions: + - Applies thresholding + - Reformats boxes + - Cleans the `other` detections: removes the ones that are included in other detections. + - If no title is found, the biggest `other` box is used if it is larger than 0.3*img_w. + Args: + preds (np.ndarray): YOLOX model predictions including bounding boxes, scores, and labels. + shape (tuple): Original image shape. + threshold_ (float): Score threshold to filter bounding boxes. + Returns: + Dict[str, np.ndarray]: Dictionary of bounding boxes, organized by class. + """ + bbox_dict = {label: np.array([]) for label in class_labels} + + for i, label in enumerate(class_labels): + bboxes_class = np.array(preds[label]) + + if bboxes_class.size == 0: + continue + + # Try to find a chart_title box + threshold = threshold_ if label != "chart_title" else min(threshold_, bboxes_class[:, -1].max()) + bboxes_class = bboxes_class[bboxes_class[:, -1] >= threshold][:, :4].astype(int) + + sort = ["x0", "y0"] if label != "ylabel" else ["y0", "x0"] + idxs = ( + pd.DataFrame( + { + "y0": bboxes_class[:, 1], + "x0": bboxes_class[:, 0], + } + ) + .sort_values(sort, ascending=label != "ylabel") + .index + ) + bboxes_class = bboxes_class[idxs] + bbox_dict[label] = bboxes_class + + # Remove other included + if len(bbox_dict.get("other", [])): + other = find_boxes_inside( + np.concatenate(list([v for v in bbox_dict.values() if len(v)])), bbox_dict["other"], threshold=0.7 + ) + del bbox_dict["other"] + if len(other): + bbox_dict["other"] = other + + # Biggest other is title if no title + if not len(bbox_dict.get("chart_title", [])) and len(bbox_dict.get("other", [])): + boxes = bbox_dict["other"] + ws = boxes[:, 2] - boxes[:, 0] + if np.max(ws) > shape[1] * 0.3: + bbox_dict["chart_title"] = boxes[np.argmax(ws)][None].copy() + bbox_dict["other"] = np.delete(boxes, (np.argmax(ws)), axis=0) + + # Make sure other key not lost + bbox_dict["other"] = bbox_dict.get("other", []) + + return bbox_dict From a4a931787ab95baef63909e1d2acad625cf879fd Mon Sep 17 00:00:00 2001 From: edknv Date: Sun, 19 Jan 2025 01:51:50 -0800 Subject: [PATCH 03/28] replace cached with yolox --- docker-compose.yaml | 44 +--- src/nv_ingest/api/v1/health.py | 20 +- .../schemas/chart_extractor_schema.py | 20 +- src/nv_ingest/stages/nim/chart_extraction.py | 110 +++++---- src/nv_ingest/stages/nim/table_extraction.py | 36 ++- .../util/image_processing/table_and_chart.py | 217 +++++++++++++----- src/nv_ingest/util/nim/helpers.py | 23 +- src/nv_ingest/util/nim/paddle.py | 155 ++++--------- src/nv_ingest/util/nim/yolox.py | 42 ++-- src/nv_ingest/util/pipeline/stage_builders.py | 12 +- 10 files changed, 366 insertions(+), 313 deletions(-) diff --git a/docker-compose.yaml b/docker-compose.yaml index 8d9c307a..13ec904a 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -38,8 +38,8 @@ services: capabilities: [gpu] runtime: nvidia - deplot: - image: ${DEPLOT_IMAGE:-nvcr.io/ohlfw0olaadg/ea-participants/deplot}:${DEPLOT_TAG:-1.0.0} + yolox-graphic-elements: + image: ${YOLOX_GRAPHIC_ELEMENTS_IMAGE:-nvcr.io/ohlfw0olaadg/ea-participants/nv-yolox-graphic-elements-v1}:${YOLOX_GRAPHIC_ELEMENTS_TAG:-1.1.0} ports: - "8003:8000" - "8004:8001" @@ -48,7 +48,7 @@ services: environment: - NIM_HTTP_API_PORT=8000 - NIM_TRITON_LOG_VERBOSE=1 - - NGC_API_KEY=${NIM_NGC_API_KEY:-${NGC_API_KEY:-ngcapikey}} + - NGC_API_KEY=${STAGING_NIM_NGC_API_KEY} - CUDA_VISIBLE_DEVICES=0 deploy: resources: @@ -59,28 +59,6 @@ services: capabilities: [gpu] runtime: nvidia - cached: - image: ${CACHED_IMAGE:-nvcr.io/ohlfw0olaadg/ea-participants/cached}:${CACHED_TAG:-0.2.1} - shm_size: 2gb - ports: - - "8006:8000" - - "8007:8001" - - "8008:8002" - user: root - environment: - - NIM_HTTP_API_PORT=8000 - - NIM_TRITON_LOG_VERBOSE=1 - - NGC_API_KEY=${NIM_NGC_API_KEY:-${NGC_API_KEY:-ngcapikey}} - - CUDA_VISIBLE_DEVICES=0 - deploy: - resources: - reservations: - devices: - - driver: nvidia - device_ids: ["1"] - capabilities: [gpu] - runtime: nvidia - paddle: image: ${PADDLE_IMAGE:-nvcr.io/ohlfw0olaadg/ea-participants/paddleocr}:${PADDLE_TAG:-1.0.0} shm_size: 2gb @@ -141,20 +119,7 @@ services: cap_add: - sys_nice environment: - # Self-hosted cached endpoints. - - CACHED_GRPC_ENDPOINT=cached:8001 - - CACHED_HTTP_ENDPOINT=http://cached:8000/v1/infer - - CACHED_INFER_PROTOCOL=grpc - # build.nvidia.com hosted cached endpoints. - #- CACHED_HTTP_ENDPOINT=https://ai.api.nvidia.com/v1/cv/university-at-buffalo/cached - #- CACHED_INFER_PROTOCOL=http - CUDA_VISIBLE_DEVICES=0 - #- DEPLOT_GRPC_ENDPOINT="" - # Self-hosted deplot endpoints. - - DEPLOT_HTTP_ENDPOINT=http://deplot:8000/v1/chat/completions - # build.nvidia.com hosted deplot - #- DEPLOT_HTTP_ENDPOINT=https://ai.api.nvidia.com/v1/vlm/google/deplot - - DEPLOT_INFER_PROTOCOL=http - DOUGHNUT_GRPC_TRITON=triton-doughnut:8001 - EMBEDDING_NIM_MODEL_NAME=${EMBEDDING_NIM_MODEL_NAME:-nvidia/nv-embedqa-e5-v5} - INGEST_LOG_LEVEL=DEFAULT @@ -187,6 +152,9 @@ services: # build.nvidia.com hosted yolox endpoints. #- YOLOX_HTTP_ENDPOINT=https://ai.api.nvidia.com/v1/cv/nvidia/nv-yolox-page-elements-v1 #- YOLOX_INFER_PROTOCOL=http + - YOLOX_GRAPHIC_ELEMENTS_GRPC_ENDPOINT=yolox-graphic-elements:8001 + - YOLOX_GRAPHIC_ELEMENTS_HTTP_ENDPOINT=http://yolox-graphic-elements:8000/v1/infer + - YOLOX_GRAPHIC_ELEMENTS_INFER_PROTOCOL=grpc - VLM_CAPTION_ENDPOINT=https://ai.api.nvidia.com/v1/gr/meta/llama-3.2-90b-vision-instruct/chat/completions healthcheck: test: curl --fail http://nv-ingest-ms-runtime:7670/v1/health/ready || exit 1 diff --git a/src/nv_ingest/api/v1/health.py b/src/nv_ingest/api/v1/health.py index 7a50dba2..5af30dc3 100644 --- a/src/nv_ingest/api/v1/health.py +++ b/src/nv_ingest/api/v1/health.py @@ -70,20 +70,26 @@ async def get_ready_state() -> dict: # We give the users an option to disable checking all distributed services for "readiness" check_all_components = os.getenv("READY_CHECK_ALL_COMPONENTS", "True").lower() if check_all_components in ["1", "true", "yes"]: - yolox_ready = is_ready(os.getenv("YOLOX_HTTP_ENDPOINT", None), "/v1/health/ready") - deplot_ready = is_ready(os.getenv("DEPLOT_HTTP_ENDPOINT", None), "/v1/health/ready") - cached_ready = is_ready(os.getenv("CACHED_HTTP_ENDPOINT", None), "/v1/health/ready") + yolox_page_elements_ready = is_ready(os.getenv("YOLOX_HTTP_ENDPOINT", None), "/v1/health/ready") + yolox_graphic_elements_ready = is_ready( + os.getenv("YOLOX_GRAPHIC_ELEMENTS_HTTP_ENDPOINT", None), "/v1/health/ready" + ) paddle_ready = is_ready(os.getenv("PADDLE_HTTP_ENDPOINT", None), "/v1/health/ready") - if ingest_ready and morpheus_pipeline_ready and yolox_ready and deplot_ready and cached_ready and paddle_ready: + if ( + ingest_ready + and morpheus_pipeline_ready + and yolox_page_elements_ready + and yolox_graphic_elements_ready + and paddle_ready + ): return JSONResponse(content={"ready": True}, status_code=200) else: ready_statuses = { "ingest_ready": ingest_ready, "morpheus_pipeline_ready": morpheus_pipeline_ready, - "yolox_ready": yolox_ready, - "deplot_ready": deplot_ready, - "cached_ready": cached_ready, + "yolox_page_elemenst_ready": yolox_page_elements_ready, + "yolox_graphic_elements_ready": yolox_graphic_elements_ready, "paddle_ready": paddle_ready, } logger.debug(f"Ready Statuses: {ready_statuses}") diff --git a/src/nv_ingest/schemas/chart_extractor_schema.py b/src/nv_ingest/schemas/chart_extractor_schema.py index 05714bbc..4cf8b2b3 100644 --- a/src/nv_ingest/schemas/chart_extractor_schema.py +++ b/src/nv_ingest/schemas/chart_extractor_schema.py @@ -20,12 +20,8 @@ class ChartExtractorConfigSchema(BaseModel): auth_token : Optional[str], default=None Authentication token required for secure services. - cached_endpoints : Tuple[Optional[str], Optional[str]], default=(None, None) - A tuple containing the gRPC and HTTP services for the cached endpoint. - Either the gRPC or HTTP service can be empty, but not both. - - deplot_endpoints : Tuple[Optional[str], Optional[str]], default=(None, None) - A tuple containing the gRPC and HTTP services for the deplot endpoint. + yolox_endpoints : Tuple[Optional[str], Optional[str]], default=(None, None) + A tuple containing the gRPC and HTTP services for the yolox endpoint. Either the gRPC or HTTP service can be empty, but not both. paddle_endpoints : Tuple[Optional[str], Optional[str]], default=(None, None) @@ -50,13 +46,9 @@ class ChartExtractorConfigSchema(BaseModel): auth_token: Optional[str] = None - cached_endpoints: Tuple[Optional[str], Optional[str]] = (None, None) - cached_infer_protocol: str = "" - - deplot_endpoints: Tuple[Optional[str], Optional[str]] = (None, None) - deplot_infer_protocol: str = "" + yolox_endpoints: Tuple[Optional[str], Optional[str]] = (None, None) + yolox_infer_protocol: str = "" - ## NOTE: Paddle isn't currently called independently of the cached NIM, but will be in the future. paddle_endpoints: Tuple[Optional[str], Optional[str]] = (None, None) paddle_infer_protocol: str = "" @@ -91,7 +83,7 @@ def clean_service(service): return None return service - for endpoint_name in ["cached_endpoints", "deplot_endpoints", "paddle_endpoints"]: + for endpoint_name in ["yolox_endpoints", "paddle_endpoints"]: grpc_service, http_service = values.get(endpoint_name, (None, None)) grpc_service = clean_service(grpc_service) http_service = clean_service(http_service) @@ -122,7 +114,7 @@ class ChartExtractorSchema(BaseModel): A flag indicating whether to raise an exception if a failure occurs during chart extraction. stage_config : Optional[ChartExtractorConfigSchema], default=None - Configuration for the chart extraction stage, including cached, deplot, and paddle service endpoints. + Configuration for the chart extraction stage, including yolox and paddle service endpoints. """ max_queue_size: int = 1 diff --git a/src/nv_ingest/stages/nim/chart_extraction.py b/src/nv_ingest/stages/nim/chart_extraction.py index 3890c62e..0447f145 100644 --- a/src/nv_ingest/stages/nim/chart_extraction.py +++ b/src/nv_ingest/stages/nim/chart_extraction.py @@ -14,17 +14,19 @@ from nv_ingest.schemas.chart_extractor_schema import ChartExtractorSchema from nv_ingest.stages.multiprocessing_stage import MultiProcessingBaseStage -from nv_ingest.util.image_processing.table_and_chart import join_cached_and_deplot_output -from nv_ingest.util.nim.cached import CachedModelInterface -from nv_ingest.util.nim.deplot import DeplotModelInterface -from nv_ingest.util.nim.helpers import create_inference_client +from nv_ingest.util.image_processing.table_and_chart import join_yolox_and_paddle_output +from nv_ingest.util.image_processing.table_and_chart import process_yolox_graphic_elements +from nv_ingest.util.image_processing.transforms import base64_to_numpy from nv_ingest.util.nim.helpers import NimClient +from nv_ingest.util.nim.helpers import create_inference_client +from nv_ingest.util.nim.paddle import PaddleOCRModelInterface +from nv_ingest.util.nim.yolox import YoloxGraphicElementsModelInterface logger = logging.getLogger(f"morpheus.{__name__}") # Modify the _update_metadata function -def _update_metadata(row: pd.Series, cached_client: NimClient, deplot_client: NimClient, trace_info: Dict) -> Dict: +def _update_metadata(row: pd.Series, yolox_client: NimClient, paddle_client: NimClient, trace_info: Dict) -> Dict: """ Modifies the metadata of a row if the conditions for chart extraction are met. @@ -33,11 +35,11 @@ def _update_metadata(row: pd.Series, cached_client: NimClient, deplot_client: Ni row : pd.Series A row from the DataFrame containing metadata for the chart extraction. - cached_client : NimClient - The client used to call the cached inference model. + yolox_client : NimClient + The client used to call the yolox inference model. - deplot_client : NimClient - The client used to call the deplot inference model. + paddle_client : NimClient + The client used to call the paddle inference model. trace_info : Dict Trace information used for logging or debugging. @@ -72,23 +74,47 @@ def _update_metadata(row: pd.Series, cached_client: NimClient, deplot_client: Ni # Modify chart metadata with the result from the inference models try: - data = {"base64_image": base64_image} + base64_data = {"base64_image": base64_image} + array_data = {"images": [base64_to_numpy(base64_image)]} # Perform inference using the NimClients - deplot_result = deplot_client.infer( - data, - model_name="deplot", - trace_info=trace_info, # traceable_func arg + yolox_result = yolox_client.infer( + array_data, + model_name="yolox", stage_name="chart_data_extraction", # traceable_func arg + trace_info=trace_info, # traceable_func arg ) - cached_result = cached_client.infer( - data, - model_name="cached", + paddle_result = paddle_client.infer( + base64_data, + model_name="paddle", stage_name="chart_data_extraction", # traceable_func arg trace_info=trace_info, # traceable_func arg ) - chart_content = join_cached_and_deplot_output(cached_result, deplot_result) + ### + source_name = metadata.get("source_metadata", {}).get("source_name").split("/")[-1] + page_number = metadata.get("content_metadata", {}).get("page_number") + + out_path = f"/workspace/data/tmp/{source_name}.{page_number}.yolox.pickle" + + import os + import pickle + + if os.path.exists(out_path): + with open(out_path, "rb") as f: + data = pickle.load(f) + else: + data = {"data": []} + + data["data"].append({"base64_image": base64_image, "yolox": yolox_result[0], "paddle": paddle_result}) + + with open(out_path, "wb") as f: + pickle.dump(data, f) + ### + + text_predictions, bounding_boxes = paddle_result + yolox_elements = join_yolox_and_paddle_output(yolox_result[0], text_predictions, bounding_boxes) + chart_content = process_yolox_graphic_elements(yolox_elements) chart_metadata["table_content"] = chart_content except Exception as e: @@ -99,32 +125,32 @@ def _update_metadata(row: pd.Series, cached_client: NimClient, deplot_client: Ni def _create_clients( - cached_endpoints: Tuple[str, str], - cached_protocol: str, - deplot_endpoints: Tuple[str, str], - deplot_protocol: str, + yolox_endpoints: Tuple[str, str], + yolox_protocol: str, + paddle_endpoints: Tuple[str, str], + paddle_protocol: str, auth_token: str, ) -> Tuple[NimClient, NimClient]: - cached_model_interface = CachedModelInterface() - deplot_model_interface = DeplotModelInterface() + yolox_model_interface = YoloxGraphicElementsModelInterface() + paddle_model_interface = PaddleOCRModelInterface() - logger.debug(f"Inference protocols: cached={cached_protocol}, deplot={deplot_protocol}") + logger.debug(f"Inference protocols: yolox={yolox_protocol}, paddle={paddle_protocol}") - cached_client = create_inference_client( - endpoints=cached_endpoints, - model_interface=cached_model_interface, + yolox_client = create_inference_client( + endpoints=yolox_endpoints, + model_interface=yolox_model_interface, auth_token=auth_token, - infer_protocol=cached_protocol, + infer_protocol=yolox_protocol, ) - deplot_client = create_inference_client( - endpoints=deplot_endpoints, - model_interface=deplot_model_interface, + paddle_client = create_inference_client( + endpoints=paddle_endpoints, + model_interface=paddle_model_interface, auth_token=auth_token, - infer_protocol=deplot_protocol, + infer_protocol=paddle_protocol, ) - return cached_client, deplot_client + return yolox_client, paddle_client def _extract_chart_data( @@ -168,11 +194,11 @@ def _extract_chart_data( return df, trace_info stage_config = validated_config.stage_config - cached_client, deplot_client = _create_clients( - stage_config.cached_endpoints, - stage_config.cached_infer_protocol, - stage_config.deplot_endpoints, - stage_config.deplot_infer_protocol, + yolox_client, paddle_client = _create_clients( + stage_config.yolox_endpoints, + stage_config.yolox_infer_protocol, + stage_config.paddle_endpoints, + stage_config.paddle_infer_protocol, stage_config.auth_token, ) @@ -182,7 +208,7 @@ def _extract_chart_data( try: # Apply the _update_metadata function to each row in the DataFrame - df["metadata"] = df.apply(_update_metadata, axis=1, args=(cached_client, deplot_client, trace_info)) + df["metadata"] = df.apply(_update_metadata, axis=1, args=(yolox_client, paddle_client, trace_info)) return df, {"trace_info": trace_info} @@ -190,8 +216,8 @@ def _extract_chart_data( logger.error("Error occurred while extracting chart data.", exc_info=True) raise finally: - cached_client.close() - deplot_client.close() + yolox_client.close() + paddle_client.close() def generate_chart_extractor_stage( diff --git a/src/nv_ingest/stages/nim/table_extraction.py b/src/nv_ingest/stages/nim/table_extraction.py index dd803af1..32554601 100644 --- a/src/nv_ingest/stages/nim/table_extraction.py +++ b/src/nv_ingest/stages/nim/table_extraction.py @@ -10,16 +10,16 @@ from typing import Tuple import pandas as pd - from morpheus.config import Config +from nv_ingest.schemas.metadata_schema import TableFormatEnum from nv_ingest.schemas.table_extractor_schema import TableExtractorSchema from nv_ingest.stages.multiprocessing_stage import MultiProcessingBaseStage +from nv_ingest.util.image_processing.table_and_chart import convert_paddle_response_to_psuedo_markdown from nv_ingest.util.image_processing.transforms import base64_to_numpy from nv_ingest.util.image_processing.transforms import check_numpy_image_size -from nv_ingest.util.nim.helpers import create_inference_client from nv_ingest.util.nim.helpers import NimClient -from nv_ingest.util.nim.helpers import get_version +from nv_ingest.util.nim.helpers import create_inference_client from nv_ingest.util.nim.paddle import PaddleOCRModelInterface logger = logging.getLogger(f"morpheus.{__name__}") @@ -71,6 +71,8 @@ def _update_metadata(row: pd.Series, paddle_client: NimClient, trace_info: Dict) ): return metadata + table_content_format = table_metadata.get("table_content_format") + # Modify table metadata with the result from the inference model try: data = {"base64_image": base64_image} @@ -83,12 +85,20 @@ def _update_metadata(row: pd.Series, paddle_client: NimClient, trace_info: Dict) paddle_result = paddle_client.infer( data, model_name="paddle", - table_content_format=table_metadata.get("table_content_format"), + table_content_format=table_content_format, trace_info=trace_info, # traceable_func arg stage_name="table_data_extraction", # traceable_func arg ) - table_content, table_content_format = paddle_result + text_predictions, bounding_boxes = paddle_result + + if table_content_format == TableFormatEnum.SIMPLE: + table_content = " ".join(text_predictions) + elif table_content_format == TableFormatEnum.PSEUDO_MARKDOWN: + table_content = convert_paddle_response_to_psuedo_markdown(text_predictions, bounding_boxes) + else: + raise ValueError(f"Unexpected table format: {table_content_format}") + table_metadata["table_content"] = table_content table_metadata["table_content_format"] = table_content_format except Exception as e: @@ -140,20 +150,8 @@ def _extract_table_data( stage_config = validated_config.stage_config - # Obtain paddle_version - # Assuming that the grpc endpoint is at index 0 - paddle_endpoint = stage_config.paddle_endpoints[1] - try: - paddle_version = get_version(paddle_endpoint) - if not paddle_version: - logger.warning("Failed to obtain PaddleOCR version from the endpoint. Falling back to the latest version.") - paddle_version = None # Default to the latest version - except Exception: - logger.warning("Failed to get PaddleOCR version after 30 seconds. Falling back to the latest verrsion.") - paddle_version = None # Default to the latest version - - # Create the PaddleOCRModelInterface with paddle_version - paddle_model_interface = PaddleOCRModelInterface(paddle_version=paddle_version) + # Create the PaddleOCRModelInterface + paddle_model_interface = PaddleOCRModelInterface() # Create the NimClient for PaddleOCR paddle_client = create_inference_client( diff --git a/src/nv_ingest/util/image_processing/table_and_chart.py b/src/nv_ingest/util/image_processing/table_and_chart.py index 3a016aab..e6996c54 100644 --- a/src/nv_ingest/util/image_processing/table_and_chart.py +++ b/src/nv_ingest/util/image_processing/table_and_chart.py @@ -3,77 +3,184 @@ # SPDX-License-Identifier: Apache-2.0 -import json import logging +import re + +import numpy as np +import pandas as pd +from sklearn.cluster import DBSCAN + logger = logging.getLogger(__name__) -def join_cached_and_deplot_output(cached_text, deplot_text): +def process_yolox_graphic_elements(yolox_text_dict): """ - Process the inference results from cached and deplot models. + Process the inference results from yolox-graphic-elements model. Parameters ---------- - cached_text : str - The result from the cached model inference, expected to be a JSON string or plain text. - deplot_text : str - The result from the deplot model inference, expected to be plain text. + yolox_text : str + The result from the yolox model inference. Returns ------- str The concatenated and processed chart content as a string. - - Notes - ----- - This function attempts to parse the `cached_text` as JSON to extract specific fields. - If parsing fails, it falls back to using the raw `cached_text`. The `deplot_text` is then - appended to this content. - - Examples - -------- - >>> cached_text = '{"chart_title": "Sales Over Time"}' - >>> deplot_text = "This chart shows the sales over time." - >>> result = join_cached_and_deplot_output(cached_text, deplot_text) - >>> print(result) - "Sales Over Time This chart shows the sales over time." """ chart_content = "" - if cached_text is not None: - try: - if isinstance(cached_text, str): - cached_text_dict = json.loads(cached_text) - elif isinstance(cached_text, dict): - cached_text_dict = cached_text - else: - cached_text_dict = {} - - chart_content += cached_text_dict.get("chart_title", "") - - if deplot_text is not None: - chart_content += f" {deplot_text}" - - chart_content += " " + cached_text_dict.get("caption", "") - chart_content += " " + cached_text_dict.get("info_deplot", "") - chart_content += " " + cached_text_dict.get("x_title", "") - chart_content += " " + cached_text_dict.get("xlabel", "") - chart_content += " " + cached_text_dict.get("y_title", "") - chart_content += " " + cached_text_dict.get("ylabel", "") - chart_content += " " + cached_text_dict.get("legend_label", "") - chart_content += " " + cached_text_dict.get("legend_title", "") - chart_content += " " + cached_text_dict.get("mark_label", "") - chart_content += " " + cached_text_dict.get("value_label", "") - chart_content += " " + cached_text_dict.get("other", "") - except json.JSONDecodeError: - chart_content += cached_text - - if deplot_text is not None: - chart_content += f" {deplot_text}" - - else: - if deplot_text is not None: - chart_content += f" {deplot_text}" + chart_content += yolox_text_dict.get("chart_title", "") + + chart_content += " " + yolox_text_dict.get("caption", "") + chart_content += " " + yolox_text_dict.get("x_title", "") + chart_content += " " + yolox_text_dict.get("xlabel", "") + chart_content += " " + yolox_text_dict.get("y_title", "") + chart_content += " " + yolox_text_dict.get("ylabel", "") + chart_content += " " + yolox_text_dict.get("legend_label", "") + chart_content += " " + yolox_text_dict.get("legend_title", "") + chart_content += " " + yolox_text_dict.get("mark_label", "") + chart_content += " " + yolox_text_dict.get("value_label", "") + chart_content += " " + yolox_text_dict.get("other", "") return chart_content + + +def match_bboxes(yolox_box, paddle_ocr_boxes, already_matched=None, delta=2.0): + """ + Associates a yolox-graphic-elements box to PaddleOCR bboxes, by taking overlapping boxes. + Criterion is iou > max_iou / delta where max_iou is the biggest found overlap. + Boxes are expeceted in format (x0, y0, x1, y1) + Args: + yolox_box (np array [4]): Cached Bbox. + paddle_ocr_boxes (np array [n x 4]): PaddleOCR boxes + already_matched (list or None, Optional): Already matched ids to ignore. + delta (float, Optional): IoU delta for considering several boxes. Defaults to 2.. + Returns: + np array or list: Indices of the match bboxes + """ + x0_1, y0_1, x1_1, y1_1 = yolox_box + x0_2, y0_2, x1_2, y1_2 = ( + paddle_ocr_boxes[:, 0], + paddle_ocr_boxes[:, 1], + paddle_ocr_boxes[:, 2], + paddle_ocr_boxes[:, 3], + ) + + # Intersection + inter_y0 = np.maximum(y0_1, y0_2) + inter_y1 = np.minimum(y1_1, y1_2) + inter_x0 = np.maximum(x0_1, x0_2) + inter_x1 = np.minimum(x1_1, x1_2) + inter_area = np.maximum(0, inter_y1 - inter_y0) * np.maximum(0, inter_x1 - inter_x0) + + # Union + area_1 = (y1_1 - y0_1) * (x1_1 - x0_1) + area_2 = (y1_2 - y0_2) * (x1_2 - x0_2) + union_area = area_1 + area_2 - inter_area + + # IoU + ious = inter_area / union_area + + max_iou = np.max(ious) + if max_iou <= 0.01: + return [] + + matches = np.where(ious > (max_iou / delta))[0] + if already_matched is not None: + matches = np.array([m for m in matches if m not in already_matched]) + return matches + + +def join_yolox_and_paddle_output(yolox_output, paddle_txts, paddle_boxes): + """ + Matching boxes + We need to associate a text to the paddle detections. + For each class and for each CACHED detections, we look for overlapping text bboxes + with IoU > max_iou / delta where max_iou is the biggest found overlap. + Found texts are added to the class representation, and removed from the texts to match + """ + KEPT_CLASSES = [ # Used CACHED classes, corresponds to YoloX classes + "chart_title", + "x_title", + "y_title", + "xlabel", + "ylabel", + "other", + "legend_label", + "legend_title", + "mark_label", + "value_label", + ] + + paddle_txts = np.array(paddle_txts) + paddle_boxes = np.array(paddle_boxes) + + if (paddle_txts.size == 0) or (paddle_boxes.size == 0): + return {} + + paddle_boxes = np.array( + [ + paddle_boxes[:, :, 0].min(-1), + paddle_boxes[:, :, 1].min(-1), + paddle_boxes[:, :, 0].max(-1), + paddle_boxes[:, :, 1].max(-1), + ] + ).T + + already_matched = [] + results = {} + + for k in KEPT_CLASSES: + if not len(yolox_output.get(k, [])): # No bounding boxes + continue + + texts = [] + for yolox_box in yolox_output[k]: + # if there's a score at the end, drop the score. + yolox_box = yolox_box[:4] + paddle_ids = match_bboxes(yolox_box, paddle_boxes, already_matched=already_matched, delta=4) + + if len(paddle_ids) > 0: + text = " ".join(paddle_txts[paddle_ids].tolist()) + texts.append(text) + + processed_texts = [] + for t in texts: + t = re.sub(r"\s+", " ", t) + t = re.sub(r"\.+", ".", t) + processed_texts.append(t) + + if "title" in k: + processed_texts = " ".join(processed_texts) + else: + processed_texts = " - ".join(processed_texts) # Space ? + + results[k] = processed_texts + + return results + + +def convert_paddle_response_to_psuedo_markdown(texts, bboxes): + if (not bboxes) or (not texts): + return "" + + bboxes = np.array(bboxes).astype(int) + bboxes = bboxes.reshape(-1, 8)[:, [0, 1, 2, -1]] + + preds_df = pd.DataFrame( + {"x0": bboxes[:, 0], "y0": bboxes[:, 1], "x1": bboxes[:, 2], "y1": bboxes[:, 3], "text": texts} + ) + preds_df = preds_df.sort_values("y0") + + dbscan = DBSCAN(eps=10, min_samples=1) + dbscan.fit(preds_df["y0"].values[:, None]) + + preds_df["cluster"] = dbscan.labels_ + preds_df = preds_df.sort_values(["cluster", "x0"]) + + results = "" + for _, dfg in preds_df.groupby("cluster"): + results += "| " + " | ".join(dfg["text"].values.tolist()) + " |\n" + + return results diff --git a/src/nv_ingest/util/nim/helpers.py b/src/nv_ingest/util/nim/helpers.py index bfbb37a0..2390e5e3 100644 --- a/src/nv_ingest/util/nim/helpers.py +++ b/src/nv_ingest/util/nim/helpers.py @@ -14,7 +14,6 @@ import numpy as np import requests import tritonclient.grpc as grpcclient -from packaging import version as pkgversion from nv_ingest.util.image_processing.transforms import normalize_image from nv_ingest.util.image_processing.transforms import pad_image @@ -362,7 +361,7 @@ def create_inference_client( return NimClient(model_interface, infer_protocol, endpoints, auth_token) -def preprocess_image_for_paddle(array: np.ndarray, paddle_version: Optional[str] = None) -> np.ndarray: +def preprocess_image_for_paddle(array: np.ndarray, image_max_dimension: int = 960) -> np.ndarray: """ Preprocesses an input image to be suitable for use with PaddleOCR by resizing, normalizing, padding, and transposing it into the required format. @@ -395,11 +394,8 @@ def preprocess_image_for_paddle(array: np.ndarray, paddle_version: Optional[str] a requirement for PaddleOCR. - The normalized pixel values are scaled between 0 and 1 before padding and transposing the image. """ - if (not paddle_version) or (pkgversion.parse(paddle_version) < pkgversion.parse("0.2.0-rc1")): - return array - height, width = array.shape[:2] - scale_factor = 960 / max(height, width) + scale_factor = image_max_dimension / max(height, width) new_height = int(height * scale_factor) new_width = int(width * scale_factor) resized = cv2.resize(array, (new_width, new_height)) @@ -409,14 +405,25 @@ def preprocess_image_for_paddle(array: np.ndarray, paddle_version: Optional[str] # PaddleOCR NIM (GRPC) requires input shapes to be multiples of 32. new_height = (normalized.shape[0] + 31) // 32 * 32 new_width = (normalized.shape[1] + 31) // 32 * 32 - padded, _ = pad_image( + padded, (pad_width, pad_height) = pad_image( normalized, target_height=new_height, target_width=new_width, background_color=0, dtype=np.float32 ) # PaddleOCR NIM (GRPC) requires input to be (channel, height, width). transposed = padded.transpose((2, 0, 1)) - return transposed + # Metadata can used for inverting transformations on the resulting bounding boxes. + metadata = { + "original_height": height, + "original_width": width, + "scale_factor": scale_factor, + "new_height": transposed.shape[1], + "new_width": transposed.shape[2], + "pad_height": pad_height, + "pad_width": pad_width, + } + + return transposed, metadata def remove_url_endpoints(url) -> str: diff --git a/src/nv_ingest/util/nim/paddle.py b/src/nv_ingest/util/nim/paddle.py index a7d3af79..bd466eec 100644 --- a/src/nv_ingest/util/nim/paddle.py +++ b/src/nv_ingest/util/nim/paddle.py @@ -5,11 +5,7 @@ from typing import Optional import numpy as np -import pandas as pd -from packaging import version as pkgversion -from sklearn.cluster import DBSCAN -from nv_ingest.schemas.metadata_schema import TableFormatEnum from nv_ingest.util.image_processing.transforms import base64_to_numpy from nv_ingest.util.nim.helpers import ModelInterface from nv_ingest.util.nim.helpers import preprocess_image_for_paddle @@ -22,20 +18,6 @@ class PaddleOCRModelInterface(ModelInterface): An interface for handling inference with a PaddleOCR model, supporting both gRPC and HTTP protocols. """ - def __init__( - self, - paddle_version: Optional[str] = None, - ): - """ - Initialize the PaddleOCR model interface. - - Parameters - ---------- - paddle_version : str, optional - The version of the PaddleOCR model (default: None). - """ - self.paddle_version = paddle_version - def name(self) -> str: """ Get the name of the model interface. @@ -43,9 +25,9 @@ def name(self) -> str: Returns ------- str - The name of the model interface, including the PaddleOCR version. + The name of the model interface. """ - return f"PaddleOCR - {self.paddle_version}" + return "PaddleOCR" def prepare_data_for_inference(self, data: Dict[str, Any]) -> Dict[str, Any]: """ @@ -67,9 +49,6 @@ def prepare_data_for_inference(self, data: Dict[str, Any]) -> Dict[str, Any]: image_array = base64_to_numpy(base64_image) data["image_array"] = image_array - # Cache image dimensions for computing bounding boxes. - self._width, self._height = image_array.shape[:2] - return data def format_input(self, data: Dict[str, Any], protocol: str, **kwargs) -> Any: @@ -99,7 +78,18 @@ def format_input(self, data: Dict[str, Any], protocol: str, **kwargs) -> Any: if protocol == "grpc": logger.debug("Formatting input for gRPC PaddleOCR model") image_data = data["image_array"] - image_data = preprocess_image_for_paddle(image_data, self.paddle_version) + + image_data, metadata = preprocess_image_for_paddle(image_data) + + # Cache image dimensions for computing bounding boxes. + self._orig_height = metadata["original_height"] + self._orig_width = metadata["original_width"] + self._scale_factor = metadata["scale_factor"] + self._max_height = metadata["new_height"] + self._max_width = metadata["new_width"] + self._pad_height = metadata["pad_height"] + self._pad_width = metadata["pad_width"] + image_data = image_data.astype(np.float32) image_data = np.expand_dims(image_data, axis=0) @@ -137,24 +127,13 @@ def parse_output(self, response: Any, protocol: str, data: Optional[Dict[str, An ValueError If an invalid protocol is specified or the response format is unexpected. """ - default_table_content_format = ( - TableFormatEnum.SIMPLE if self._is_version_early_access_legacy_api() else TableFormatEnum.PSEUDO_MARKDOWN - ) - table_content_format = kwargs.get("table_content_format", default_table_content_format) - - if self._is_version_early_access_legacy_api() and (table_content_format != TableFormatEnum.SIMPLE): - logger.warning( - f"Paddle version {self.paddle_version} does not support {table_content_format} format. " - "The table content will be in `simple` format." - ) - table_content_format = TableFormatEnum.SIMPLE - if protocol == "grpc": logger.debug("Parsing output from gRPC PaddleOCR model") - return self._extract_content_from_paddle_grpc_response(response, table_content_format) + + return self._extract_content_from_paddle_grpc_response(response, data) elif protocol == "http": logger.debug("Parsing output from HTTP PaddleOCR model") - return self._extract_content_from_paddle_http_response(response, table_content_format) + return self._extract_content_from_paddle_http_response(response) else: raise ValueError("Invalid protocol specified. Must be 'grpc' or 'http'.") @@ -178,9 +157,6 @@ def process_inference_results(self, output: Any, **kwargs) -> Any: # For PaddleOCR, the output is the table content as a string return output - def _is_version_early_access_legacy_api(self): - return self.paddle_version and (pkgversion.parse(self.paddle_version) < pkgversion.parse("0.2.1-rc2")) - def _prepare_paddle_payload(self, base64_img: str) -> Dict[str, Any]: """ Prepare a payload for the PaddleOCR HTTP API using a base64-encoded image. @@ -198,18 +174,14 @@ def _prepare_paddle_payload(self, base64_img: str) -> Dict[str, Any]: image_url = f"data:image/png;base64,{base64_img}" - if self._is_version_early_access_legacy_api(): - image = {"type": "image_url", "image_url": {"url": image_url}} - message = {"content": [image]} - payload = {"messages": [message]} - else: - image = {"type": "image_url", "url": image_url} - payload = {"input": [image]} + image = {"type": "image_url", "url": image_url} + payload = {"input": [image]} return payload def _extract_content_from_paddle_http_response( - self, json_response: Dict[str, Any], table_content_format: Optional[str] + self, + json_response: Dict[str, Any], ) -> Any: """ Extract content from the JSON response of a PaddleOCR HTTP API request. @@ -233,80 +205,51 @@ def _extract_content_from_paddle_http_response( if "data" not in json_response or not json_response["data"]: raise RuntimeError("Unexpected response format: 'data' key is missing or empty.") - if self._is_version_early_access_legacy_api(): - content = json_response["data"][0]["content"] - else: - text_detections = json_response["data"][0]["text_detections"] + text_detections = json_response["data"][0]["text_detections"] - text_predictions = [] - bounding_boxes = [] - for text_detection in text_detections: - text_predictions.append(text_detection["text_prediction"]["text"]) - bounding_boxes.append([(point["x"], point["y"]) for point in text_detection["bounding_box"]["points"]]) + text_predictions = [] + bounding_boxes = [] + for text_detection in text_detections: + text_predictions.append(text_detection["text_prediction"]["text"]) + bounding_boxes.append([(point["x"], point["y"]) for point in text_detection["bounding_box"]["points"]]) - if table_content_format == TableFormatEnum.SIMPLE: - content = " ".join(text_predictions) - elif table_content_format == TableFormatEnum.PSEUDO_MARKDOWN: - content = self._convert_paddle_response_to_psuedo_markdown(bounding_boxes, text_predictions) - else: - raise ValueError(f"Unexpected table format: {table_content_format}") + text_predictions, bounding_boxes = self._postprocess_paddle_response( + text_predictions, bounding_boxes, scale_factor=1.0 + ) - return content, table_content_format + return text_predictions, bounding_boxes - def _extract_content_from_paddle_grpc_response(self, response, table_content_format): + def _extract_content_from_paddle_grpc_response(self, response, data): if not isinstance(response, np.ndarray): raise ValueError("Unexpected response format: response is not a NumPy array.") - if self._is_version_early_access_legacy_api(): - content = " ".join([output[0].decode("utf-8") for output in response]) - else: - bboxes_bytestr, texts_bytestr, _ = response - bounding_boxes = json.loads(bboxes_bytestr.decode("utf8"))[0] - text_predictions = json.loads(texts_bytestr.decode("utf8"))[0] + bboxes_bytestr, texts_bytestr, _ = response + bounding_boxes = json.loads(bboxes_bytestr.decode("utf8"))[0] + text_predictions = json.loads(texts_bytestr.decode("utf8"))[0] - if table_content_format == TableFormatEnum.SIMPLE: - content = " ".join(text_predictions) - elif table_content_format == TableFormatEnum.PSEUDO_MARKDOWN: - content = self._convert_paddle_response_to_psuedo_markdown(bounding_boxes, text_predictions) - else: - raise ValueError(f"Unexpected table format: {table_content_format}") + text_predictions, bounding_boxes = self._postprocess_paddle_response( + text_predictions, bounding_boxes, scale_factor=self._scale_factor + ) - return content, table_content_format + return text_predictions, bounding_boxes - def _convert_paddle_response_to_psuedo_markdown(self, bounding_boxes, text_predictions): + def _postprocess_paddle_response(self, text_predictions, bounding_boxes, scale_factor): bboxes = [] texts = [] + for box, txt in zip(bounding_boxes, text_predictions): if box == "nan": continue points = [] for point in box: - # The coordinates from Paddle are normlized. Convert them back to integers for DBSCAN. - x = float(point[0]) * self._width - y = float(point[1]) * self._height - points.append([x, y]) + # The coordinates from Paddle are normlized. Convert them back to integers, + # and scale or shift them back to their original positions if padded or scaled. + x_pixels = float(point[0]) * self._max_width - self._pad_width + y_pixels = float(point[1]) * self._max_height - self._pad_height + x_original = x_pixels / scale_factor + y_original = y_pixels / scale_factor + points.append([x_original, y_original]) bboxes.append(points) texts.append(txt) - if (not bboxes) or (not texts): - return "" - - bboxes = np.array(bboxes).astype(int) - bboxes = bboxes.reshape(-1, 8)[:, [0, 1, 2, -1]] - - preds_df = pd.DataFrame( - {"x0": bboxes[:, 0], "y0": bboxes[:, 1], "x1": bboxes[:, 2], "y1": bboxes[:, 3], "text": texts} - ) - preds_df = preds_df.sort_values("y0") - - dbscan = DBSCAN(eps=10, min_samples=1) - dbscan.fit(preds_df["y0"].values[:, None]) - - preds_df["cluster"] = dbscan.labels_ - preds_df = preds_df.sort_values(["cluster", "x0"]) - - results = "" - for _, dfg in preds_df.groupby("cluster"): - results += "| " + " | ".join(dfg["text"].values.tolist()) + " |\n" - - return results + return texts, bboxes diff --git a/src/nv_ingest/util/nim/yolox.py b/src/nv_ingest/util/nim/yolox.py index ebfdd31a..33fbe2a3 100644 --- a/src/nv_ingest/util/nim/yolox.py +++ b/src/nv_ingest/util/nim/yolox.py @@ -274,15 +274,15 @@ def process_inference_results(self, output: Any, protocol: str, **kwargs) -> Lis original_image_shapes, self.image_preproc_width, self.image_preproc_height, + self.class_labels, min_score=self.min_score, ) inference_results = self.postprocess_annotations(results, **kwargs) - inference_results = self.transform_normalized_coordinates_to_original(inference_results, original_image_shapes) return inference_results - def postprocess_annotations(self, annotation_dicts: List[Dict[str, Any]], **kwargs) -> List[Dict[str, Any]]: + def postprocess_annotations(self, annotation_dicts, **kwargs): raise NotImplementedError() def transform_normalized_coordinates_to_original(self, results, original_image_shapes): @@ -293,14 +293,16 @@ def transform_normalized_coordinates_to_original(self, results, original_image_s new_dict = {} for label, bboxes_and_scores in annotation_dict.items(): new_dict[label] = [] - for *bbox, score in bboxes_and_scores: + for bbox_and_score in bboxes_and_scores: + bbox = bbox_and_score[:4] transformed_bbox = [ bbox[0] * shape[1], bbox[1] * shape[0], bbox[2] * shape[1], bbox[3] * shape[0], ] - new_dict[label].append(transformed_bbox + [score]) + transformed_bbox += bbox_and_score[4:] + new_dict[label].append(transformed_bbox) transformed_results.append(new_dict) return transformed_results @@ -341,7 +343,9 @@ def name( return "yolox-page-elements" - def postprocess_annotations(self, annotation_dicts: List[Dict[str, Any]], **kwargs) -> List[Dict[str, Any]]: + def postprocess_annotations(self, annotation_dicts, **kwargs): + original_image_shapes = kwargs.get("original_image_shapes", []) + # Table/chart expansion is "business logic" specific to nv-ingest annotation_dicts = [expand_table_bboxes(annotation_dict) for annotation_dict in annotation_dicts] annotation_dicts = [expand_chart_bboxes(annotation_dict) for annotation_dict in annotation_dicts] @@ -359,6 +363,8 @@ def postprocess_annotations(self, annotation_dicts: List[Dict[str, Any]], **kwar new_dict["title"] = annotation_dict["title"] inference_results.append(new_dict) + inference_results = self.transform_normalized_coordinates_to_original(inference_results, original_image_shapes) + return inference_results @@ -397,21 +403,26 @@ def name( return "yolox-graphic-elements" - def postprocess_annotations(self, annotation_dicts: List[Dict[str, Any]], **kwargs) -> List[Dict[str, Any]]: + def postprocess_annotations(self, annotation_dicts, **kwargs): original_image_shapes = kwargs.get("original_image_shapes", []) + annotation_dicts = self.transform_normalized_coordinates_to_original(annotation_dicts, original_image_shapes) + inference_results = [] # bbox extraction: additional postprocessing speicifc to nv-ingest for pred, shape in zip(annotation_dicts, original_image_shapes): - inference_results.append( - get_bbox_dict_yolox_graphic( - pred, - shape, - self.class_labels, - self.min_score, - ) + bbox_dict = get_bbox_dict_yolox_graphic( + pred, + shape, + self.class_labels, + self.min_score, ) + # convert numpy arrays to list + bbox_dict = { + label: array.tolist() if isinstance(array, np.ndarray) else array for label, array in bbox_dict.items() + } + inference_results.append(bbox_dict) return inference_results @@ -475,7 +486,9 @@ def postprocess_model_prediction(prediction, num_classes, conf_thre=0.7, nms_thr return output -def postprocess_results(results, original_image_shapes, image_preproc_width, image_preproc_height, min_score=0.0): +def postprocess_results( + results, original_image_shapes, image_preproc_width, image_preproc_height, class_labels, min_score=0.0 +): """ For each item (==image) in results, computes annotations in the form @@ -487,7 +500,6 @@ def postprocess_results(results, original_image_shapes, image_preproc_width, ima Keep only bboxes with high enough confidence. """ - class_labels = ["table", "chart", "title"] out = [] for original_image_shape, result in zip(original_image_shapes, results): diff --git a/src/nv_ingest/util/pipeline/stage_builders.py b/src/nv_ingest/util/pipeline/stage_builders.py index 7dd80f3a..ecfedbd6 100644 --- a/src/nv_ingest/util/pipeline/stage_builders.py +++ b/src/nv_ingest/util/pipeline/stage_builders.py @@ -225,21 +225,15 @@ def add_table_extractor_stage(pipe, morpheus_pipeline_config, ingest_config, def def add_chart_extractor_stage(pipe, morpheus_pipeline_config, ingest_config, default_cpu_count): - _, _, yolox_auth, _ = get_table_detection_service("yolox") - - deplot_grpc, deplot_http, deplot_auth, deplot_protocol = get_table_detection_service("deplot") - cached_grpc, cached_http, cached_auth, cached_protocol = get_table_detection_service("cached") - # NOTE: Paddle isn't currently used directly by the chart extraction stage, but will be in the future. + yolox_grpc, yolox_http, yolox_auth, yolox_protocol = get_table_detection_service("yolox_graphic_elements") paddle_grpc, paddle_http, paddle_auth, paddle_protocol = get_table_detection_service("paddle") table_content_extractor_config = ingest_config.get( "table_content_extraction_module", { "stage_config": { - "cached_endpoints": (cached_grpc, cached_http), - "cached_infer_protocol": cached_protocol, - "deplot_endpoints": (deplot_grpc, deplot_http), - "deplot_infer_protocol": deplot_protocol, + "yolox_endpoints": (yolox_grpc, yolox_http), + "yolox_infer_protocol": yolox_protocol, "paddle_endpoints": (paddle_grpc, paddle_http), "paddle_infer_protocol": paddle_protocol, "auth_token": yolox_auth, From 822a83e7403447754fe6d0dbfc7198f33256c468 Mon Sep 17 00:00:00 2001 From: edknv Date: Tue, 21 Jan 2025 09:59:27 -0800 Subject: [PATCH 04/28] use graphic elements labels in http mode --- src/nv_ingest/util/nim/yolox.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/nv_ingest/util/nim/yolox.py b/src/nv_ingest/util/nim/yolox.py index 33fbe2a3..ee1b4377 100644 --- a/src/nv_ingest/util/nim/yolox.py +++ b/src/nv_ingest/util/nim/yolox.py @@ -223,7 +223,7 @@ def parse_output(self, response: Any, protocol: str, data: Optional[Dict[str, An batch_results = response.get("data", []) for detections in batch_results: - new_bounding_boxes = {"table": [], "chart": [], "title": []} + new_bounding_boxes = {label: [] for label in self.class_labels} bounding_boxes = detections.get("bounding_boxes", []) for obj_type, bboxes in bounding_boxes.items(): From b1b4414cbe429aa324db9d8b7e854cb95638ee4f Mon Sep 17 00:00:00 2001 From: edknv Date: Tue, 21 Jan 2025 10:06:04 -0800 Subject: [PATCH 05/28] remove experimental code --- src/nv_ingest/stages/nim/chart_extraction.py | 21 -------------------- 1 file changed, 21 deletions(-) diff --git a/src/nv_ingest/stages/nim/chart_extraction.py b/src/nv_ingest/stages/nim/chart_extraction.py index 0447f145..4e83d86a 100644 --- a/src/nv_ingest/stages/nim/chart_extraction.py +++ b/src/nv_ingest/stages/nim/chart_extraction.py @@ -91,27 +91,6 @@ def _update_metadata(row: pd.Series, yolox_client: NimClient, paddle_client: Nim trace_info=trace_info, # traceable_func arg ) - ### - source_name = metadata.get("source_metadata", {}).get("source_name").split("/")[-1] - page_number = metadata.get("content_metadata", {}).get("page_number") - - out_path = f"/workspace/data/tmp/{source_name}.{page_number}.yolox.pickle" - - import os - import pickle - - if os.path.exists(out_path): - with open(out_path, "rb") as f: - data = pickle.load(f) - else: - data = {"data": []} - - data["data"].append({"base64_image": base64_image, "yolox": yolox_result[0], "paddle": paddle_result}) - - with open(out_path, "wb") as f: - pickle.dump(data, f) - ### - text_predictions, bounding_boxes = paddle_result yolox_elements = join_yolox_and_paddle_output(yolox_result[0], text_predictions, bounding_boxes) chart_content = process_yolox_graphic_elements(yolox_elements) From b6a82d1c3cb4829c9951ea377607a31eda10d61a Mon Sep 17 00:00:00 2001 From: edknv Date: Wed, 22 Jan 2025 00:20:36 -0800 Subject: [PATCH 06/28] include threshold parameters in http request --- src/nv_ingest/util/nim/yolox.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/nv_ingest/util/nim/yolox.py b/src/nv_ingest/util/nim/yolox.py index ee1b4377..fa9f7595 100644 --- a/src/nv_ingest/util/nim/yolox.py +++ b/src/nv_ingest/util/nim/yolox.py @@ -183,7 +183,11 @@ def format_input(self, data: Dict[str, Any], protocol: str) -> Any: content_list.append(content) - payload = {"input": content_list} + payload = { + "input": content_list, + "confidence_threshold": self.conf_threshold, + "nms_threshold": self.iou_threshold, + } return payload else: From be2e30faaf3bccf0c375f8f4b2985ab7dc2fed9b Mon Sep 17 00:00:00 2001 From: edknv Date: Fri, 31 Jan 2025 09:25:12 -0800 Subject: [PATCH 07/28] create a new metadata object for thread safety --- src/nv_ingest/util/nim/paddle.py | 26 +++++++++++++++++--------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/src/nv_ingest/util/nim/paddle.py b/src/nv_ingest/util/nim/paddle.py index 4eb7793f..0e08b15a 100644 --- a/src/nv_ingest/util/nim/paddle.py +++ b/src/nv_ingest/util/nim/paddle.py @@ -76,6 +76,8 @@ def format_input(self, data: Dict[str, Any], protocol: str, **kwargs) -> Any: images = data["image_arrays"] + data["_dimensions"] = [] # Cache image dimensions information for scale/shift. + if protocol == "grpc": logger.debug("Formatting input for gRPC PaddleOCR model (batched).") @@ -86,8 +88,8 @@ def format_input(self, data: Dict[str, Any], protocol: str, **kwargs) -> Any: processed = [] self._dims = [] for img in images: - arr, metadata = preprocess_image_for_paddle(img) - self._dims.append(metadata) + arr, _dims = preprocess_image_for_paddle(img) + data["_dimensions"].append(_dims) arr = arr.astype(np.float32) arr = np.expand_dims(arr, axis=0) # => shape (1, H, W, C) processed.append(arr) @@ -130,7 +132,7 @@ def parse_output(self, response: Any, protocol: str, data: Optional[Dict[str, An """ if protocol == "grpc": logger.debug("Parsing output from gRPC PaddleOCR model (batched).") - return self._extract_content_from_paddle_grpc_response(response) + return self._extract_content_from_paddle_grpc_response(response, data["_dimensions"]) elif protocol == "http": logger.debug("Parsing output from HTTP PaddleOCR model (batched).") @@ -198,6 +200,7 @@ def _extract_content_from_paddle_http_response( def _extract_content_from_paddle_grpc_response( self, response: np.ndarray, + dimensions: List[Dict[str, Any]], ) -> List[Tuple[str, str]]: """ Parses a gRPC response for one or more images. @@ -246,6 +249,7 @@ def _extract_content_from_paddle_grpc_response( bounding_boxes, text_predictions = self._postprocess_paddle_response( bounding_boxes, text_predictions, + dimensions, img_index=i, ) @@ -254,17 +258,21 @@ def _extract_content_from_paddle_grpc_response( return results def _postprocess_paddle_response( - self, bounding_boxes: List[Any], text_predictions: List[str], img_index: int = 0 + self, + bounding_boxes: List[Any], + text_predictions: List[str], + dimensions: List[Dict[str, Any]], + img_index: int = 0, ) -> str: """ Convert bounding boxes & text to pseudo-markdown. For multiple images, we use self._dims[img_index] to recover the correct height/width. """ - max_width = self._dims[img_index]["new_width"] - max_height = self._dims[img_index]["new_height"] - pad_width = self._dims[img_index]["pad_width"] - pad_height = self._dims[img_index]["pad_height"] - scale_factor = self._dims[img_index]["scale_factor"] + max_width = dimensions[img_index]["new_width"] + max_height = dimensions[img_index]["new_height"] + pad_width = dimensions[img_index]["pad_width"] + pad_height = dimensions[img_index]["pad_height"] + scale_factor = dimensions[img_index]["scale_factor"] bboxes = [] texts = [] From 553caa9ec2b85e32374f820f4e95b674031430d1 Mon Sep 17 00:00:00 2001 From: edknv Date: Fri, 31 Jan 2025 12:47:21 -0800 Subject: [PATCH 08/28] update tests --- src/nv_ingest/stages/nim/table_extraction.py | 5 +- .../util/image_processing/table_and_chart.py | 2 +- src/nv_ingest/util/nim/paddle.py | 2 +- .../image/test_image_handlers.py | 6 +- .../schemas/test_chart_extractor_schema.py | 31 +-- .../stages/nims/test_chart_extraction.py | 211 +++++++++--------- .../stages/nims/test_table_extraction.py | 30 +-- tests/nv_ingest/util/nim/test_helpers.py | 94 +++----- tests/nv_ingest/util/nim/test_paddle.py | 142 ++---------- tests/nv_ingest/util/nim/test_yolox.py | 20 +- 10 files changed, 208 insertions(+), 335 deletions(-) diff --git a/src/nv_ingest/stages/nim/table_extraction.py b/src/nv_ingest/stages/nim/table_extraction.py index fba42126..06ed0d64 100644 --- a/src/nv_ingest/stages/nim/table_extraction.py +++ b/src/nv_ingest/stages/nim/table_extraction.py @@ -254,13 +254,14 @@ def meets_criteria(row): ) # 4) Write the results (bounding_boxes, text_predictions) back - table_content_format = df.at[valid_indices[0], "metadata"]["table_metadata"].get("table_content_format") + table_content_format = df.at[valid_indices[0], "metadata"]["table_metadata"].get( + "table_content_format", TableFormatEnum.PSEUDO_MARKDOWN + ) for row_id, idx in enumerate(valid_indices): # unpack (base64_image, (bounding boxes, text_predictions)) _, (bounding_boxes, text_predictions) = bulk_results[row_id] - table_content_format = TableFormatEnum.SIMPLE if table_content_format == TableFormatEnum.SIMPLE: table_content = " ".join(text_predictions) elif table_content_format == TableFormatEnum.PSEUDO_MARKDOWN: diff --git a/src/nv_ingest/util/image_processing/table_and_chart.py b/src/nv_ingest/util/image_processing/table_and_chart.py index 3364345b..4e3d6e91 100644 --- a/src/nv_ingest/util/image_processing/table_and_chart.py +++ b/src/nv_ingest/util/image_processing/table_and_chart.py @@ -43,7 +43,7 @@ def process_yolox_graphic_elements(yolox_text_dict): chart_content += " " + yolox_text_dict.get("value_label", "") chart_content += " " + yolox_text_dict.get("other", "") - return chart_content + return chart_content.strip() def match_bboxes(yolox_box, paddle_ocr_boxes, already_matched=None, delta=2.0): diff --git a/src/nv_ingest/util/nim/paddle.py b/src/nv_ingest/util/nim/paddle.py index 0e08b15a..c9a979e2 100644 --- a/src/nv_ingest/util/nim/paddle.py +++ b/src/nv_ingest/util/nim/paddle.py @@ -191,7 +191,7 @@ def _extract_content_from_paddle_http_response( bounding_boxes = [] for td in text_detections: text_predictions.append(td["text_prediction"]["text"]) - bounding_boxes.append([(pt["x"], pt["y"]) for pt in td["bounding_box"]["points"]]) + bounding_boxes.append([[pt["x"], pt["y"]] for pt in td["bounding_box"]["points"]]) results.append([bounding_boxes, text_predictions]) diff --git a/tests/nv_ingest/extraction_workflows/image/test_image_handlers.py b/tests/nv_ingest/extraction_workflows/image/test_image_handlers.py index da0c358e..eb2f047a 100644 --- a/tests/nv_ingest/extraction_workflows/image/test_image_handlers.py +++ b/tests/nv_ingest/extraction_workflows/image/test_image_handlers.py @@ -138,7 +138,7 @@ def test_extract_table_and_chart_images_empty_annotations(): def test_extract_table_and_chart_images_single_table(): """Test extraction with a single table bounding box.""" - annotation_dict = {"table": [[0.1, 0.1, 0.3, 0.3, 0.8]], "chart": []} + annotation_dict = {"table": [[64, 64, 192, 192, 0.8]], "chart": []} original_image = np.random.rand(640, 640, 3) tables_and_charts = [] @@ -161,7 +161,7 @@ def test_extract_table_and_chart_images_single_table(): def test_extract_table_and_chart_images_single_chart(): """Test extraction with a single chart bounding box.""" - annotation_dict = {"table": [], "chart": [[0.4, 0.4, 0.6, 0.6, 0.9]]} + annotation_dict = {"table": [], "chart": [[256, 256, 384, 384, 0.9]]} original_image = np.random.rand(640, 640, 3) tables_and_charts = [] @@ -198,7 +198,7 @@ def test_extract_table_and_chart_images_multiple_objects(): def test_extract_table_and_chart_images_invalid_bounding_box(): """Test with an invalid bounding box to check handling of incorrect coordinates.""" - annotation_dict = {"table": [[1.1, 1.1, 1.5, 1.5, 0.9]], "chart": []} # Out of bounds + annotation_dict = {"table": [[704, 704, 960, 960, 0.9]], "chart": []} # Out of bounds original_image = np.random.rand(640, 640, 3) tables_and_charts = [] diff --git a/tests/nv_ingest/schemas/test_chart_extractor_schema.py b/tests/nv_ingest/schemas/test_chart_extractor_schema.py index 64d2f0b2..917fcafc 100644 --- a/tests/nv_ingest/schemas/test_chart_extractor_schema.py +++ b/tests/nv_ingest/schemas/test_chart_extractor_schema.py @@ -11,54 +11,45 @@ def test_valid_config_with_grpc_only(): config = ChartExtractorConfigSchema( auth_token="valid_token", - cached_endpoints=("grpc://cached_service", None), - deplot_endpoints=("grpc://deplot_service", None), + yolox_endpoints=("grpc://yolox_service", None), paddle_endpoints=("grpc://paddle_service", None), ) assert config.auth_token == "valid_token" - assert config.cached_endpoints == ("grpc://cached_service", None) - assert config.deplot_endpoints == ("grpc://deplot_service", None) + assert config.yolox_endpoints == ("grpc://yolox_service", None) assert config.paddle_endpoints == ("grpc://paddle_service", None) def test_valid_config_with_http_only(): config = ChartExtractorConfigSchema( auth_token="valid_token", - cached_endpoints=(None, "http://cached_service"), - deplot_endpoints=(None, "http://deplot_service"), + yolox_endpoints=(None, "http://yolox_service"), paddle_endpoints=(None, "http://paddle_service"), ) assert config.auth_token == "valid_token" - assert config.cached_endpoints == (None, "http://cached_service") - assert config.deplot_endpoints == (None, "http://deplot_service") + assert config.yolox_endpoints == (None, "http://yolox_service") assert config.paddle_endpoints == (None, "http://paddle_service") def test_invalid_config_with_empty_services(): with pytest.raises(ValidationError) as excinfo: - ChartExtractorConfigSchema( - cached_endpoints=(None, None), deplot_endpoints=(None, None), paddle_endpoints=(None, None) - ) + ChartExtractorConfigSchema(yolox_endpoints=(None, None), paddle_endpoints=(None, None)) assert "Both gRPC and HTTP services cannot be empty" in str(excinfo.value) def test_valid_config_with_both_grpc_and_http(): config = ChartExtractorConfigSchema( auth_token="another_token", - cached_endpoints=("grpc://cached_service", "http://cached_service"), - deplot_endpoints=("grpc://deplot_service", "http://deplot_service"), + yolox_endpoints=("grpc://yolox_service", "http://yolox_service"), paddle_endpoints=("grpc://paddle_service", "http://paddle_service"), ) assert config.auth_token == "another_token" - assert config.cached_endpoints == ("grpc://cached_service", "http://cached_service") - assert config.deplot_endpoints == ("grpc://deplot_service", "http://deplot_service") + assert config.yolox_endpoints == ("grpc://yolox_service", "http://yolox_service") assert config.paddle_endpoints == ("grpc://paddle_service", "http://paddle_service") def test_invalid_auth_token_none(): config = ChartExtractorConfigSchema( - cached_endpoints=("grpc://cached_service", None), - deplot_endpoints=("grpc://deplot_service", None), + yolox_endpoints=("grpc://yolox_service", None), paddle_endpoints=("grpc://paddle_service", None), ) assert config.auth_token is None @@ -67,7 +58,8 @@ def test_invalid_auth_token_none(): def test_invalid_endpoint_format(): with pytest.raises(ValidationError): ChartExtractorConfigSchema( - cached_endpoints=("invalid_endpoint", None), deplot_endpoints=(None, "invalid_endpoint") + yolox_endpoints=("invalid_endpoint", None), + deplot_endpoints=(None, "invalid_endpoint"), ) @@ -82,8 +74,7 @@ def test_chart_extractor_schema_defaults(): def test_chart_extractor_schema_with_custom_values(): stage_config = ChartExtractorConfigSchema( - cached_endpoints=("grpc://cached_service", "http://cached_service"), - deplot_endpoints=("grpc://deplot_service", None), + yolox_endpoints=("grpc://yolox_service", "http://yolox_service"), paddle_endpoints=(None, "http://paddle_service"), ) config = ChartExtractorSchema(max_queue_size=10, n_workers=5, raise_on_failure=True, stage_config=stage_config) diff --git a/tests/nv_ingest/stages/nims/test_chart_extraction.py b/tests/nv_ingest/stages/nims/test_chart_extraction.py index 8d637a15..6198a871 100644 --- a/tests/nv_ingest/stages/nims/test_chart_extraction.py +++ b/tests/nv_ingest/stages/nims/test_chart_extraction.py @@ -2,10 +2,12 @@ import pytest import pandas as pd +import numpy as np from nv_ingest.schemas.chart_extractor_schema import ChartExtractorConfigSchema from nv_ingest.stages.nim.chart_extraction import _update_metadata, _create_clients from nv_ingest.stages.nim.chart_extraction import _extract_chart_data +from nv_ingest.util.image_processing.transforms import base64_to_numpy MODULE_UNDER_TEST = "nv_ingest.stages.nim.chart_extraction" @@ -18,10 +20,8 @@ def valid_chart_extractor_config(): """ return ChartExtractorConfigSchema( auth_token="fake_token", - cached_endpoints=("cached_grpc_url", "cached_http_url"), - cached_infer_protocol="grpc", - deplot_endpoints=("deplot_grpc_url", "deplot_http_url"), - deplot_infer_protocol="http", + yolox_endpoints=("yolox_grpc_url", "yolox_http_url"), + yolox_infer_protocol="grpc", paddle_endpoints=("paddle_grpc_url", "paddle_http_url"), paddle_infer_protocol="grpc", nim_batch_size=2, @@ -29,6 +29,11 @@ def valid_chart_extractor_config(): ) +@pytest.fixture +def base64_image(): + return "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=" + + @pytest.fixture def validated_config(valid_chart_extractor_config): """ @@ -48,184 +53,190 @@ def test_update_metadata_empty_list(): If the base64_images list is empty, _update_metadata should skip all logic and return an empty list. """ - cached_mock = MagicMock() - deplot_mock = MagicMock() + yolox_mock = MagicMock() + paddle_mock = MagicMock() trace_info = {} result = _update_metadata( base64_images=[], - cached_client=cached_mock, - deplot_client=deplot_mock, + yolox_client=yolox_mock, + paddle_client=paddle_mock, trace_info=trace_info, batch_size=1, worker_pool_size=1, ) assert result == [] - cached_mock.infer.assert_not_called() - deplot_mock.infer.assert_not_called() + yolox_mock.infer.assert_not_called() -def test_update_metadata_single_batch_single_worker(mocker): +def test_update_metadata_single_batch_single_worker(mocker, base64_image): """ Test a simple scenario with a small list of base64_images, batch_size=2, - worker_pool_size=1. We verify that cached is called once per batch, - and deplot is called once per image in that batch. + worker_pool_size=1. We verify that yolox is called once per batch, + and paddle is called once per image in that batch. """ # Mock out the clients - cached_mock = MagicMock() - deplot_mock = MagicMock() + yolox_mock = MagicMock() + paddle_mock = MagicMock() - # Suppose cached returns ["cached_res1", "cached_res2"] for 2 images - cached_mock.infer.return_value = ["cached_res1", "cached_res2"] + # Suppose yolox returns ["yolox_res1", "yolox_res2"] for 2 images + yolox_mock.infer.return_value = ["yolox_res1", "yolox_res2"] - # Suppose deplot is called for each image, returning single string - deplot_mock.infer.side_effect = [["deplot_res1"], ["deplot_res2"]] + # Suppose paddle returns ["paddle_res1", "paddle_res2"] for 2 images + paddle_mock.infer.return_value = [[(), "paddle_res1"], [(), "paddle_res2"]] - mock_join = mocker.patch(f"{MODULE_UNDER_TEST}.join_cached_and_deplot_output", side_effect=["joined_1", "joined_2"]) + mock_join = mocker.patch( + f"{MODULE_UNDER_TEST}.join_yolox_and_paddle_output", + side_effect=[{"chart_title": "joined_1"}, {"chart_title": "joined_2"}], + ) - base64_images = ["img1", "img2"] + base64_images = [ + base64_image, + base64_image, + ] trace_info = {} - result = _update_metadata(base64_images, cached_mock, deplot_mock, trace_info, batch_size=2, worker_pool_size=1) + result = _update_metadata(base64_images, yolox_mock, paddle_mock, trace_info, batch_size=2, worker_pool_size=1) # We expect result => [("img1", "joined_1"), ("img2", "joined_2")] assert len(result) == 2 - assert result[0] == ("img1", "joined_1") - assert result[1] == ("img2", "joined_2") + assert result[0] == (base64_image, "joined_1") + assert result[1] == (base64_image, "joined_2") # Check calls - cached_mock.infer.assert_called_once_with( - data={"base64_images": ["img1", "img2"]}, - model_name="cached", + assert yolox_mock.infer.call_count == 1 + assert np.all(yolox_mock.infer.call_args.kwargs["data"]["images"][0] == base64_to_numpy(base64_image)) + assert np.all(yolox_mock.infer.call_args.kwargs["data"]["images"][1] == base64_to_numpy(base64_image)) + assert yolox_mock.infer.call_args.kwargs["model_name"] == "yolox" + assert yolox_mock.infer.call_args.kwargs["stage_name"] == "chart_data_extraction" + assert yolox_mock.infer.call_args.kwargs["trace_info"] == trace_info + + # paddle.infer called once per image + assert paddle_mock.infer.call_count == 1 + paddle_mock.infer.assert_any_call( + data={"base64_images": [base64_image, base64_image]}, + model_name="paddle", stage_name="chart_data_extraction", trace_info=trace_info, ) - # deplot.infer called once per image - assert deplot_mock.infer.call_count == 2 - deplot_mock.infer.assert_any_call( - data={"base64_image": "img1"}, model_name="deplot", stage_name="chart_data_extraction", trace_info=trace_info - ) - deplot_mock.infer.assert_any_call( - data={"base64_image": "img2"}, model_name="deplot", stage_name="chart_data_extraction", trace_info=trace_info - ) - - # join_cached_and_deplot_output called twice + # join_yolox_and_paddle_output called twice assert mock_join.call_count == 2 -def test_update_metadata_multiple_batches_multi_worker(mocker): +def test_update_metadata_multiple_batches_multi_worker(mocker, base64_image): """ If batch_size=1 but we have multiple images, each image forms its own batch. We also can use worker_pool_size=2 for parallel calls. """ - cached_mock = MagicMock() - deplot_mock = MagicMock() + yolox_mock = MagicMock() + paddle_mock = MagicMock() mock_join = mocker.patch( - f"{MODULE_UNDER_TEST}.join_cached_and_deplot_output", side_effect=["joined_1", "joined_2", "joined_3"] + f"{MODULE_UNDER_TEST}.join_yolox_and_paddle_output", + side_effect=[{"chart_title": "joined_1"}, {"chart_title": "joined_2"}, {"chart_title": "joined_3"}], ) - # Suppose every cached.infer call returns a 1-element list - def cached_side_effect(**kwargs): - images = kwargs["data"]["base64_images"] - return [f"cached_{images[0]}"] + # Suppose every yolox.infer call returns a 1-element list + def yolox_side_effect(**kwargs): + images = kwargs["data"]["images"] + return [f"yolox_{images[0]}"] - cached_mock.infer.side_effect = cached_side_effect + yolox_mock.infer.side_effect = yolox_side_effect - # Suppose deplot.infer returns e.g. ["deplot_img1"], etc. - def deplot_side_effect(**kwargs): - img = kwargs["data"]["base64_image"] - return [f"deplot_{img}"] + # Suppose paddle.infer returns e.g. ["paddle_img1"], etc. + def paddle_side_effect(**kwargs): + img = kwargs["data"]["base64_images"] + return [([], f"paddle_{img}")] - deplot_mock.infer.side_effect = deplot_side_effect + paddle_mock.infer.side_effect = paddle_side_effect - base64_images = ["imgA", "imgB", "imgC"] + base64_images = [base64_image, base64_image, base64_image] trace_info = {} result = _update_metadata( base64_images, - cached_mock, - deplot_mock, + yolox_mock, + paddle_mock, trace_info, batch_size=1, # each image in its own batch worker_pool_size=2, ) # Expect 3 results: [("imgA", "joined_1"), ("imgB", "joined_2"), ("imgC", "joined_3")] - assert result == [("imgA", "joined_1"), ("imgB", "joined_2"), ("imgC", "joined_3")] + assert result == [(base64_image, "joined_1"), (base64_image, "joined_2"), (base64_image, "joined_3")] - # We should have 3 calls to cached.infer, each with one image - assert cached_mock.infer.call_count == 3 - # Also 3 calls to deplot.infer - assert deplot_mock.infer.call_count == 3 + # We should have 3 calls to yolox.infer, each with one image + assert yolox_mock.infer.call_count == 3 + # Also 3 calls to paddle.infer + assert paddle_mock.infer.call_count == 3 # 3 calls to join assert mock_join.call_count == 3 -def test_update_metadata_exception_in_cached_call(caplog): +def test_update_metadata_exception_in_yolox_call(base64_image, caplog): """ - If the cached call fails for a batch, we expect an exception to bubble up + If the yolox call fails for a batch, we expect an exception to bubble up and the error logged. """ - cached_mock = MagicMock() - deplot_mock = MagicMock() - cached_mock.infer.side_effect = Exception("Cached call error") + yolox_mock = MagicMock() + paddle_mock = MagicMock() + yolox_mock.infer.side_effect = Exception("Yolox call error") - with pytest.raises(Exception, match="Cached call error"): - _update_metadata(["some_img"], cached_mock, deplot_mock, trace_info={}, batch_size=1, worker_pool_size=1) + with pytest.raises(Exception, match="Yolox call error"): + _update_metadata([base64_image], yolox_mock, paddle_mock, trace_info={}, batch_size=1, worker_pool_size=1) # Check log - assert "Error processing batch: ['some_img']" in caplog.text + assert f"Error processing batch: ['{base64_image}']" in caplog.text -def test_update_metadata_exception_in_deplot_call(caplog): +def test_update_metadata_exception_in_paddle_call(base64_image, caplog): """ - If any deplot call fails for one of the images, the entire process fails + If any paddle call fails for one of the images, the entire process fails and logs the error. """ - cached_mock = MagicMock() - cached_mock.infer.return_value = ["cached_result"] # 1-element list - deplot_mock = MagicMock() - deplot_mock.infer.side_effect = Exception("Deplot error") + yolox_mock = MagicMock() + yolox_mock.infer.return_value = ["yolox_result"] # 1-element list + paddle_mock = MagicMock() + paddle_mock.infer.side_effect = Exception("Paddle error") - with pytest.raises(Exception, match="Deplot error"): - _update_metadata(["some_img"], cached_mock, deplot_mock, trace_info={}, batch_size=1, worker_pool_size=2) + with pytest.raises(Exception, match="Paddle error"): + _update_metadata([base64_image], yolox_mock, paddle_mock, trace_info={}, batch_size=1, worker_pool_size=2) - assert "Error processing batch: ['some_img']" in caplog.text + assert f"Error processing batch: ['{base64_image}']" in caplog.text def test_create_clients(mocker): """ Verify that _create_clients calls create_inference_client for - both cached and deplot endpoints, returning the pair of NimClient mocks. + both yolox and paddle endpoints, returning the pair of NimClient mocks. """ mock_create_inference_client = mocker.patch(f"{MODULE_UNDER_TEST}.create_inference_client") # Suppose it returns different mocks each time - cached_mock = MagicMock() - deplot_mock = MagicMock() - mock_create_inference_client.side_effect = [cached_mock, deplot_mock] + yolox_mock = MagicMock() + paddle_mock = MagicMock() + mock_create_inference_client.side_effect = [yolox_mock, paddle_mock] result = _create_clients( - cached_endpoints=("cached_grpc", "cached_http"), - cached_protocol="grpc", - deplot_endpoints=("deplot_grpc", "deplot_http"), - deplot_protocol="http", + yolox_endpoints=("yolox_grpc", "yolox_http"), + yolox_protocol="grpc", + paddle_endpoints=("paddle_grpc", "paddle_http"), + paddle_protocol="http", auth_token="xyz", ) - # result => (cached_mock, deplot_mock) - assert result == (cached_mock, deplot_mock) + # result => (yolox_mock, paddle_mock) + assert result == (yolox_mock, paddle_mock) # Check calls assert mock_create_inference_client.call_count == 2 mock_create_inference_client.assert_any_call( - endpoints=("cached_grpc", "cached_http"), model_interface=mocker.ANY, auth_token="xyz", infer_protocol="grpc" + endpoints=("yolox_grpc", "yolox_http"), model_interface=mocker.ANY, auth_token="xyz", infer_protocol="grpc" ) mock_create_inference_client.assert_any_call( - endpoints=("deplot_grpc", "deplot_http"), model_interface=mocker.ANY, auth_token="xyz", infer_protocol="http" + endpoints=("paddle_grpc", "paddle_http"), model_interface=mocker.ANY, auth_token="xyz", infer_protocol="http" ) @@ -280,8 +291,8 @@ def test_extract_chart_data_all_valid(validated_config, mocker): All rows meet criteria => pass them all to _update_metadata in order. """ # Mock out clients - cached_mock, deplot_mock = MagicMock(), MagicMock() - mock_create_clients = mocker.patch(f"{MODULE_UNDER_TEST}._create_clients", return_value=(cached_mock, deplot_mock)) + yolox_mock, paddle_mock = MagicMock(), MagicMock() + mock_create_clients = mocker.patch(f"{MODULE_UNDER_TEST}._create_clients", return_value=(yolox_mock, paddle_mock)) # Suppose _update_metadata returns chart content for each image mock_update_metadata = mocker.patch( @@ -315,18 +326,18 @@ def test_extract_chart_data_all_valid(validated_config, mocker): assert df_out.at[1, "metadata"]["table_metadata"]["table_content"] == {"joined": "contentB"} mock_create_clients.assert_called_once_with( - validated_config.stage_config.cached_endpoints, - validated_config.stage_config.cached_infer_protocol, - validated_config.stage_config.deplot_endpoints, - validated_config.stage_config.deplot_infer_protocol, + validated_config.stage_config.yolox_endpoints, + validated_config.stage_config.yolox_infer_protocol, + validated_config.stage_config.paddle_endpoints, + validated_config.stage_config.paddle_infer_protocol, validated_config.stage_config.auth_token, ) # Check _update_metadata call mock_update_metadata.assert_called_once_with( base64_images=["imgA", "imgB"], - cached_client=cached_mock, - deplot_client=deplot_mock, + yolox_client=yolox_mock, + paddle_client=paddle_mock, batch_size=validated_config.stage_config.nim_batch_size, worker_pool_size=validated_config.stage_config.workers_per_progress_engine, trace_info=ti.get("trace_info"), @@ -338,8 +349,8 @@ def test_extract_chart_data_mixed_rows(validated_config, mocker): Some rows are valid, some not. We only pass valid images to _update_metadata, and only those rows get updated. """ - cached_mock, deplot_mock = MagicMock(), MagicMock() - mocker.patch(f"{MODULE_UNDER_TEST}._create_clients", return_value=(cached_mock, deplot_mock)) + yolox_mock, paddle_mock = MagicMock(), MagicMock() + mocker.patch(f"{MODULE_UNDER_TEST}._create_clients", return_value=(yolox_mock, paddle_mock)) mock_update = mocker.patch( f"{MODULE_UNDER_TEST}._update_metadata", @@ -384,8 +395,8 @@ def test_extract_chart_data_mixed_rows(validated_config, mocker): mock_update.assert_called_once_with( base64_images=["base64img1", "base64img2"], - cached_client=cached_mock, - deplot_client=deplot_mock, + yolox_client=yolox_mock, + paddle_client=paddle_mock, batch_size=validated_config.stage_config.nim_batch_size, worker_pool_size=validated_config.stage_config.workers_per_progress_engine, trace_info=trace.get("trace_info"), diff --git a/tests/nv_ingest/stages/nims/test_table_extraction.py b/tests/nv_ingest/stages/nims/test_table_extraction.py index 6e1e2f6f..7249d2a7 100644 --- a/tests/nv_ingest/stages/nims/test_table_extraction.py +++ b/tests/nv_ingest/stages/nims/test_table_extraction.py @@ -106,8 +106,8 @@ def test_extract_table_data_all_valid(mocker, validated_config): mock_update_metadata = mocker.patch( f"{MODULE_UNDER_TEST}._update_metadata", return_value=[ - ("imgA", ("tableA", "fmtA")), - ("imgB", ("tableB", "fmtB")), + ("imgA", [[], ["tableA"]]), + ("imgB", [[], ["tableB"]]), ], ) @@ -116,14 +116,14 @@ def test_extract_table_data_all_valid(mocker, validated_config): { "metadata": { "content_metadata": {"type": "structured", "subtype": "table"}, - "table_metadata": {}, + "table_metadata": {"table_content_format": "simple"}, "content": "imgA", } }, { "metadata": { "content_metadata": {"type": "structured", "subtype": "table"}, - "table_metadata": {}, + "table_metadata": {"table_content_format": "simple"}, "content": "imgB", } }, @@ -134,9 +134,9 @@ def test_extract_table_data_all_valid(mocker, validated_config): # Each valid row updated assert df_out.at[0, "metadata"]["table_metadata"]["table_content"] == "tableA" - assert df_out.at[0, "metadata"]["table_metadata"]["table_content_format"] == "fmtA" + assert df_out.at[0, "metadata"]["table_metadata"]["table_content_format"] == "simple" assert df_out.at[1, "metadata"]["table_metadata"]["table_content"] == "tableB" - assert df_out.at[1, "metadata"]["table_metadata"]["table_content_format"] == "fmtB" + assert df_out.at[1, "metadata"]["table_metadata"]["table_content_format"] == "simple" # Check calls mock_create_client.assert_called_once() @@ -158,7 +158,7 @@ def test_extract_table_data_mixed_rows(mocker, validated_config): mock_create_client = mocker.patch(f"{MODULE_UNDER_TEST}._create_paddle_client", return_value=mock_client) mock_update_metadata = mocker.patch( f"{MODULE_UNDER_TEST}._update_metadata", - return_value=[("good1", ("table1", "fmt1")), ("good2", ("table2", "fmt2"))], + return_value=[("good1", [[], ["table1"]]), ("good2", [[], ["table2"]])], ) df_in = pd.DataFrame( @@ -166,7 +166,7 @@ def test_extract_table_data_mixed_rows(mocker, validated_config): { "metadata": { "content_metadata": {"type": "structured", "subtype": "table"}, - "table_metadata": {}, + "table_metadata": {"table_content_format": "simple"}, "content": "good1", } }, @@ -190,14 +190,14 @@ def test_extract_table_data_mixed_rows(mocker, validated_config): df_out, trace_info = _extract_table_data(df_in, {}, validated_config) - # row0 => updated with table1/fmt1 + # row0 => updated with table1/txt1 assert df_out.at[0, "metadata"]["table_metadata"]["table_content"] == "table1" - assert df_out.at[0, "metadata"]["table_metadata"]["table_content_format"] == "fmt1" + assert df_out.at[0, "metadata"]["table_metadata"]["table_content_format"] == "simple" # row1 => invalid => no table_content assert "table_content" not in df_out.at[1, "metadata"]["table_metadata"] - # row2 => updated => table2/fmt2 + # row2 => updated => table2/txt2 assert df_out.at[2, "metadata"]["table_metadata"]["table_content"] == "table2" - assert df_out.at[2, "metadata"]["table_metadata"]["table_content_format"] == "fmt2" + assert df_out.at[2, "metadata"]["table_metadata"]["table_content_format"] == "simple" mock_update_metadata.assert_called_once_with( base64_images=["good1", "good2"], @@ -312,7 +312,7 @@ def test_update_metadata_skip_small(mocker, paddle_mock): res = _update_metadata(imgs, paddle_mock, batch_size=2) assert len(res) == 2 # First was too small => ("", "") - assert res[0] == ("imgSmall", ("", "")) + assert res[0] == ("imgSmall", (None, None)) assert res[1] == ("imgBig", ("valid_table", "valid_fmt")) paddle_mock.infer.assert_called_once_with( @@ -415,8 +415,8 @@ def test_update_metadata_all_small(mocker, paddle_mock): mocker.patch(f"{MODULE_UNDER_TEST}.PADDLE_MIN_HEIGHT", 30) res = _update_metadata(imgs, paddle_mock, batch_size=2) - assert res[0] == ("imgA", ("", "")) - assert res[1] == ("imgB", ("", "")) + assert res[0] == ("imgA", (None, None)) + assert res[1] == ("imgB", (None, None)) # No calls to infer paddle_mock.infer.assert_not_called() diff --git a/tests/nv_ingest/util/nim/test_helpers.py b/tests/nv_ingest/util/nim/test_helpers.py index f0fae039..4486a301 100644 --- a/tests/nv_ingest/util/nim/test_helpers.py +++ b/tests/nv_ingest/util/nim/test_helpers.py @@ -4,7 +4,6 @@ from unittest.mock import Mock, patch import numpy as np -import packaging.version import pytest import requests @@ -493,52 +492,50 @@ def test_create_inference_client_http_endpoint_whitespace_no_infer_protocol(mock # Preprocess image for paddle -def test_preprocess_image_paddle_version_none(sample_image): +def test_preprocess_image_transpose(sample_image): """ - Test that when paddle_version is None, the function returns the input image unchanged. + Test that the output image is transposed correctly. """ - result = preprocess_image_for_paddle(sample_image, paddle_version=None) - assert np.array_equal( - result, sample_image - ), "The output should be the same as the input when paddle_version is None." + result, _ = preprocess_image_for_paddle(sample_image) + # The output should have shape (channels, height, width) + assert result.shape[0] == sample_image.shape[2], "The output should have channels in the first dimension." + assert result.shape[1] > 0 and result.shape[2] > 0, "The output height and width should be greater than zero." -def test_preprocess_image_paddle_version_old(sample_image): - """ - Test that when paddle_version is less than '0.2.0-rc1', the function returns the input image unchanged. - """ - result = preprocess_image_for_paddle(sample_image, paddle_version="0.1.0") - assert np.array_equal( - result, sample_image - ), "The output should be the same as the input when paddle_version is less than '0.2.0-rc1'." +def test_preprocess_image_for_paddle_metadata(): + # Create a dummy image (300x500, RGB) + height, width = 300, 500 + image = np.random.randint(0, 256, (height, width, 3), dtype=np.uint8) + # Run the preprocessing function + processed_image, metadata = preprocess_image_for_paddle(image, image_max_dimension=960) -def test_preprocess_image_paddle_version_new(sample_image): - """ - Test that when paddle_version is '0.2.0-rc1' or higher, the function processes the image. - """ - result = preprocess_image_for_paddle(sample_image, paddle_version="0.2.0-rc1") - assert not np.array_equal( - result, sample_image - ), "The output should be different from the input when paddle_version is '0.2.0-rc1' or higher." - assert result.shape[0] == sample_image.shape[2], "The output should have shape (channels, height, width)." + # Expected scale factor + expected_scale_factor = 960 / max(height, width) + expected_new_height = int(height * expected_scale_factor) + expected_new_width = int(width * expected_scale_factor) + # Compute expected padding + expected_padded_height = (expected_new_height + 31) // 32 * 32 + expected_padded_width = (expected_new_width + 31) // 32 * 32 + expected_pad_height = expected_padded_height - expected_new_height + expected_pad_width = expected_padded_width - expected_new_width -def test_preprocess_image_transpose(sample_image): - """ - Test that the output image is transposed correctly. - """ - result = preprocess_image_for_paddle(sample_image, paddle_version="0.2.0") - # The output should have shape (channels, height, width) - assert result.shape[0] == sample_image.shape[2], "The output should have channels in the first dimension." - assert result.shape[1] > 0 and result.shape[2] > 0, "The output height and width should be greater than zero." + # Assertions + assert metadata["original_height"] == height + assert metadata["original_width"] == width + assert pytest.approx(metadata["scale_factor"]) == expected_scale_factor + assert metadata["new_height"] == expected_padded_height + assert metadata["new_width"] == expected_padded_width + assert metadata["pad_height"] == expected_pad_height + assert metadata["pad_width"] == expected_pad_width def test_preprocess_image_dtype(sample_image): """ Test that the output image has dtype float32. """ - result = preprocess_image_for_paddle(sample_image, paddle_version="0.2.0") + result, _ = preprocess_image_for_paddle(sample_image) assert result.dtype == np.float32, "The output image should have dtype float32." @@ -547,7 +544,7 @@ def test_preprocess_image_large_image(): Test processing of a large image. """ image = np.random.randint(0, 256, size=(3000, 2000, 3), dtype=np.uint8) - result = preprocess_image_for_paddle(image, paddle_version="0.2.0") + result, _ = preprocess_image_for_paddle(image) height, width = image.shape[:2] scale_factor = 960 / max(height, width) new_height = int(height * scale_factor) @@ -564,7 +561,7 @@ def test_preprocess_image_small_image(): Test processing of a small image. """ image = np.random.randint(0, 256, size=(50, 50, 3), dtype=np.uint8) - result = preprocess_image_for_paddle(image, paddle_version="0.2.0") + result, _ = preprocess_image_for_paddle(image) height, width = image.shape[:2] scale_factor = 960 / max(height, width) new_height = int(height * scale_factor) @@ -581,7 +578,7 @@ def test_preprocess_image_non_multiple_of_32(): Test that images with dimensions not multiples of 32 are padded correctly. """ image = np.random.randint(0, 256, size=(527, 319, 3), dtype=np.uint8) - result = preprocess_image_for_paddle(image, paddle_version="0.2.0") + result, _ = preprocess_image_for_paddle(image) height, width = image.shape[:2] scale_factor = 960 / max(height, width) new_height = int(height * scale_factor) @@ -598,7 +595,7 @@ def test_preprocess_image_dtype_uint8(): Test that the function works with images of dtype uint8. """ image = np.random.randint(0, 256, size=(700, 500, 3), dtype=np.uint8) - result = preprocess_image_for_paddle(image, paddle_version="0.2.0") + result, _ = preprocess_image_for_paddle(image) assert result.dtype == np.float32, "The output image should be converted to dtype float32." @@ -607,7 +604,7 @@ def test_preprocess_image_max_dimension_less_than_960(): Test that images with max dimension less than 960 are scaled up. """ image = np.random.randint(0, 256, size=(800, 600, 3), dtype=np.uint8) - result = preprocess_image_for_paddle(image, paddle_version="0.2.0") + result, _ = preprocess_image_for_paddle(image) height, width = image.shape[:2] scale_factor = 960 / max(height, width) new_height = int(height * scale_factor) @@ -625,7 +622,7 @@ def test_preprocess_image_zero_dimension(): """ image = np.zeros((0, 0, 3), dtype=np.uint8) with pytest.raises(Exception): - preprocess_image_for_paddle(image, paddle_version="0.2.0") + preprocess_image_for_paddle(image) def test_preprocess_image_invalid_input(): @@ -634,24 +631,7 @@ def test_preprocess_image_invalid_input(): """ image = "not an image" with pytest.raises(Exception): - preprocess_image_for_paddle(image, paddle_version="0.2.0") - - -def test_preprocess_image_different_paddle_versions(sample_image): - """ - Test the function with different paddle_version inputs. - """ - versions = ["0.1.0", "0.2.0-rc0", "0.2.0-rc1", "0.2.1"] - for version in versions: - result = preprocess_image_for_paddle(sample_image, paddle_version=version) - if packaging.version.parse(version) < packaging.version.parse("0.2.0-rc1"): - assert np.array_equal( - result, sample_image - ), f"The output should be the same as the input when paddle_version is {version}." - else: - assert not np.array_equal( - result, sample_image - ), f"The output should be different from the input when paddle_version is {version}." + preprocess_image_for_paddle(image) # Tests for `remove_url_endpoints` diff --git a/tests/nv_ingest/util/nim/test_paddle.py b/tests/nv_ingest/util/nim/test_paddle.py index ada46a2a..f3efa7af 100644 --- a/tests/nv_ingest/util/nim/test_paddle.py +++ b/tests/nv_ingest/util/nim/test_paddle.py @@ -56,12 +56,7 @@ def create_valid_grpc_response_batched(text="mock_text"): @pytest.fixture def paddle_ocr_model(): - return PaddleOCRModelInterface(paddle_version="0.2.1") - - -@pytest.fixture -def legacy_paddle_ocr_model(): - return PaddleOCRModelInterface(paddle_version="0.2.0") + return PaddleOCRModelInterface() @pytest.fixture @@ -113,16 +108,13 @@ def test_prepare_data_for_inference(paddle_ocr_model): assert len(result["image_arrays"]) == 1 assert result["image_arrays"][0].shape == (100, 100, 3) - # We also store dimensions in self._dims - assert paddle_ocr_model._dims == [(100, 100)] - def test_format_input_grpc(paddle_ocr_model): """ Now we place the images under 'image_arrays' and return a batched input for gRPC. """ with patch(f"{_MODULE_UNDER_TEST}.preprocess_image_for_paddle") as mock_preprocess: - mock_preprocess.return_value = np.zeros((32, 32, 3)) + mock_preprocess.return_value = (np.zeros((32, 32, 3)), {}) # For gRPC, we rely on 'image_arrays' data = {"image_arrays": [np.zeros((32, 32, 3))]} @@ -134,8 +126,7 @@ def test_format_input_grpc(paddle_ocr_model): def test_format_input_http(paddle_ocr_model): """ - For HTTP, if not legacy, we expect a payload with 'input': [...] - containing a valid base64 PNG. + For HTTP, we expect a payload with 'input': [...] containing a valid base64 PNG. """ valid_b64 = create_valid_base64_image() data = {"base64_image": valid_b64} @@ -145,7 +136,7 @@ def test_format_input_http(paddle_ocr_model): result = paddle_ocr_model.format_input(data, protocol="http") - # For non-legacy => {"input": [ {"type":"image_url","url": "..."} ]} + # {"input": [ {"type":"image_url","url": "..."} ]} assert "input" in result assert len(result["input"]) == 1 first_item = result["input"][0] @@ -155,45 +146,6 @@ def test_format_input_http(paddle_ocr_model): assert len(first_item["url"]) > len("data:image/png;base64,") -def test_format_input_http_legacy(legacy_paddle_ocr_model): - """ - For legacy version (<0.2.1-rc2), the code should produce the 'messages' structure. - """ - valid_b64 = create_valid_base64_image() - data = {"base64_image": valid_b64} - - data = legacy_paddle_ocr_model.prepare_data_for_inference(data) - result = legacy_paddle_ocr_model.format_input(data, protocol="http") - - # Now we expect => {"messages":[{"content":[ { "type":"image_url", - # "image_url":{"url":"data:image/png;base64,..."}}, ... ]}]} - assert "messages" in result - assert len(result["messages"]) == 1 - content_list = result["messages"][0]["content"] - assert len(content_list) == 1 - item = content_list[0] - assert item["type"] == "image_url" - assert item["image_url"]["url"].startswith("data:image/png;base64,") - - -def test_parse_output_http_pseudo_markdown(paddle_ocr_model, mock_paddle_http_response): - """ - parse_output should now return a list of (content, table_content_format) tuples. - e.g. [("| mock_text |\n", "pseudo_markdown")] - """ - with patch(f"{_MODULE_UNDER_TEST}.base64_to_numpy") as mock_base64_to_numpy: - mock_base64_to_numpy.return_value = np.zeros((3, 100, 100)) - - data = {"base64_image": "mock_base64_string"} - _ = paddle_ocr_model.prepare_data_for_inference(data) - - result = paddle_ocr_model.parse_output(mock_paddle_http_response, protocol="http") - # It's a list with one tuple => (content, format). - assert len(result) == 1 - assert result[0][0] == "| mock_text |\n" - assert result[0][1] == "pseudo_markdown" - - def test_parse_output_http_simple(paddle_ocr_model, mock_paddle_http_response): """ The new parse_output also returns a list of (content, format). @@ -205,55 +157,11 @@ def test_parse_output_http_simple(paddle_ocr_model, mock_paddle_http_response): data = {"base64_image": "mock_base64_string"} _ = paddle_ocr_model.prepare_data_for_inference(data) - result = paddle_ocr_model.parse_output(mock_paddle_http_response, protocol="http", table_content_format="simple") - # Should be [("mock_text", "simple")] - assert len(result) == 1 - assert result[0][0] == "mock_text" - assert result[0][1] == "simple" - - -def test_parse_output_http_simple_legacy(legacy_paddle_ocr_model): - """ - For the legacy version, parse_output also returns a list of (content, format), - but it forces 'simple' format if the user requested something else. - """ - with patch(f"{_MODULE_UNDER_TEST}.base64_to_numpy") as mock_base64_to_numpy: - mock_base64_to_numpy.return_value = np.zeros((100, 100, 3)) - - data = {"base64_image": "mock_base64_string"} - _ = legacy_paddle_ocr_model.prepare_data_for_inference(data) - - mock_legacy_paddle_http_response = {"data": [{"content": "mock_text"}]} - - result = legacy_paddle_ocr_model.parse_output( - mock_legacy_paddle_http_response, protocol="http", table_content_format="foo" - ) - # Expect => [("mock_text", "simple")] - assert len(result) == 1 - assert result[0][0] == "mock_text" - assert result[0][1] == "simple" - - -def test_parse_output_grpc_pseudo_markdown(paddle_ocr_model): - """ - Provide a valid (1,2) shape for bounding boxes & text. - The interface should parse them into [("| mock_text |\n", "pseudo_markdown")]. - """ - valid_b64 = create_valid_base64_image() - data = {"base64_image": valid_b64} - paddle_ocr_model.prepare_data_for_inference(data) - - # Create a valid shape => (1,2), with bounding-box JSON, text JSON - grpc_response = create_valid_grpc_response_batched("mock_text") - - # parse_output with default => pseudo_markdown for non-legacy - result = paddle_ocr_model.parse_output(grpc_response, protocol="grpc") - + result = paddle_ocr_model.parse_output(mock_paddle_http_response, protocol="http") + # Should be (bounding_boxes, text_predictions) assert len(result) == 1 - content, fmt = result[0] - # content might contain a markdown table row => "| mock_text |" - assert "mock_text" in content - assert fmt == "pseudo_markdown" + assert result[0][0] == [[[0.1, 0.2], [0.2, 0.2], [0.2, 0.3], [0.1, 0.3]]] + assert result[0][1] == ["mock_text"] def test_parse_output_grpc_simple(paddle_ocr_model): @@ -264,28 +172,18 @@ def test_parse_output_grpc_simple(paddle_ocr_model): data = {"base64_image": valid_b64} paddle_ocr_model.prepare_data_for_inference(data) + data["_dimensions"] = [ + { + "new_width": 1, + "new_height": 1, + "pad_width": 0, + "pad_height": 0, + "scale_factor": 1.0, + } + ] grpc_response = create_valid_grpc_response_batched("mock_text") - result = paddle_ocr_model.parse_output(grpc_response, protocol="grpc", table_content_format="simple") - - assert len(result) == 1 - content, fmt = result[0] - assert content == "mock_text" - assert fmt == "simple" - - -def test_parse_output_grpc_legacy(legacy_paddle_ocr_model): - """ - For legacy gRPC, we also return a list of (content, format). - We force 'simple' if table_content_format was not 'simple'. - """ - valid_b64 = create_valid_base64_image() - data = {"base64_image": valid_b64} - legacy_paddle_ocr_model.prepare_data_for_inference(data) - - grpc_response = create_valid_grpc_response_batched("mock_text") + result = paddle_ocr_model.parse_output(grpc_response, protocol="grpc", data=data) - # Pass a non-"simple" format => should be forced to 'simple' in legacy - result = legacy_paddle_ocr_model.parse_output(grpc_response, protocol="grpc", table_content_format="foo") assert len(result) == 1 - assert result[0][0] == "mock_text" - assert result[0][1] == "simple" + assert result[0][0] == [[[0.1, 0.2], [0.2, 0.2], [0.2, 0.3], [0.1, 0.3]]] + assert result[0][1] == ["mock_text"] diff --git a/tests/nv_ingest/util/nim/test_yolox.py b/tests/nv_ingest/util/nim/test_yolox.py index b3a84fd3..595b3e58 100644 --- a/tests/nv_ingest/util/nim/test_yolox.py +++ b/tests/nv_ingest/util/nim/test_yolox.py @@ -182,11 +182,6 @@ def test_process_inference_results_grpc(model_interface): output_array, "grpc", original_image_shapes=original_image_shapes, - num_classes=3, - conf_thresh=0.5, - iou_thresh=0.4, - min_score=0.3, - final_thresh=0.6, ) assert isinstance(inference_results, list) assert len(inference_results) == 2 @@ -194,10 +189,10 @@ def test_process_inference_results_grpc(model_interface): assert isinstance(result, dict) if "table" in result: for bbox in result["table"]: - assert bbox[4] >= 0.6 + assert bbox[4] >= 0.48 if "chart" in result: for bbox in result["chart"]: - assert bbox[4] >= 0.6 + assert bbox[4] >= 0.48 if "title" in result: assert isinstance(result["title"], list) @@ -211,14 +206,11 @@ def test_process_inference_results_http(model_interface): } for _ in range(10) ] + original_image_shapes = [[100, 100] for _ in range(10)] inference_results = model_interface.process_inference_results( output, "http", - num_classes=3, - conf_thresh=0.5, - iou_thresh=0.4, - min_score=0.3, - final_thresh=0.6, + original_image_shapes=original_image_shapes, ) assert isinstance(inference_results, list) assert len(inference_results) == 10 @@ -226,9 +218,9 @@ def test_process_inference_results_http(model_interface): assert isinstance(result, dict) if "table" in result: for bbox in result["table"]: - assert bbox[4] >= 0.6 + assert bbox[4] >= 0.48 if "chart" in result: for bbox in result["chart"]: - assert bbox[4] >= 0.6 + assert bbox[4] >= 0.48 if "title" in result: assert isinstance(result["title"], list) From 3eeb5926ca3b7af6b0ac74b569257e25cbdd2947 Mon Sep 17 00:00:00 2001 From: edknv Date: Fri, 31 Jan 2025 13:15:03 -0800 Subject: [PATCH 09/28] minor fix --- src/nv_ingest/stages/nim/chart_extraction.py | 2 +- src/nv_ingest/stages/nim/table_extraction.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/nv_ingest/stages/nim/chart_extraction.py b/src/nv_ingest/stages/nim/chart_extraction.py index cbedd184..fbe8e818 100644 --- a/src/nv_ingest/stages/nim/chart_extraction.py +++ b/src/nv_ingest/stages/nim/chart_extraction.py @@ -260,7 +260,7 @@ def meets_criteria(row): # The order of base64_images in bulk_results should match their original # indices if we process them in the same order. for row_id, idx in enumerate(valid_indices): - (_, chart_content) = bulk_results[row_id] + _, chart_content = bulk_results[row_id] df.at[idx, "metadata"]["table_metadata"]["table_content"] = chart_content return df, {"trace_info": trace_info} diff --git a/src/nv_ingest/stages/nim/table_extraction.py b/src/nv_ingest/stages/nim/table_extraction.py index 06ed0d64..40213ee3 100644 --- a/src/nv_ingest/stages/nim/table_extraction.py +++ b/src/nv_ingest/stages/nim/table_extraction.py @@ -265,7 +265,7 @@ def meets_criteria(row): if table_content_format == TableFormatEnum.SIMPLE: table_content = " ".join(text_predictions) elif table_content_format == TableFormatEnum.PSEUDO_MARKDOWN: - table_content = convert_paddle_response_to_psuedo_markdown(text_predictions, bounding_boxes) + table_content = convert_paddle_response_to_psuedo_markdown(bounding_boxes, text_predictions) else: raise ValueError(f"Unexpected table format: {table_content_format}") From bbe165d431ac55b51fa07706e48535f84fa77d8d Mon Sep 17 00:00:00 2001 From: edknv Date: Sat, 1 Feb 2025 20:09:08 -0800 Subject: [PATCH 10/28] enable retrying for RemoteDisconnected ConnectionError --- client/src/nv_ingest_client/client/client.py | 5 ++++- .../src/nv_ingest_client/message_clients/rest/rest_client.py | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/client/src/nv_ingest_client/client/client.py b/client/src/nv_ingest_client/client/client.py index 0b2f3017..70b68188 100644 --- a/client/src/nv_ingest_client/client/client.py +++ b/client/src/nv_ingest_client/client/client.py @@ -93,10 +93,13 @@ def __init__( self._message_client_hostname = message_client_hostname or "localhost" self._message_client_port = message_client_port or 7670 self._message_counter_id = msg_counter_id or "nv-ingest-message-id" + self._message_client_kwargs = message_client_kwargs or {} logger.debug("Instantiate NvIngestClient:\n%s", str(self)) self._message_client = message_client_allocator( - host=self._message_client_hostname, port=self._message_client_port + host=self._message_client_hostname, + port=self._message_client_port, + **self._message_client_kwargs, ) # Initialize the worker pool with the specified size diff --git a/client/src/nv_ingest_client/message_clients/rest/rest_client.py b/client/src/nv_ingest_client/message_clients/rest/rest_client.py index c65952f2..4794e875 100644 --- a/client/src/nv_ingest_client/message_clients/rest/rest_client.py +++ b/client/src/nv_ingest_client/message_clients/rest/rest_client.py @@ -236,7 +236,7 @@ def fetch_message(self, job_id: str, timeout: float = 10) -> ResponseSchema: except RuntimeError as rte: raise rte - except requests.HTTPError as err: + except (ConnectionError, requests.HTTPError, requests.exceptions.ConnectionError) as err: logger.error(f"Error during fetching, retrying... Error: {err}") self._client = None # Invalidate client to force reconnection try: From d29e95a69743efce9759dcee25402466ce8805f1 Mon Sep 17 00:00:00 2001 From: edknv Date: Sat, 1 Feb 2025 22:22:52 -0800 Subject: [PATCH 11/28] fix merge --- src/nv_ingest/util/nim/paddle.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/nv_ingest/util/nim/paddle.py b/src/nv_ingest/util/nim/paddle.py index c29267cd..03f1f853 100644 --- a/src/nv_ingest/util/nim/paddle.py +++ b/src/nv_ingest/util/nim/paddle.py @@ -361,8 +361,8 @@ def _extract_content_from_paddle_grpc_response( def _postprocess_paddle_response( bounding_boxes: List[Any], text_predictions: List[str], - img_index: int = 0, dims: Optional[List[Dict[str, Any]]] = None, + img_index: int = 0, ) -> Tuple[List[Any], List[str]]: """ Convert bounding boxes with normalized coordinates to pixel cooridnates by using From 088f2f669988caa41810a13a18379ac5d58abb22 Mon Sep 17 00:00:00 2001 From: edknv Date: Sat, 1 Feb 2025 22:47:45 -0800 Subject: [PATCH 12/28] fix merge --- src/nv_ingest/util/nim/paddle.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/nv_ingest/util/nim/paddle.py b/src/nv_ingest/util/nim/paddle.py index 03f1f853..ca8df5cf 100644 --- a/src/nv_ingest/util/nim/paddle.py +++ b/src/nv_ingest/util/nim/paddle.py @@ -61,14 +61,11 @@ def prepare_data_for_inference(self, data: Dict[str, Any]) -> Dict[str, Any]: raise ValueError("The 'base64_images' key must contain a list of base64-encoded strings.") image_arrays: List[np.ndarray] = [] - dims: List[Tuple[int, int]] = [] for b64 in base64_list: img = base64_to_numpy(b64) image_arrays.append(img) - dims.append((img.shape[0], img.shape[1])) data["image_arrays"] = image_arrays - data["image_dims"] = dims elif "base64_image" in data: # Single-image fallback @@ -112,7 +109,9 @@ def format_input(self, data: Dict[str, Any], protocol: str, **kwargs: Any) -> An raise KeyError("Expected 'image_arrays' in data. Call prepare_data_for_inference first.") images = data["image_arrays"] - dims = data["image_dims"] + + dims: List[Dict[str, Any]] = [] + data["image_dims"] = dims if protocol == "grpc": logger.debug("Formatting input for gRPC PaddleOCR model (batched).") From 668f1c0b7f3c4c7234be70c2c322655cf29f8187 Mon Sep 17 00:00:00 2001 From: edknv Date: Sun, 2 Feb 2025 19:50:58 -0800 Subject: [PATCH 13/28] place all nims on the first gpu --- docker-compose.yaml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docker-compose.yaml b/docker-compose.yaml index 67a6e826..004b0ace 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -34,7 +34,7 @@ services: reservations: devices: - driver: nvidia - device_ids: ["1"] + device_ids: ["0"] capabilities: [gpu] runtime: nvidia @@ -77,7 +77,7 @@ services: reservations: devices: - driver: nvidia - device_ids: ["1"] + device_ids: ["0"] capabilities: [gpu] runtime: nvidia @@ -99,7 +99,7 @@ services: reservations: devices: - driver: nvidia - device_ids: ["1"] + device_ids: ["0"] capabilities: [gpu] runtime: nvidia @@ -167,7 +167,7 @@ services: reservations: devices: - driver: nvidia - device_ids: ["1"] + device_ids: ["0"] capabilities: [gpu] otel-collector: From c9d8ff3b5c35092253b5a261956d67a25e91826c Mon Sep 17 00:00:00 2001 From: edknv Date: Sat, 8 Feb 2025 14:12:08 -0800 Subject: [PATCH 14/28] segfault fix --- src/nv_ingest/extraction_workflows/pdf/pdfium_helper.py | 2 ++ src/nv_ingest/util/pdf/pdfium.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/nv_ingest/extraction_workflows/pdf/pdfium_helper.py b/src/nv_ingest/extraction_workflows/pdf/pdfium_helper.py index b75925aa..f7521ec7 100644 --- a/src/nv_ingest/extraction_workflows/pdf/pdfium_helper.py +++ b/src/nv_ingest/extraction_workflows/pdf/pdfium_helper.py @@ -410,6 +410,8 @@ def pdfium_extractor( futures.append(future) pages_for_tables.clear() + page.close() + # After page loop, if we still have leftover pages_for_tables, submit one last job if (extract_tables or extract_charts) and pages_for_tables: future = executor.submit( diff --git a/src/nv_ingest/util/pdf/pdfium.py b/src/nv_ingest/util/pdf/pdfium.py index daa2ee84..19cedd3e 100644 --- a/src/nv_ingest/util/pdf/pdfium.py +++ b/src/nv_ingest/util/pdf/pdfium.py @@ -176,7 +176,7 @@ def pdfium_pages_to_numpy( pil_image.thumbnail(scale_tuple, Image.LANCZOS) # Convert the PIL image to a NumPy array - img_arr = np.array(pil_image) + img_arr = np.array(pil_image).copy() # Apply padding if specified if padding_tuple: From a14f496b6362c3dcc9ac3719ed53089be4fe67e9 Mon Sep 17 00:00:00 2001 From: edknv Date: Sat, 8 Feb 2025 14:45:29 -0800 Subject: [PATCH 15/28] update docker-compose.yaml --- docker-compose.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docker-compose.yaml b/docker-compose.yaml index 004b0ace..c9f36116 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -39,7 +39,7 @@ services: runtime: nvidia yolox-graphic-elements: - image: ${YOLOX_GRAPHIC_ELEMENTS_IMAGE:-nvcr.io/ohlfw0olaadg/ea-participants/nv-yolox-graphic-elements-v1}:${YOLOX_GRAPHIC_ELEMENTS_TAG:-1.1.0} + image: ${YOLOX_GRAPHIC_ELEMENTS_IMAGE:-nvcr.io/nvidia/nemo-microservices/nv-yolox-graphic-elements-v1}:${YOLOX_GRAPHIC_ELEMENTS_TAG:-1.1.0} ports: - "8003:8000" - "8004:8001" @@ -48,7 +48,7 @@ services: environment: - NIM_HTTP_API_PORT=8000 - NIM_TRITON_LOG_VERBOSE=1 - - NGC_API_KEY=${STAGING_NIM_NGC_API_KEY} + - NGC_API_KEY=${NIM_NGC_API_KEY:-${NGC_API_KEY:-ngcapikey}} - CUDA_VISIBLE_DEVICES=0 deploy: resources: From b5fd7fa7b5b96402d90e31e04c5b8ab4d9ce0ec2 Mon Sep 17 00:00:00 2001 From: edknv Date: Sun, 9 Feb 2025 09:22:56 -0800 Subject: [PATCH 16/28] move yolox-graphics nim to device 1 --- docker-compose.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docker-compose.yaml b/docker-compose.yaml index c9f36116..1230d83c 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -55,7 +55,7 @@ services: reservations: devices: - driver: nvidia - device_ids: ["0"] + device_ids: ["1"] capabilities: [gpu] runtime: nvidia From a0a87cfcb3c8fc15937090d557fcc0f8edaba071 Mon Sep 17 00:00:00 2001 From: edknv Date: Tue, 11 Feb 2025 08:35:32 -0800 Subject: [PATCH 17/28] move yolox-graphics nim to device 0 --- docker-compose.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docker-compose.yaml b/docker-compose.yaml index 1230d83c..c9f36116 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -55,7 +55,7 @@ services: reservations: devices: - driver: nvidia - device_ids: ["1"] + device_ids: ["0"] capabilities: [gpu] runtime: nvidia From cef036908d9ea653119088b2eaf2786231a78e75 Mon Sep 17 00:00:00 2001 From: edknv Date: Tue, 11 Feb 2025 13:37:18 -0800 Subject: [PATCH 18/28] merge main --- Dockerfile | 5 + README.md | 3 +- ci/scripts/build_pip_packages.sh | 7 +- client/LICENSE | 201 ++++++++++++++++++ client/MANIFEST.in | 8 + .../client_examples/docker/Dockerfile.client | 1 - client/pyproject.toml | 59 +++++ client/requirements.txt | 22 -- client/setup.py | 73 ------- client/src/version.py | 36 ++++ .../developer-guide/environment-config.md | 84 ++------ .../getting-started/quickstart-guide.md | 3 +- .../image/image_handlers.py | 68 +++--- .../extraction_workflows/pdf/pdfium_helper.py | 89 ++++---- src/nv_ingest/stages/nim/chart_extraction.py | 162 ++++++-------- src/nv_ingest/stages/nim/table_extraction.py | 180 ++++++---------- src/nv_ingest/util/nim/helpers.py | 85 +++++--- src/nv_ingest/util/nim/yolox.py | 2 +- .../stages/nims/test_chart_extraction.py | 68 +++--- .../stages/nims/test_table_extraction.py | 67 +++--- .../util/flow_control/test_filter_by_task.py | 19 +- 21 files changed, 668 insertions(+), 574 deletions(-) create mode 100644 client/LICENSE create mode 100644 client/MANIFEST.in create mode 100644 client/pyproject.toml delete mode 100644 client/requirements.txt delete mode 100644 client/setup.py create mode 100644 client/src/version.py diff --git a/Dockerfile b/Dockerfile index 348335c1..4e6bf22d 100644 --- a/Dockerfile +++ b/Dockerfile @@ -91,10 +91,15 @@ COPY client client COPY src/nv_ingest src/nv_ingest RUN rm -rf ./src/nv_ingest/dist ./client/dist +# Install python build from pip, version needed not present in conda +RUN source activate nv_ingest_runtime \ + && pip install 'build>=1.2.2' + # Add pip cache path to match conda's package cache RUN --mount=type=cache,target=/opt/conda/pkgs \ --mount=type=cache,target=/root/.cache/pip \ chmod +x ./ci/scripts/build_pip_packages.sh \ + && source activate nv_ingest_runtime \ && ./ci/scripts/build_pip_packages.sh --type ${RELEASE_TYPE} --lib client \ && ./ci/scripts/build_pip_packages.sh --type ${RELEASE_TYPE} --lib service diff --git a/README.md b/README.md index 503574ab..700f9f6d 100644 --- a/README.md +++ b/README.md @@ -95,7 +95,7 @@ Password: > During the early access (EA) phase, you must apply for early access here: https://developer.nvidia.com/nemo-microservices-early-access/join. > When your early access is approved, follow the instructions in the email to create an organization and team, link your profile, and generate your NGC API key. -4. Create a .env file that contains your NGC API keys. For more information, refer to [](docs/docs/user-guide/developer-guide/environment-config.md). +4. Create a .env file that contains your NGC API keys. For more information, refer to [Environment Configuration Variables](docs/docs/user-guide/developer-guide/environment-config.md). ``` # Container images must access resources from NGC. @@ -184,7 +184,6 @@ pip install . # When not using Conda, pip dependencies for the client can be installed directly via pip. Pip based installation of # the ingest service is not supported. cd client -pip install -r requirements.txt pip install . ``` diff --git a/ci/scripts/build_pip_packages.sh b/ci/scripts/build_pip_packages.sh index 2f0ca3fa..34433ad6 100755 --- a/ci/scripts/build_pip_packages.sh +++ b/ci/scripts/build_pip_packages.sh @@ -41,15 +41,14 @@ fi if [[ "$LIBRARY" == "client" ]]; then NV_INGEST_CLIENT_VERSION_OVERRIDE="${VERSION_SUFFIX}" export NV_INGEST_CLIENT_VERSION_OVERRIDE - SETUP_PATH="$SCRIPT_DIR/../../client/setup.py" + SETUP_PATH="$SCRIPT_DIR/../../client" + (cd "$(dirname "$SETUP_PATH")/client" && python -m build) elif [[ "$LIBRARY" == "service" ]]; then NV_INGEST_SERVICE_VERSION_OVERRIDE="${VERSION_SUFFIX}" export NV_INGEST_SERVICE_VERSION_OVERRIDE SETUP_PATH="$SCRIPT_DIR/../../setup.py" + (cd "$(dirname "$SETUP_PATH")" && python setup.py sdist bdist_wheel) else echo "Invalid library: $LIBRARY" usage fi - -# Build the wheel -(cd "$(dirname "$SETUP_PATH")" && python setup.py sdist bdist_wheel) diff --git a/client/LICENSE b/client/LICENSE new file mode 100644 index 00000000..261eeb9e --- /dev/null +++ b/client/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/client/MANIFEST.in b/client/MANIFEST.in new file mode 100644 index 00000000..406ded85 --- /dev/null +++ b/client/MANIFEST.in @@ -0,0 +1,8 @@ +exclude *.egg-info + +include README.md +include LICENSE +recursive-include src/nv_ingest_client +include src/version.py +global-exclude __pycache__ +global-exclude *.pyc diff --git a/client/client_examples/docker/Dockerfile.client b/client/client_examples/docker/Dockerfile.client index 198e88b5..20885396 100644 --- a/client/client_examples/docker/Dockerfile.client +++ b/client/client_examples/docker/Dockerfile.client @@ -20,7 +20,6 @@ RUN apt update && apt install -y python3-pip git tree \ RUN cd /workspace \ && git clone https://github.com/NVIDIA/nv-ingest.git \ && cd /workspace/nv-ingest/client \ - && pip install -r ./requirements.txt \ && pip install . COPY examples /workspace/client_examples/examples diff --git a/client/pyproject.toml b/client/pyproject.toml new file mode 100644 index 00000000..d9d06e8e --- /dev/null +++ b/client/pyproject.toml @@ -0,0 +1,59 @@ +[build-system] +requires = ["setuptools", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "nv-ingest-client" +description = "Python client for the nv-ingest service" +dynamic = ["version"] # Declare attrs that will be generated at build time +readme = "README.md" +authors = [ + {name = "Jeremy Dyer", email = "jdyer@nvidia.com"} +] +license = {file = "LICENSE"} +classifiers = [ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", +] +dependencies = [ + "azure-storage-blob==12.24.0", + "build>=1.2.2", + "charset-normalizer>=3.4.1", + "click>=8.1.8", + "fsspec>=2025.2.0", + "httpx==0.27.2", + "langchain-milvus==0.1.7", + "langchain-nvidia-ai-endpoints>=0.3.7", + "llama-index-embeddings-nvidia==0.1.5", + "minio>=7.2.15", + "openai==1.40.6", + "pyarrow>=19.0.0", + "pydantic>2.0.0", + "pydantic-settings>2.0.0", + "pymilvus==2.5.4", + "pymilvus[bulk_writer,model]", + "pypdfium2>=4.30.1", + "python-docx>=1.1.2", + "python-magic>=0.4.27", + "python-pptx==0.6.23", + "redis==5.0.8", + "requests>=2.32.3", + "setuptools>=75.8.0", + "tqdm>=4.67.1", +] + +[project.urls] +homepage = "https://github.com/NVIDIA/nv-ingest" +repository = "https://github.com/NVIDIA/nv-ingest" +documentation = "https://docs.nvidia.com/nv-ingest" + +[project.scripts] +nv-ingest-cli = "nv_ingest_client.nv_ingest_cli:main" +process-json-files = "nv_ingest_client.util.process_json_files:main" + +[tool.setuptools.packages.find] +where = ["src"] + +[tool.setuptools.dynamic] +version = {attr = "version.get_version"} diff --git a/client/requirements.txt b/client/requirements.txt deleted file mode 100644 index 116a2561..00000000 --- a/client/requirements.txt +++ /dev/null @@ -1,22 +0,0 @@ -azure-storage-blob==12.24.0 -charset-normalizer -click -fsspec -httpx==0.27.2 -llama-index-embeddings-nvidia==0.1.5 -openai==1.40.6 -pydantic>=2.0.0 -pymilvus==2.4.9 -pymilvus[bulk_writer,model] -pypdfium2 -python-docx -python-magic -python-pptx==0.6.23 -redis~=5.0.1 -requests -setuptools -tqdm -langchain-milvus -langchain-nvidia-ai-endpoints>=0.3.7 -minio -pyarrow diff --git a/client/setup.py b/client/setup.py deleted file mode 100644 index d703d311..00000000 --- a/client/setup.py +++ /dev/null @@ -1,73 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. -# All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - - -import datetime -import os -import re - -from setuptools import find_packages -from setuptools import setup - - -def get_version(): - release_type = os.getenv("NV_INGEST_RELEASE_TYPE", "dev") - version = os.getenv("NV_INGEST_CLIENT_VERSION") - rev = os.getenv("NV_INGEST_REV", "0") - - if not version: - version = f"{datetime.datetime.now().strftime('%Y.%m.%d')}" - - # Ensure the version is PEP 440 compatible - pep440_regex = r"^\d{4}\.\d{1,2}\.\d{1,2}$" - if not re.match(pep440_regex, version): - raise ValueError(f"Version '{version}' is not PEP 440 compatible") - - # Construct the final version string - if release_type == "dev": - final_version = f"{version}.dev{rev}" - elif release_type == "release": - final_version = f"{version}.post{rev}" if int(rev) > 0 else version - else: - raise ValueError(f"Invalid release type: {release_type}") - - return final_version - - -def read_requirements(file_name): - """Read a requirements file and return a list of its packages.""" - with open(file_name) as f: - return f.read().splitlines() - - -# Specify your requirements files -base_dir = os.path.abspath(os.path.dirname(__file__)) -requirements_files = [] - -# Read and combine requirements from all specified files -combined_requirements = [] -for file in requirements_files: - combined_requirements.extend(read_requirements(file)) - -combined_requirements = list(set(combined_requirements)) - -setup( - author="Anuradha Karuppiah", - author_email="anuradhak@nvidia.com", - classifiers=[], - description="Python client for the nv-ingest service", - entry_points={ - "console_scripts": [ - "nv-ingest-cli=nv_ingest_client.nv_ingest_cli:main", - "process-json-files=nv_ingest_client.util.process_json_files:main", - ] - }, - install_requires=combined_requirements, - name="nv_ingest_client", - package_dir={"": "src"}, - packages=find_packages(where="src"), - python_requires=">=3.10", - version=get_version(), - license="Apache-2.0", -) diff --git a/client/src/version.py b/client/src/version.py new file mode 100644 index 00000000..2037ed0f --- /dev/null +++ b/client/src/version.py @@ -0,0 +1,36 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + + +import datetime +import os +import re + + +def get_version(): + release_type = os.getenv("NV_INGEST_RELEASE_TYPE", "dev") + version = os.getenv("NV_INGEST_CLIENT_VERSION") + rev = os.getenv("NV_INGEST_REV", "0") + + if not version: + version = f"{datetime.datetime.now().strftime('%Y.%m.%d')}" + + # Ensure the version is PEP 440 compatible + pep440_regex = r"^\d{4}\.\d{1,2}\.\d{1,2}$" + if not re.match(pep440_regex, version): + raise ValueError(f"Version '{version}' is not PEP 440 compatible") + + # Construct the final version string + if release_type == "dev": + # If rev is not specified and defaults to 0 lets create a more meaningful development + # identifier that is pep440 compliant + if int(rev) == 0: + rev = datetime.datetime.now().strftime("%Y%m%d%H%M%S") + final_version = f"{version}.dev{rev}" + elif release_type == "release": + final_version = f"{version}.post{rev}" if int(rev) > 0 else version + else: + raise ValueError(f"Invalid release type: {release_type}") + + return final_version diff --git a/docs/docs/user-guide/developer-guide/environment-config.md b/docs/docs/user-guide/developer-guide/environment-config.md index 6ed322a0..a3dba36e 100644 --- a/docs/docs/user-guide/developer-guide/environment-config.md +++ b/docs/docs/user-guide/developer-guide/environment-config.md @@ -1,68 +1,20 @@ # Environment Configuration Variables - -- **`MESSAGE_CLIENT_HOST`**: - - - **Description**: Specifies the hostname or IP address of the message broker used for communication between - services. - - **Example**: `redis`, `localhost`, `192.168.1.10` - -- **`MESSAGE_CLIENT_PORT`**: - - - **Description**: Specifies the port number on which the message broker is listening. - - **Example**: `7670`, `6379` - -- **`CAPTION_CLASSIFIER_GRPC_TRITON`**: - - - **Description**: The endpoint where the caption classifier model is hosted using gRPC for communication. This is - used to send requests for caption classification. - You must specify only ONE of an http or gRPC endpoint. If both are specified gRPC will take precedence. - - **Example**: `triton:8001` - -- **`CAPTION_CLASSIFIER_MODEL_NAME`**: - - - **Description**: The name of the caption classifier model. - - **Example**: `deberta_large` - -- **`REDIS_MORPHEUS_TASK_QUEUE`**: - - - **Description**: The name of the task queue in Redis where tasks are stored and processed. - - **Example**: `morpheus_task_queue` - -- **`DOUGHNUT_TRITON_HOST`**: - - - **Description**: The hostname or IP address of the DOUGHNUT model service. - - **Example**: `triton-doughnut` - -- **`DOUGHNUT_TRITON_PORT`**: - - - **Description**: The port number on which the DOUGHNUT model service is listening. - - **Example**: `8001` - -- **`OTEL_EXPORTER_OTLP_ENDPOINT`**: - - - **Description**: The endpoint for the OpenTelemetry exporter, used for sending telemetry data. - - **Example**: `http://otel-collector:4317` - -- **`NGC_API_KEY`**: - - - **Description**: An authorized NGC API key, used to interact with hosted NIMs and can be generated here: https://org.ngc.nvidia.com/setup/personal-keys. - - **Example**: `nvapi-*************` - -- **`MINIO_BUCKET`**: - - - **Description**: Name of MinIO bucket, used to store image, table, and chart extractions. - - **Example**: `nv-ingest` - -- **`INGEST_LOG_LEVEL`**: - - - **Description**: The log level for the ingest service, which controls the verbosity of the logging output. - - **Example**: `DEBUG`, `INFO`, `WARNING`, `ERROR`, `CRITICAL` - -- **`NVIDIA_BUILD_API_KEY`**: - - **Description**: This key is for when you are using the build.nvidia.com endpoint instead of a self hosted Deplot NIM. - This is necessary only in some cases when it is different from `NGC_API_KEY`. If this is not specified, `NGC_API_KEY` is used for build.nvidia.com. - -- **`NIM_NGC_API_KEY`**: - - **Description**: This key is by NIM microservices inside docker containers to access NGC resources. - This is necessary only in some cases when it is different from `NGC_API_KEY`. If this is not specified, `NGC_API_KEY` is used to access NGC resources. +The following are the environment configuration variables that you can specify in your .env file. + + +| Name | Example | Description | +|----------------------------------|--------------------------------|-----------------------------------------------------------------------| +| `CAPTION_CLASSIFIER_GRPC_TRITON` | - `triton:8001`
| The endpoint where the caption classifier model is hosted using gRPC for communication. This is used to send requests for caption classification. You must specify only ONE of an http or gRPC endpoint. If both are specified gRPC will take precedence. | +| `CAPTION_CLASSIFIER_MODEL_NAME` | - `deberta_large`
| The name of the caption classifier model. | +| `DOUGHNUT_TRITON_HOST` | - `triton-doughnut`
| The hostname or IP address of the DOUGHNUT model service. | +| `DOUGHNUT_TRITON_PORT` | - `8001`
| The port number on which the DOUGHNUT model service is listening. | +| `INGEST_LOG_LEVEL` | - `DEBUG`
- `INFO`
- `WARNING`
- `ERROR`
- `CRITICAL`
| The log level for the ingest service, which controls the verbosity of the logging output. | +| `MESSAGE_CLIENT_HOST` | - `redis`
- `localhost`
- `192.168.1.10`
| Specifies the hostname or IP address of the message broker used for communication between services. | +| `MESSAGE_CLIENT_PORT` | - `7670`
- `6379`
| Specifies the port number on which the message broker is listening. | +| `MINIO_BUCKET` | - `nv-ingest`
| Name of MinIO bucket, used to store image, table, and chart extractions. | +| `NGC_API_KEY` | - `nvapi-*************`
| An authorized NGC API key, used to interact with hosted NIMs and can be generated here: https://org.ngc.nvidia.com/setup/personal-keys. | +| `NIM_NGC_API_KEY` | — | The key that NIM microservices inside docker containers use to access NGC resources. This is necessary only in some cases when it is different from `NGC_API_KEY`. If this is not specified, `NGC_API_KEY` is used to access NGC resources. | +| `NVIDIA_BUILD_API_KEY` | — | The key to access NIMs that are hosted on build.nvidia.com instead of a self-hosted NIM. This is necessary only in some cases when it is different from `NGC_API_KEY`. If this is not specified, `NGC_API_KEY` is used for build.nvidia.com. | +| `OTEL_EXPORTER_OTLP_ENDPOINT` | - `http://otel-collector:4317`
| The endpoint for the OpenTelemetry exporter, used for sending telemetry data. | +| `REDIS_MORPHEUS_TASK_QUEUE` | - `morpheus_task_queue`
| The name of the task queue in Redis where tasks are stored and processed. | diff --git a/docs/docs/user-guide/getting-started/quickstart-guide.md b/docs/docs/user-guide/getting-started/quickstart-guide.md index 091cfe3b..d06684ad 100644 --- a/docs/docs/user-guide/getting-started/quickstart-guide.md +++ b/docs/docs/user-guide/getting-started/quickstart-guide.md @@ -33,7 +33,7 @@ Password: > When your early access is approved, follow the instructions in the email to create an organization and team, link your profile, and generate your NGC API key. -4. Create a .env file containing your NGC API key and the following paths. For more information, refer to [](../developer-guide/environment-config.md). +4. Create a .env file containing your NGC API key and the following paths. For more information, refer to [Environment Configuration Variables](../developer-guide/environment-config.md). ``` # Container images must access resources from NGC. @@ -112,7 +112,6 @@ To interact from the host, you'll need a Python environment and install the clie conda create --name nv-ingest-dev python=3.10 conda activate nv-ingest-dev cd client -pip install -r ./requirements.txt pip install . ``` diff --git a/src/nv_ingest/extraction_workflows/image/image_handlers.py b/src/nv_ingest/extraction_workflows/image/image_handlers.py index e645b4f2..315ce2d0 100644 --- a/src/nv_ingest/extraction_workflows/image/image_handlers.py +++ b/src/nv_ingest/extraction_workflows/image/image_handlers.py @@ -27,7 +27,6 @@ import numpy as np from PIL import Image -from math import log from wand.image import Image as WandImage import nv_ingest.util.nim.yolox as yolox_utils @@ -186,7 +185,7 @@ def extract_tables_and_charts_from_images( ---------- images : List[np.ndarray] List of images in NumPy array format. - config : PDFiumConfigSchema + config : ImageConfigSchema Configuration object containing YOLOX endpoints, auth token, etc. trace_info : Optional[List], optional Optional tracing data for debugging/performance profiling. @@ -194,8 +193,8 @@ def extract_tables_and_charts_from_images( Returns ------- List[Tuple[int, object]] - A list of (image_index, CroppedImageWithContent) - representing extracted table/chart data from each image. + A list of (image_index, CroppedImageWithContent) representing extracted + table/chart data from each image. """ tables_and_charts = [] yolox_client = None @@ -209,41 +208,31 @@ def extract_tables_and_charts_from_images( config.yolox_infer_protocol, ) - max_batch_size = YOLOX_MAX_BATCH_SIZE - batches = [] - i = 0 - while i < len(images): - batch_size = min(2 ** int(log(len(images) - i, 2)), max_batch_size) - batches.append(images[i : i + batch_size]) # noqa: E203 - i += batch_size - - img_index = 0 - for batch in batches: - data = {"images": batch} - - # NimClient inference - inference_results = yolox_client.infer( - data, - model_name="yolox", - max_batch_size=YOLOX_MAX_BATCH_SIZE, - num_classes=YOLOX_NUM_CLASSES, - conf_thresh=YOLOX_CONF_THRESHOLD, - iou_thresh=YOLOX_IOU_THRESHOLD, - min_score=YOLOX_MIN_SCORE, - final_thresh=YOLOX_FINAL_SCORE, - trace_info=trace_info, # traceable_func arg - stage_name="pdf_content_extractor", # traceable_func arg - ) + # Prepare the payload with all images. + data = {"images": images} + + # Perform inference in a single call. The NimClient handles batching internally. + inference_results = yolox_client.infer( + data, + model_name="yolox", + max_batch_size=YOLOX_MAX_BATCH_SIZE, + num_classes=YOLOX_NUM_CLASSES, + conf_thresh=YOLOX_CONF_THRESHOLD, + iou_thresh=YOLOX_IOU_THRESHOLD, + min_score=YOLOX_MIN_SCORE, + final_thresh=YOLOX_FINAL_SCORE, + trace_info=trace_info, + stage_name="pdf_content_extractor", + ) - # 5) Extract table/chart info from each image's annotations - for annotation_dict, original_image in zip(inference_results, batch): - extract_table_and_chart_images( - annotation_dict, - original_image, - img_index, - tables_and_charts, - ) - img_index += 1 + # Process each result along with its corresponding image. + for i, (annotation_dict, original_image) in enumerate(zip(inference_results, images)): + extract_table_and_chart_images( + annotation_dict, + original_image, + i, + tables_and_charts, + ) except TimeoutError: logger.error("Timeout error during table/chart extraction.") @@ -252,14 +241,13 @@ def extract_tables_and_charts_from_images( except Exception as e: logger.error(f"Unhandled error during table/chart extraction: {str(e)}") traceback.print_exc() - raise e + raise finally: if yolox_client: yolox_client.close() logger.debug(f"Extracted {len(tables_and_charts)} tables and charts from image.") - return tables_and_charts diff --git a/src/nv_ingest/extraction_workflows/pdf/pdfium_helper.py b/src/nv_ingest/extraction_workflows/pdf/pdfium_helper.py index f7521ec7..b1608ddc 100644 --- a/src/nv_ingest/extraction_workflows/pdf/pdfium_helper.py +++ b/src/nv_ingest/extraction_workflows/pdf/pdfium_helper.py @@ -19,7 +19,6 @@ import concurrent.futures import logging import traceback -from math import log from typing import List from typing import Optional from typing import Tuple @@ -61,55 +60,58 @@ def extract_tables_and_charts_using_image_ensemble( pages: List[Tuple[int, np.ndarray]], config: PDFiumConfigSchema, trace_info: Optional[List] = None, -) -> List[Tuple[int, object]]: # List[Tuple[int, CroppedImageWithContent]] +) -> List[Tuple[int, object]]: + """ + Given a list of (page_index, image) tuples, this function calls the YOLOX-based + inference service to extract table and chart annotations from all pages. + + Returns + ------- + List[Tuple[int, object]] + For each page, returns (page_index, joined_content) where joined_content + is the result of combining annotations from the inference. + """ tables_and_charts = [] + yolox_client = None try: model_interface = yolox_utils.YoloxPageElementsModelInterface() yolox_client = create_inference_client( - config.yolox_endpoints, model_interface, config.auth_token, config.yolox_infer_protocol + config.yolox_endpoints, + model_interface, + config.auth_token, + config.yolox_infer_protocol, ) - batches = [] - i = 0 - max_batch_size = YOLOX_MAX_BATCH_SIZE - while i < len(pages): - batch_size = min(2 ** int(log(len(pages) - i, 2)), max_batch_size) - batches.append(pages[i : i + batch_size]) # noqa: E203 - i += batch_size - - page_index = 0 - for batch in batches: - image_page_indices = [page[0] for page in batch] - original_images = [page[1] for page in batch] - - # Prepare data - data = {"images": original_images} - - # Perform inference using NimClient - inference_results = yolox_client.infer( - data, - model_name="yolox", - max_batch_size=YOLOX_MAX_BATCH_SIZE, - num_classes=YOLOX_NUM_CLASSES, - conf_thresh=YOLOX_CONF_THRESHOLD, - iou_thresh=YOLOX_IOU_THRESHOLD, - min_score=YOLOX_MIN_SCORE, - final_thresh=YOLOX_FINAL_SCORE, - trace_info=trace_info, # traceable_func arg - stage_name="pdf_content_extractor", # traceable_func arg - ) + # Collect all page indices and images in order. + image_page_indices = [page[0] for page in pages] + original_images = [page[1] for page in pages] + + # Prepare the data payload with all images. + data = {"images": original_images} + + # Perform inference using the NimClient. + inference_results = yolox_client.infer( + data, + model_name="yolox", + max_batch_size=YOLOX_MAX_BATCH_SIZE, + num_classes=YOLOX_NUM_CLASSES, + conf_thresh=YOLOX_CONF_THRESHOLD, + iou_thresh=YOLOX_IOU_THRESHOLD, + min_score=YOLOX_MIN_SCORE, + final_thresh=YOLOX_FINAL_SCORE, + trace_info=trace_info, + stage_name="pdf_content_extractor", + ) - # Process results - for annotation_dict, page_index, original_image in zip( - inference_results, image_page_indices, original_images - ): - extract_table_and_chart_images( - annotation_dict, - original_image, - page_index, - tables_and_charts, - ) + # Process results: iterate over each image's inference output. + for annotation_dict, page_index, original_image in zip(inference_results, image_page_indices, original_images): + extract_table_and_chart_images( + annotation_dict, + original_image, + page_index, + tables_and_charts, + ) except TimeoutError: logger.error("Timeout error during table/chart extraction.") @@ -118,14 +120,13 @@ def extract_tables_and_charts_using_image_ensemble( except Exception as e: logger.error(f"Unhandled error during table/chart extraction: {str(e)}") traceback.print_exc() - raise e + raise finally: if yolox_client: yolox_client.close() logger.debug(f"Extracted {len(tables_and_charts)} tables and charts.") - return tables_and_charts diff --git a/src/nv_ingest/stages/nim/chart_extraction.py b/src/nv_ingest/stages/nim/chart_extraction.py index e3c55907..f13d9321 100644 --- a/src/nv_ingest/stages/nim/chart_extraction.py +++ b/src/nv_ingest/stages/nim/chart_extraction.py @@ -11,6 +11,7 @@ from typing import Optional from typing import Tuple +import numpy as np import pandas as pd from morpheus.config import Config @@ -35,107 +36,79 @@ def _update_metadata( yolox_client: NimClient, paddle_client: NimClient, trace_info: Dict, - batch_size: int = 1, - worker_pool_size: int = 1, + worker_pool_size: int = 8, # Not currently used. ) -> List[Tuple[str, Dict]]: """ - Given a list of base64-encoded chart images, this function: - - Splits them into batches of size `batch_size`. - - Calls Yolox and Paddle with *all images* in each batch in a single request if protocol != 'grpc'. - If protocol == 'grpc', calls Yolox and Paddle individually for each image in the batch. - - Joins the results for each image into a final combined inference result. + Given a list of base64-encoded chart images, this function calls both the Yolox and Paddle + inference services concurrently to extract chart data for all images. - Returns - ------- - List[Tuple[str, Dict]] - For each base64-encoded image, returns (original_image_str, joined_chart_content_dict). + For each base64-encoded image, returns: + (original_image_str, joined_chart_content_dict) """ - logger.debug(f"Running chart extraction: batch_size={batch_size}, worker_pool_size={worker_pool_size}") + logger.debug("Running chart extraction using updated concurrency handling.") + + valid_images: List[str] = [] + valid_arrays: List[np.ndarray] = [] + + # Pre-decode image dimensions and filter valid images. + for i, img in enumerate(base64_images): + array = base64_to_numpy(img) + height, width = array.shape[0], array.shape[1] + if width >= PADDLE_MIN_WIDTH and height >= PADDLE_MIN_HEIGHT: + valid_images.append(img) + valid_arrays.append(array) + + # Prepare data payloads for both clients. + data_yolox = {"images": valid_arrays} + data_paddle = {"base64_images": valid_images} + + _ = worker_pool_size + with ThreadPoolExecutor(max_workers=2) as executor: + future_yolox = executor.submit( + yolox_client.infer, + data=data_yolox, + model_name="yolox", + stage_name="chart_data_extraction", + max_batch_size=8, + trace_info=trace_info, + ) + future_paddle = executor.submit( + paddle_client.infer, + data=data_paddle, + model_name="paddle", + stage_name="chart_data_extraction", + max_batch_size=1, + trace_info=trace_info, + ) + + try: + yolox_results = future_yolox.result() + except Exception as e: + logger.error(f"Error calling yolox_client.infer: {e}", exc_info=True) + raise + + try: + paddle_results = future_paddle.result() + except Exception as e: + logger.error(f"Error calling yolox_client.infer: {e}", exc_info=True) + raise + + # Ensure both clients returned lists of results matching the number of input images. + if not (isinstance(yolox_results, list) and isinstance(paddle_results, list)): + raise ValueError("Expected list results from both yolox_client and paddle_client infer calls.") - def chunk_list(lst, chunk_size): - for i in range(0, len(lst), chunk_size): - yield lst[i : i + chunk_size] + if len(yolox_results) != len(base64_images): + raise ValueError(f"Expected {len(base64_images)} yolox results, got {len(yolox_results)}") + if len(paddle_results) != len(base64_images): + raise ValueError(f"Expected {len(base64_images)} paddle results, got {len(paddle_results)}") + # Join the corresponding results from both services for each image. results = [] - image_arrays = [base64_to_numpy(img) for img in base64_images] - - with ThreadPoolExecutor(max_workers=worker_pool_size) as executor: - for batch, arrays in zip(chunk_list(base64_images, batch_size), chunk_list(image_arrays, batch_size)): - # 1) Yolox calls - # Single request for the entire batch - data = {"images": arrays} - yolox_futures = executor.submit( - yolox_client.infer, - data=data, - model_name="yolox", - stage_name="chart_data_extraction", - trace_info=trace_info, - ) - - # 2) Paddle calls - paddle_futures = [] - if paddle_client.protocol == "grpc": - # Submit each image in the batch separately - paddle_futures = [] - for image_str, image_arr in zip(batch, arrays): - width, height = image_arr.shape[:2] - if width < PADDLE_MIN_WIDTH or height < PADDLE_MIN_HEIGHT: - # Too small, skip inference - continue - - data = {"base64_images": [image_str]} - fut = executor.submit( - paddle_client.infer, - data=data, - model_name="paddle", - stage_name="chart_data_extraction", - max_batch_size=1, - trace_info=trace_info, - ) - paddle_futures.append(fut) - else: - # Single request for the entire batch - data = {"base64_images": batch} - paddle_futures = executor.submit( - paddle_client.infer, - data=data, - model_name="paddle", - stage_name="chart_data_extraction", - max_batch_size=batch_size, - trace_info=trace_info, - ) - - try: - # Retrieve results from Yolox - yolox_results = yolox_futures.result() - - # 3) Retrieve results from Yolox - if paddle_client.protocol == "grpc": - # Each future should return a single-element list - # We take the 0th item to align with single-image results - paddle_results = [] - for fut in paddle_futures: - res = fut.result() - if isinstance(res, list) and len(res) == 1: - paddle_results.append(res[0]) - else: - # Fallback in case the service returns something unexpected - logger.warning(f"Unexpected PaddleOCR result format: {res}") - paddle_results.append(res) - else: - # Single call returning a list of the same length as 'batch' - paddle_results = paddle_futures.result() - - # 4) Zip them together, one by one - for img_str, yolox_res, paddle_res in zip(batch, yolox_results, paddle_results): - bounding_boxes, text_predictions = paddle_res - yolox_elements = join_yolox_and_paddle_output(yolox_res, bounding_boxes, text_predictions) - chart_content = process_yolox_graphic_elements(yolox_elements) - results.append((img_str, chart_content)) - - except Exception as e: - logger.error(f"Error processing batch: {batch}, error: {e}", exc_info=True) - raise + for img_str, yolox_res, paddle_res in zip(base64_images, yolox_results, paddle_results): + bounding_boxes, text_predictions = paddle_res + yolox_elements = join_yolox_and_paddle_output(yolox_res, bounding_boxes, text_predictions) + chart_content = process_yolox_graphic_elements(yolox_elements) + results.append((img_str, chart_content)) return results @@ -253,7 +226,6 @@ def meets_criteria(row): base64_images=base64_images, yolox_client=yolox_client, paddle_client=paddle_client, - batch_size=stage_config.nim_batch_size, worker_pool_size=stage_config.workers_per_progress_engine, trace_info=trace_info, ) diff --git a/src/nv_ingest/stages/nim/table_extraction.py b/src/nv_ingest/stages/nim/table_extraction.py index ec535f77..4836e107 100644 --- a/src/nv_ingest/stages/nim/table_extraction.py +++ b/src/nv_ingest/stages/nim/table_extraction.py @@ -4,7 +4,6 @@ import functools import logging -from concurrent.futures import ThreadPoolExecutor from typing import Any from typing import Dict from typing import List @@ -21,6 +20,7 @@ from nv_ingest.util.image_processing.transforms import base64_to_numpy from nv_ingest.util.nim.helpers import NimClient from nv_ingest.util.nim.helpers import create_inference_client +from nv_ingest.util.nim.helpers import get_version from nv_ingest.util.nim.paddle import PaddleOCRModelInterface logger = logging.getLogger(__name__) @@ -32,136 +32,67 @@ def _update_metadata( base64_images: List[str], paddle_client: NimClient, - batch_size: int = 1, - worker_pool_size: int = 1, + worker_pool_size: int = 8, # Not currently used trace_info: Dict = None, ) -> List[Tuple[str, Tuple[Any, Any]]]: """ - Given a list of base64-encoded images, this function processes them either individually - (if paddle_client.protocol == 'grpc') or in batches (if paddle_client.protocol == 'http'), - then calls the PaddleOCR model to extract data. + Given a list of base64-encoded images, this function filters out images that do not meet the minimum + size requirements and then calls the PaddleOCR model via paddle_client.infer to extract table data. For each base64-encoded image, the result is: (base64_image, (text_predictions, bounding_boxes)) - Images that do not meet the minimum size are skipped (("", "")). + Images that do not meet the minimum size are skipped (resulting in ("", "") for that image). + The paddle_client is expected to handle any necessary batching and concurrency. """ - logger.debug( - f"Running table extraction: batch_size={batch_size}, " - f"worker_pool_size={worker_pool_size}, protocol={paddle_client.protocol}" - ) + logger.debug(f"Running table extraction using protocol {paddle_client.protocol}") - # We'll build the final results in the same order as base64_images. + # Initialize the results list in the same order as base64_images. results: List[Optional[Tuple[str, Tuple[Any, Any]]]] = [None] * len(base64_images) - # Pre-decode dimensions once (optional, but efficient if we want to skip small images). - decoded_shapes = [] - for img in base64_images: + valid_images: List[str] = [] + valid_indices: List[int] = [] + + _ = worker_pool_size + # Pre-decode image dimensions and filter valid images. + for i, img in enumerate(base64_images): array = base64_to_numpy(img) - decoded_shapes.append(array.shape) # e.g. (height, width, channels) - - # ------------------------------------------------ - # GRPC path: submit one request per valid image. - # ------------------------------------------------ - if paddle_client.protocol == "grpc": - with ThreadPoolExecutor(max_workers=worker_pool_size) as executor: - future_to_index = {} - - # Submit individual requests - for i, b64_image in enumerate(base64_images): - height, width = decoded_shapes[i][0], decoded_shapes[i][1] - if width < PADDLE_MIN_WIDTH or height < PADDLE_MIN_HEIGHT: - # Too small, skip inference - results[i] = (b64_image, (None, None)) - continue - - # Enqueue a single-image inference - data = {"base64_images": [b64_image]} # single item - future = executor.submit( - paddle_client.infer, - data=data, - model_name="paddle", - stage_name="table_data_extraction", - max_batch_size=1, - trace_info=trace_info, - ) - future_to_index[future] = i - - # Gather results - for future, i in future_to_index.items(): - b64_image = base64_images[i] - try: - paddle_result = future.result() - # We expect exactly one result for one image - if not isinstance(paddle_result, list) or len(paddle_result) != 1: - raise ValueError(f"Expected 1 result list, got: {paddle_result}") - bounding_boxes, text_predictions = paddle_result[0] - results[i] = (b64_image, (bounding_boxes, text_predictions)) - except Exception as e: - logger.error(f"Error processing image {i}. Error: {e}", exc_info=True) - results[i] = (b64_image, (None, None)) - raise - - # ------------------------------------------------ - # HTTP path: submit requests in batches. - # ------------------------------------------------ - else: - with ThreadPoolExecutor(max_workers=worker_pool_size) as executor: - # Process images in chunks - for start_idx in range(0, len(base64_images), batch_size): - chunk_indices = range(start_idx, min(start_idx + batch_size, len(base64_images))) - valid_indices = [] - valid_images = [] - - # Check dimensions & collect valid images - for i in chunk_indices: - height, width = decoded_shapes[i][0], decoded_shapes[i][1] - if width >= PADDLE_MIN_WIDTH and height >= PADDLE_MIN_HEIGHT: - valid_indices.append(i) - valid_images.append(base64_images[i]) - else: - # Too small, skip inference - results[i] = (base64_images[i], (None, None)) - - if not valid_images: - # All images in this chunk were too small - continue - - # Submit a single batch inference - data = {"base64_images": valid_images} - future = executor.submit( - paddle_client.infer, - data=data, - model_name="paddle", - stage_name="table_data_extraction", - max_batch_size=batch_size, - trace_info=trace_info, - ) - - try: - # This should be a list of (text_predictions, bounding_boxes) - # in the same order as valid_images - paddle_result = future.result() - - if not isinstance(paddle_result, list): - raise ValueError(f"Expected a list of tuples, got {type(paddle_result)}") - - if len(paddle_result) != len(valid_images): - raise ValueError(f"Expected {len(valid_images)} results, got {len(paddle_result)}") - - # Match each result back to its original index - for idx_in_batch, (tc, tf) in enumerate(paddle_result): - i = valid_indices[idx_in_batch] - results[i] = (base64_images[i], (tc, tf)) - - except Exception as e: - logger.error(f"Error processing batch {valid_images}. Error: {e}", exc_info=True) - # If inference fails, we can fill them with empty or re-raise - for vi in valid_indices: - results[vi] = (base64_images[vi], (None, None)) - raise - - # 'results' now has an entry for every image in base64_images + height, width = array.shape[0], array.shape[1] + if width >= PADDLE_MIN_WIDTH and height >= PADDLE_MIN_HEIGHT: + valid_images.append(img) + valid_indices.append(i) + else: + # Image is too small; mark as skipped. + results[i] = (img, (None, None)) + + if valid_images: + data = {"base64_images": valid_images} + try: + # Call infer once for all valid images. The NimClient will handle batching internally. + paddle_result = paddle_client.infer( + data=data, + model_name="paddle", + stage_name="table_data_extraction", + max_batch_size=1, + trace_info=trace_info, + ) + + if not isinstance(paddle_result, list): + raise ValueError(f"Expected a list of tuples, got {type(paddle_result)}") + if len(paddle_result) != len(valid_images): + raise ValueError(f"Expected {len(valid_images)} results, got {len(paddle_result)}") + + # Assign each result back to its original position. + for idx, result in enumerate(paddle_result): + original_index = valid_indices[idx] + results[original_index] = (base64_images[original_index], result) + + except Exception as e: + logger.error(f"Error processing images. Error: {e}", exc_info=True) + for i in valid_indices: + results[i] = (base64_images[i], (None, None)) + raise + return results @@ -170,6 +101,16 @@ def _create_paddle_client(stage_config) -> NimClient: Helper to create a NimClient for PaddleOCR, retrieving the paddle version from the endpoint. """ # Attempt to obtain PaddleOCR version from the second endpoint + paddle_endpoint = stage_config.paddle_endpoints[1] + try: + paddle_version = get_version(paddle_endpoint) + if not paddle_version: + logger.warning("Failed to obtain PaddleOCR version from the endpoint. Falling back to the latest version.") + paddle_version = None + except Exception: + logger.warning("Failed to get PaddleOCR version after 30 seconds. Falling back to the latest version.") + paddle_version = None + paddle_model_interface = PaddleOCRModelInterface() paddle_client = create_inference_client( @@ -251,7 +192,6 @@ def meets_criteria(row): bulk_results = _update_metadata( base64_images=base64_images, paddle_client=paddle_client, - batch_size=stage_config.nim_batch_size, worker_pool_size=stage_config.workers_per_progress_engine, trace_info=trace_info, ) diff --git a/src/nv_ingest/util/nim/helpers.py b/src/nv_ingest/util/nim/helpers.py index 10c80938..9448ba5b 100644 --- a/src/nv_ingest/util/nim/helpers.py +++ b/src/nv_ingest/util/nim/helpers.py @@ -6,6 +6,8 @@ import re import threading import time +from concurrent.futures import ThreadPoolExecutor +from functools import partial from typing import Any from typing import Optional from typing import Tuple @@ -111,7 +113,7 @@ def __init__( protocol: str, endpoints: Tuple[str, str], auth_token: Optional[str] = None, - timeout: float = 30.0, + timeout: float = 120.0, max_retries: int = 5, ): """ @@ -187,6 +189,39 @@ def _fetch_max_batch_size(self, model_name, model_version: str = "") -> int: return self._max_batch_sizes[model_name] + def _process_batch(self, batch_input, *, prepared_data, model_name, **kwargs): + """ + Process a single batch input for inference. + + Parameters + ---------- + batch_input : Any + The batch input data to process. + prepared_data : Any + The prepared data used for inference. + model_name : str + The model name to use for inference. + kwargs : dict + Additional parameters for inference. + + Returns + ------- + Any + The parsed output from the inference request. + """ + if self.protocol == "grpc": + logger.debug("Performing gRPC inference for a batch...") + response = self._grpc_infer(batch_input, model_name) + logger.debug("gRPC inference received response for a batch") + elif self.protocol == "http": + logger.debug("Performing HTTP inference for a batch...") + response = self._http_infer(batch_input) + logger.debug("HTTP inference received response for a batch") + else: + raise ValueError("Invalid protocol specified. Must be 'grpc' or 'http'.") + + return self.model_interface.parse_output(response, protocol=self.protocol, data=prepared_data, **kwargs) + def try_set_max_batch_size(self, model_name, model_version: str = ""): """Attempt to set the max batch size for the model if it is not already set, ensuring thread safety.""" self._fetch_max_batch_size(model_name, model_version) @@ -203,7 +238,8 @@ def infer(self, data: dict, model_name: str, **kwargs) -> Any: model_name : str The name of the model to use for inference. kwargs : dict - Additional parameters for inference. + Additional parameters for inference. Optionally supports "max_pool_workers" to set + the number of worker threads in the thread pool. Returns ------- @@ -215,54 +251,38 @@ def infer(self, data: dict, model_name: str, **kwargs) -> Any: ValueError If an invalid protocol is specified. """ - try: - # 1. Retrieve or default to the model's maximum batch size + # 1. Retrieve or default to the model's maximum batch size. batch_size = self._fetch_max_batch_size(model_name) max_requested_batch_size = kwargs.get("max_batch_size", batch_size) force_requested_batch_size = kwargs.get("force_max_batch_size", False) - # 1a. In some cases we can't use the absolute max batch size (or don't want to) so we allow override - # 1b. In some cases we can't reliably retrieve the max batch size so we default to 1 and allow forced - # override if not force_requested_batch_size: max_batch_size = min(batch_size, max_requested_batch_size) else: max_batch_size = max_requested_batch_size - # 2. Prepare data for inference + # 2. Prepare data for inference. prepared_data = self.model_interface.prepare_data_for_inference(data) - # 3. Format the input based on protocol + # 3. Format the input based on protocol. # NOTE: This now returns a list of batches. formatted_batches = self.model_interface.format_input( prepared_data, protocol=self.protocol, max_batch_size=max_batch_size ) - # Container for all parsed outputs - all_parsed_outputs = [] - - # 4. Loop over each batch - for batch_input in formatted_batches: - if self.protocol == "grpc": - logger.debug("Performing gRPC inference for a batch...") - response = self._grpc_infer(batch_input, model_name) - logger.debug("gRPC inference received response for a batch") - elif self.protocol == "http": - logger.debug("Performing HTTP inference for a batch...") - response = self._http_infer(batch_input) - logger.debug("HTTP inference received response for a batch") - else: - raise ValueError("Invalid protocol specified. Must be 'grpc' or 'http'.") + # Check for a custom maximum pool worker count, and remove it from kwargs. + max_pool_workers = kwargs.pop("max_pool_workers", 16) - # Parse the output of this batch - parsed_output = self.model_interface.parse_output( - response, protocol=self.protocol, data=prepared_data, **kwargs - ) - # Accumulate parsed outputs - all_parsed_outputs.append(parsed_output) + # 4. Process each batch concurrently using a thread pool. + process_batch_partial = partial( + self._process_batch, prepared_data=prepared_data, model_name=model_name, **kwargs + ) + + with ThreadPoolExecutor(max_workers=max_pool_workers) as executor: + all_parsed_outputs = list(executor.map(process_batch_partial, formatted_batches)) - # 5. Process the parsed outputs for each batch + # 5. Process the parsed outputs for each batch. all_results = [] for parsed_output in all_parsed_outputs: batch_results = self.model_interface.process_inference_results( @@ -271,8 +291,6 @@ def infer(self, data: dict, model_name: str, **kwargs) -> Any: protocol=self.protocol, **kwargs, ) - # Extend or append based on how `batch_results` is structured - # (assuming it's a list of result items): if isinstance(batch_results, list): all_results.extend(batch_results) else: @@ -283,7 +301,6 @@ def infer(self, data: dict, model_name: str, **kwargs) -> Any: logger.error(error_str) raise RuntimeError(error_str) - # 6. Return final accumulated results return all_results def _grpc_infer(self, formatted_input: np.ndarray, model_name: str) -> np.ndarray: diff --git a/src/nv_ingest/util/nim/yolox.py b/src/nv_ingest/util/nim/yolox.py index ad51df16..e411a8ae 100644 --- a/src/nv_ingest/util/nim/yolox.py +++ b/src/nv_ingest/util/nim/yolox.py @@ -184,7 +184,7 @@ def format_input(self, data: Dict[str, Any], protocol: str, max_batch_size: int, ) if new_size != original_size: - logger.warning(f"Image was scaled from {original_size} to {new_size}.") + logger.debug(f"Image was scaled from {original_size} to {new_size}.") # Add to content_list content_list.append({"type": "image_url", "url": f"data:image/png;base64,{scaled_image_b64}"}) diff --git a/tests/nv_ingest/stages/nims/test_chart_extraction.py b/tests/nv_ingest/stages/nims/test_chart_extraction.py index 92627c2a..14672d2c 100644 --- a/tests/nv_ingest/stages/nims/test_chart_extraction.py +++ b/tests/nv_ingest/stages/nims/test_chart_extraction.py @@ -24,7 +24,6 @@ def valid_chart_extractor_config(): yolox_infer_protocol="grpc", paddle_endpoints=("paddle_grpc_url", "paddle_http_url"), paddle_infer_protocol="grpc", - nim_batch_size=2, workers_per_progress_engine=5, ) @@ -50,31 +49,50 @@ class FakeValidated: def test_update_metadata_empty_list(): """ - If the base64_images list is empty, _update_metadata should - skip all logic and return an empty list. + If the base64_images list is empty, _update_metadata should return an empty list. + With the updated implementation, both clients are still invoked (with an empty list) + so we set their return values to [] and then verify the calls. """ yolox_mock = MagicMock() paddle_mock = MagicMock() trace_info = {} + # When given an empty list, both clients return an empty list. + yolox_mock.infer.return_value = [] + paddle_mock.infer.return_value = [] + result = _update_metadata( base64_images=[], yolox_client=yolox_mock, paddle_client=paddle_mock, trace_info=trace_info, - batch_size=1, worker_pool_size=1, ) assert result == [] - yolox_mock.infer.assert_not_called() + + # Each client's infer should be called once with an empty list. + yolox_mock.infer.assert_called_once_with( + data={"images": []}, + model_name="yolox", + stage_name="chart_data_extraction", + max_batch_size=8, + trace_info=trace_info, + ) + paddle_mock.infer.assert_called_once_with( + data={"base64_images": []}, + model_name="paddle", + stage_name="chart_data_extraction", + max_batch_size=2, + trace_info=trace_info, + ) def test_update_metadata_single_batch_single_worker(mocker, base64_image): """ - Test a simple scenario with a small list of base64_images, batch_size=2, - worker_pool_size=1. We verify that yolox is called once per batch, - and paddle is called once per image in that batch. + Test a simple scenario with a small list of base64_images using worker_pool_size=1. + In the updated _update_metadata implementation, both the yolox and paddle clients are + called once with the full list of images. The join function is applied per image. """ # Mock out the clients yolox_mock = MagicMock() @@ -99,12 +117,12 @@ def test_update_metadata_single_batch_single_worker(mocker, base64_image): result = _update_metadata(base64_images, yolox_mock, paddle_mock, trace_info, batch_size=2, worker_pool_size=1) - # We expect result => [("img1", "joined_1"), ("img2", "joined_2")] + # Expect the result to combine each original image with its corresponding joined output. assert len(result) == 2 assert result[0] == (base64_image, "joined_1") assert result[1] == (base64_image, "joined_2") - # Check calls + # yolox.infer should be called once with the full list of arrays. assert yolox_mock.infer.call_count == 1 assert np.all(yolox_mock.infer.call_args.kwargs["data"]["images"][0] == base64_to_numpy(base64_image)) assert np.all(yolox_mock.infer.call_args.kwargs["data"]["images"][1] == base64_to_numpy(base64_image)) @@ -112,9 +130,8 @@ def test_update_metadata_single_batch_single_worker(mocker, base64_image): assert yolox_mock.infer.call_args.kwargs["stage_name"] == "chart_data_extraction" assert yolox_mock.infer.call_args.kwargs["trace_info"] == trace_info - # paddle.infer called once per image - assert paddle_mock.infer.call_count == 1 - paddle_mock.infer.assert_any_call( + # paddle.infer should be called once with the full list of images. + paddle_mock.infer.assert_called_once_with( data={"base64_images": [base64_image, base64_image]}, model_name="paddle", stage_name="chart_data_extraction", @@ -122,14 +139,15 @@ def test_update_metadata_single_batch_single_worker(mocker, base64_image): trace_info=trace_info, ) - # join_yolox_and_paddle_output called twice + # The join function should be invoked once per image. assert mock_join.call_count == 2 def test_update_metadata_multiple_batches_multi_worker(mocker, base64_image): """ - If batch_size=1 but we have multiple images, each image forms its own batch. - We also can use worker_pool_size=2 for parallel calls. + With the new _update_metadata implementation, both cached_client.infer and deplot_client.infer + are called once with the full list of images. Their results are expected to be lists with one + item per image. The join function is still invoked for each image. """ yolox_mock = MagicMock() paddle_mock = MagicMock() @@ -160,7 +178,6 @@ def paddle_side_effect(**kwargs): yolox_mock, paddle_mock, trace_info, - batch_size=1, # each image in its own batch worker_pool_size=2, ) @@ -177,8 +194,7 @@ def paddle_side_effect(**kwargs): def test_update_metadata_exception_in_yolox_call(base64_image, caplog): """ - If the yolox call fails for a batch, we expect an exception to bubble up - and the error logged. + If the yolox call fails, we expect an exception to bubble up and the error to be logged. """ yolox_mock = MagicMock() paddle_mock = MagicMock() @@ -187,24 +203,24 @@ def test_update_metadata_exception_in_yolox_call(base64_image, caplog): with pytest.raises(Exception, match="Yolox call error"): _update_metadata([base64_image], yolox_mock, paddle_mock, trace_info={}, batch_size=1, worker_pool_size=1) - # Check log - assert f"Error processing batch: ['{base64_image}']" in caplog.text + # Verify that the error message from the cached client is logged. + assert "Error calling yolox_client.infer: Cached call error" in caplog.text def test_update_metadata_exception_in_paddle_call(base64_image, caplog): """ - If any paddle call fails for one of the images, the entire process fails - and logs the error. + If the paddle call fails, we expect an exception to bubble up and the error to be logged. """ yolox_mock = MagicMock() - yolox_mock.infer.return_value = ["yolox_result"] # 1-element list + yolox_mock.infer.return_value = ["yolox_result"] # Single-element list for one image paddle_mock = MagicMock() paddle_mock.infer.side_effect = Exception("Paddle error") with pytest.raises(Exception, match="Paddle error"): _update_metadata([base64_image], yolox_mock, paddle_mock, trace_info={}, batch_size=1, worker_pool_size=2) - assert f"Error processing batch: ['{base64_image}']" in caplog.text + # Verify that the error message from the deplot client is logged. + assert "Error calling paddle_client.infer: Deplot error" in caplog.text def test_create_clients(mocker): @@ -339,7 +355,6 @@ def test_extract_chart_data_all_valid(validated_config, mocker): base64_images=["imgA", "imgB"], yolox_client=yolox_mock, paddle_client=paddle_mock, - batch_size=validated_config.stage_config.nim_batch_size, worker_pool_size=validated_config.stage_config.workers_per_progress_engine, trace_info=ti.get("trace_info"), ) @@ -398,7 +413,6 @@ def test_extract_chart_data_mixed_rows(validated_config, mocker): base64_images=["base64img1", "base64img2"], yolox_client=yolox_mock, paddle_client=paddle_mock, - batch_size=validated_config.stage_config.nim_batch_size, worker_pool_size=validated_config.stage_config.workers_per_progress_engine, trace_info=trace.get("trace_info"), ) diff --git a/tests/nv_ingest/stages/nims/test_table_extraction.py b/tests/nv_ingest/stages/nims/test_table_extraction.py index 8f1a1356..e3da6c90 100644 --- a/tests/nv_ingest/stages/nims/test_table_extraction.py +++ b/tests/nv_ingest/stages/nims/test_table_extraction.py @@ -36,7 +36,6 @@ def validated_config(): class FakeStageConfig: # Values that _extract_table_data expects - nim_batch_size = 4 workers_per_progress_engine = 5 auth_token = "fake-token" # For _create_paddle_client @@ -143,7 +142,6 @@ def test_extract_table_data_all_valid(mocker, validated_config): mock_update_metadata.assert_called_once_with( base64_images=["imgA", "imgB"], paddle_client=mock_client, - batch_size=validated_config.stage_config.nim_batch_size, worker_pool_size=validated_config.stage_config.workers_per_progress_engine, trace_info=trace_info.get("trace_info"), ) @@ -202,7 +200,6 @@ def test_extract_table_data_mixed_rows(mocker, validated_config): mock_update_metadata.assert_called_once_with( base64_images=["good1", "good2"], paddle_client=mock_client, - batch_size=validated_config.stage_config.nim_batch_size, worker_pool_size=validated_config.stage_config.workers_per_progress_engine, trace_info=trace_info.get("trace_info"), ) @@ -258,32 +255,30 @@ def test_update_metadata_empty_list(paddle_mock): def test_update_metadata_all_valid(mocker, paddle_mock): - """ - If all images meet the minimum size, we pass them all to paddle.infer - in a single batch (batch_size=2), then use the returned results. - """ imgs = ["b64imgA", "b64imgB"] + # Patch base64_to_numpy so that both images are valid. mock_dim = mocker.patch(f"{MODULE_UNDER_TEST}.base64_to_numpy") - # Return actual NumPy arrays mock_dim.side_effect = [ - np.zeros((100, 120, 3), dtype=np.uint8), - np.zeros((80, 80, 3), dtype=np.uint8), + np.zeros((100, 120, 3), dtype=np.uint8), # b64imgA is valid + np.zeros((80, 80, 3), dtype=np.uint8), # b64imgB is valid ] + # Set minimum dimensions so that both images pass. mocker.patch(f"{MODULE_UNDER_TEST}.PADDLE_MIN_WIDTH", 50) mocker.patch(f"{MODULE_UNDER_TEST}.PADDLE_MIN_HEIGHT", 50) - # Suppose inference returns a list of (table_content, table_content_format) + # The paddle client returns a result for each valid image. paddle_mock.infer.return_value = [ ("tableA", "fmtA"), ("tableB", "fmtB"), ] - res = _update_metadata(imgs, paddle_mock, batch_size=2, worker_pool_size=1) + res = _update_metadata(imgs, paddle_mock, worker_pool_size=1) assert len(res) == 2 assert res[0] == ("b64imgA", ("tableA", "fmtA")) assert res[1] == ("b64imgB", ("tableB", "fmtB")) + # Expect one call to infer with all valid images. paddle_mock.infer.assert_called_once_with( data={"base64_images": ["b64imgA", "b64imgB"]}, model_name="paddle", @@ -300,7 +295,7 @@ def test_update_metadata_skip_small(mocker, paddle_mock): """ imgs = ["imgSmall", "imgBig"] mock_dim = mocker.patch(f"{MODULE_UNDER_TEST}.base64_to_numpy") - # Return NumPy arrays of certain shape to emulate dimension checks + # Return NumPy arrays of certain shape to emulate dimension checks. mock_dim.side_effect = [ np.zeros((40, 40, 3), dtype=np.uint8), # too small np.zeros((60, 70, 3), dtype=np.uint8), # big enough @@ -310,10 +305,11 @@ def test_update_metadata_skip_small(mocker, paddle_mock): paddle_mock.infer.return_value = [("valid_table", "valid_fmt")] - res = _update_metadata(imgs, paddle_mock, batch_size=2) + res = _update_metadata(imgs, paddle_mock) assert len(res) == 2 - # First was too small => ("", "") + # The first image is too small and is skipped. assert res[0] == ("imgSmall", (None, None)) + # The second image is valid and processed. assert res[1] == ("imgBig", ("valid_table", "valid_fmt")) paddle_mock.infer.assert_called_once_with( @@ -326,36 +322,39 @@ def test_update_metadata_skip_small(mocker, paddle_mock): def test_update_metadata_multiple_batches(mocker, paddle_mock): - """ - If batch_size=1 but we have 3 images => we chunk them => 3 calls to paddle.infer, - ignoring any that are too small. - """ imgs = ["img1", "img2", "img3"] + # Patch base64_to_numpy so that all images are valid. mock_dim = mocker.patch(f"{MODULE_UNDER_TEST}.base64_to_numpy") - # All valid => each call returns a 3D array big enough mock_dim.side_effect = [ - np.zeros((80, 80, 3), dtype=np.uint8), - np.zeros((100, 50, 3), dtype=np.uint8), - np.zeros((64, 64, 3), dtype=np.uint8), + np.zeros((80, 80, 3), dtype=np.uint8), # img1 + np.zeros((100, 50, 3), dtype=np.uint8), # img2 + np.zeros((64, 64, 3), dtype=np.uint8), # img3 ] + # Set minimum dimensions such that all images are considered valid. mocker.patch(f"{MODULE_UNDER_TEST}.PADDLE_MIN_WIDTH", 40) mocker.patch(f"{MODULE_UNDER_TEST}.PADDLE_MIN_HEIGHT", 40) - # We'll side_effect 3 calls => each returns a single pair - paddle_mock.infer.side_effect = [ - [("table1", "fmt1")], - [("table2", "fmt2")], - [("table3", "fmt3")], + # Since all images are valid, infer is called once with the full list. + paddle_mock.infer.return_value = [ + ("table1", "fmt1"), + ("table2", "fmt2"), + ("table3", "fmt3"), ] - res = _update_metadata(imgs, paddle_mock, batch_size=1, worker_pool_size=2) + res = _update_metadata(imgs, paddle_mock, worker_pool_size=2) assert len(res) == 3 assert res[0] == ("img1", ("table1", "fmt1")) assert res[1] == ("img2", ("table2", "fmt2")) assert res[2] == ("img3", ("table3", "fmt3")) - # 3 calls to infer, each with a single base64_images list - assert paddle_mock.infer.call_count == 3 + # Verify that infer is called only once with all valid images. + paddle_mock.infer.assert_called_once_with( + data={"base64_images": ["img1", "img2", "img3"]}, + model_name="paddle", + stage_name="table_data_extraction", + max_batch_size=2, + trace_info=None, + ) def test_update_metadata_inference_error(mocker, paddle_mock): @@ -372,7 +371,7 @@ def test_update_metadata_inference_error(mocker, paddle_mock): paddle_mock.infer.side_effect = RuntimeError("paddle error") with pytest.raises(RuntimeError, match="paddle error"): - _update_metadata(imgs, paddle_mock, batch_size=2) + _update_metadata(imgs, paddle_mock) # The code sets them to ("", "") before re-raising # We can’t see final 'res', but that’s the logic. @@ -391,7 +390,7 @@ def test_update_metadata_mismatch_length(mocker, paddle_mock): paddle_mock.infer.return_value = [("tableOnly", "fmtOnly")] with pytest.raises(ValueError, match="Expected 2 results"): - _update_metadata(imgs, paddle_mock, batch_size=2) + _update_metadata(imgs, paddle_mock) def test_update_metadata_non_list_return(mocker, paddle_mock): @@ -419,7 +418,7 @@ def test_update_metadata_all_small(mocker, paddle_mock): mocker.patch(f"{MODULE_UNDER_TEST}.PADDLE_MIN_WIDTH", 30) mocker.patch(f"{MODULE_UNDER_TEST}.PADDLE_MIN_HEIGHT", 30) - res = _update_metadata(imgs, paddle_mock, batch_size=2) + res = _update_metadata(imgs, paddle_mock) assert res[0] == ("imgA", (None, None)) assert res[1] == ("imgB", (None, None)) diff --git a/tests/nv_ingest/util/flow_control/test_filter_by_task.py b/tests/nv_ingest/util/flow_control/test_filter_by_task.py index 595beaf0..3edd9a5e 100644 --- a/tests/nv_ingest/util/flow_control/test_filter_by_task.py +++ b/tests/nv_ingest/util/flow_control/test_filter_by_task.py @@ -6,13 +6,11 @@ import pytest -from nv_ingest.util.flow_control.filter_by_task import filter_by_task -from nv_ingest.util.flow_control.filter_by_task import remove_task_subset - -from ....import_checks import CUDA_DRIVER_OK from ....import_checks import MORPHEUS_IMPORT_OK -if CUDA_DRIVER_OK and MORPHEUS_IMPORT_OK: +if MORPHEUS_IMPORT_OK: + from nv_ingest.util.flow_control.filter_by_task import remove_task_subset + from nv_ingest.util.flow_control.filter_by_task import filter_by_task from morpheus.messages import ControlMessage @@ -40,6 +38,7 @@ def process_message(message): return message +@pytest.mark.skipif(not MORPHEUS_IMPORT_OK, reason="Morpheus modules are not available.") def test_filter_by_task_with_required_task(mock_control_message): decorated_func = filter_by_task(["task1"])(process_message) assert ( @@ -47,6 +46,7 @@ def test_filter_by_task_with_required_task(mock_control_message): ), "Should process the message when required task is present." +@pytest.mark.skipif(not MORPHEUS_IMPORT_OK, reason="Morpheus modules are not available.") def test_filter_by_task_with_required_task_properties(mock_control_message): decorated_func = filter_by_task([("task1", {"prop1": "foo"})])(process_message) assert ( @@ -54,6 +54,7 @@ def test_filter_by_task_with_required_task_properties(mock_control_message): ), "Should process the message when both required task and required property are present." +@pytest.mark.skipif(not MORPHEUS_IMPORT_OK, reason="Morpheus modules are not available.") def test_filter_by_task_without_required_task_no_forward_func(mock_control_message): decorated_func = filter_by_task(["task3"])(process_message) assert ( @@ -61,6 +62,7 @@ def test_filter_by_task_without_required_task_no_forward_func(mock_control_messa ), "Should return the original message when required task is not present and no forward_func is provided." +@pytest.mark.skipif(not MORPHEUS_IMPORT_OK, reason="Morpheus modules are not available.") def test_filter_by_task_without_required_task_properteies_no_forward_func(mock_control_message): decorated_func = filter_by_task([("task1", {"prop1": "bar"})])(process_message) assert ( @@ -68,6 +70,7 @@ def test_filter_by_task_without_required_task_properteies_no_forward_func(mock_c ), "Should return the original message when required task is present but required task property is not present." +@pytest.mark.skipif(not MORPHEUS_IMPORT_OK, reason="Morpheus modules are not available.") def test_filter_by_task_without_required_task_with_forward_func(mock_control_message): # Create a simple mock function to be decorated mock_function = Mock(return_value="some_value") @@ -88,6 +91,7 @@ def test_filter_by_task_without_required_task_with_forward_func(mock_control_mes assert result == mock_control_message, "Should return the mock_control_message from the forward function." +@pytest.mark.skipif(not MORPHEUS_IMPORT_OK, reason="Morpheus modules are not available.") def test_filter_by_task_without_required_task_properties_with_forward_func(mock_control_message): # Create a simple mock function to be decorated mock_function = Mock(return_value="some_value") @@ -108,6 +112,7 @@ def test_filter_by_task_without_required_task_properties_with_forward_func(mock_ assert result == mock_control_message, "Should return the mock_control_message from the forward function." +@pytest.mark.skipif(not MORPHEUS_IMPORT_OK, reason="Morpheus modules are not available.") def test_filter_by_task_with_invalid_argument(): decorated_func = filter_by_task(["task1"])(process_message) with pytest.raises(ValueError): @@ -125,10 +130,6 @@ def create_ctrl_msg(task, task_props_list): @pytest.mark.skipif(not MORPHEUS_IMPORT_OK, reason="Morpheus modules are not available.") -@pytest.mark.skipif( - not CUDA_DRIVER_OK, - reason="Test environment does not have a compatible CUDA driver.", -) def test_remove_task_subset(): task_props_list = [ {"prop0": "foo0", "prop1": "bar1"}, From 458ea86963f840a9c50305defb7d9d510bb2b5d2 Mon Sep 17 00:00:00 2001 From: edknv Date: Tue, 11 Feb 2025 13:56:45 -0800 Subject: [PATCH 19/28] merge main --- .github/workflows/pypi-nightly-publish.yml | 11 ++++-- Dockerfile | 6 ++-- api/MANIFEST.in | 7 ++++ api/pyproject.toml | 7 ++-- api/src/version.py | 36 ++++++++++++++++++++ ci/scripts/build_pip_packages.sh | 17 +++++---- client/src/version.py | 2 +- src/nv_ingest/stages/nim/chart_extraction.py | 16 +++++---- src/nv_ingest/stages/nim/table_extraction.py | 2 +- 9 files changed, 83 insertions(+), 21 deletions(-) create mode 100644 api/MANIFEST.in create mode 100644 api/src/version.py diff --git a/.github/workflows/pypi-nightly-publish.yml b/.github/workflows/pypi-nightly-publish.yml index 10a4812b..c1ee1f5c 100644 --- a/.github/workflows/pypi-nightly-publish.yml +++ b/.github/workflows/pypi-nightly-publish.yml @@ -22,14 +22,19 @@ jobs: run: | pip install build twine - - name: Build wheel + - name: Build nv-ingest-api wheel + run: | + cd api && python -m build + + - name: Build nv-ingest-client wheel run: | cd client && python -m build - - name: Publish to Artifactory + - name: Publish wheels to Artifactory env: ARTIFACTORY_URL: ${{ secrets.ARTIFACTORY_URL }} ARTIFACTORY_USERNAME: ${{ secrets.ARTIFACTORY_USERNAME }} ARTIFACTORY_PASSWORD: ${{ secrets.ARTIFACTORY_PASSWORD }} run: | - twine upload --repository-url $ARTIFACTORY_URL -u $ARTIFACTORY_USERNAME -p $ARTIFACTORY_PASSWORD client/dist/* + twine upload --repository-url $ARTIFACTORY_URL -u $ARTIFACTORY_USERNAME -p $ARTIFACTORY_PASSWORD api/dist/* \ + && twine upload --repository-url $ARTIFACTORY_URL -u $ARTIFACTORY_USERNAME -p $ARTIFACTORY_PASSWORD client/dist/* diff --git a/Dockerfile b/Dockerfile index 4e6bf22d..e058323f 100644 --- a/Dockerfile +++ b/Dockerfile @@ -81,15 +81,15 @@ RUN if [ -z "${VERSION}" ]; then \ ENV NV_INGEST_RELEASE_TYPE=${RELEASE_TYPE} ENV NV_INGEST_VERSION_OVERRIDE=${NV_INGEST_VERSION_OVERRIDE} -ENV NV_INGEST_CLIENT_VERSION_OVERRIDE=${NV_INGEST_VERSION_OVERRIDE} SHELL ["/bin/bash", "-c"] COPY tests tests COPY data data +COPY api api COPY client client COPY src/nv_ingest src/nv_ingest -RUN rm -rf ./src/nv_ingest/dist ./client/dist +RUN rm -rf ./src/nv_ingest/dist ./client/dist ./api/dist # Install python build from pip, version needed not present in conda RUN source activate nv_ingest_runtime \ @@ -100,6 +100,7 @@ RUN --mount=type=cache,target=/opt/conda/pkgs \ --mount=type=cache,target=/root/.cache/pip \ chmod +x ./ci/scripts/build_pip_packages.sh \ && source activate nv_ingest_runtime \ + && ./ci/scripts/build_pip_packages.sh --type ${RELEASE_TYPE} --lib api \ && ./ci/scripts/build_pip_packages.sh --type ${RELEASE_TYPE} --lib client \ && ./ci/scripts/build_pip_packages.sh --type ${RELEASE_TYPE} --lib service @@ -107,6 +108,7 @@ RUN --mount=type=cache,target=/opt/conda/pkgs\ --mount=type=cache,target=/root/.cache/pip \ source activate nv_ingest_runtime \ && pip install ./dist/*.whl \ + && pip install ./api/dist/*.whl \ && pip install ./client/dist/*.whl RUN rm -rf src diff --git a/api/MANIFEST.in b/api/MANIFEST.in new file mode 100644 index 00000000..f0c39304 --- /dev/null +++ b/api/MANIFEST.in @@ -0,0 +1,7 @@ +exclude *.egg-info + +include README.md +include LICENSE +recursive-include src * +global-exclude __pycache__ +global-exclude *.pyc diff --git a/api/pyproject.toml b/api/pyproject.toml index 7d633c42..0ede7d00 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -1,11 +1,11 @@ [build-system] -requires = ["setuptools", "wheel"] # Tools needed to build the project +requires = ["setuptools", "wheel"] build-backend = "setuptools.build_meta" [project] name = "nv-ingest-api" -version = "24.12.dev0" description = "Python module with core document ingestion functions." +dynamic = ["version"] # Declare attrs that will be generated at build time readme = "README.md" authors = [ {name = "Jeremy Dyer", email = "jdyer@nvidia.com"} @@ -29,3 +29,6 @@ documentation = "https://docs.nvidia.com/nv-ingest" [tool.setuptools.packages.find] where = ["src"] + +[tool.setuptools.dynamic] +version = {attr = "version.get_version"} diff --git a/api/src/version.py b/api/src/version.py new file mode 100644 index 00000000..ac7fc9d2 --- /dev/null +++ b/api/src/version.py @@ -0,0 +1,36 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + + +import datetime +import os +import re + + +def get_version(): + release_type = os.getenv("NV_INGEST_RELEASE_TYPE", "dev") + version = os.getenv("NV_INGEST_VERSION") + rev = os.getenv("NV_INGEST_REV", "0") + + if not version: + version = f"{datetime.datetime.now().strftime('%Y.%m.%d')}" + + # Ensure the version is PEP 440 compatible + pep440_regex = r"^\d{4}\.\d{1,2}\.\d{1,2}$" + if not re.match(pep440_regex, version): + raise ValueError(f"Version '{version}' is not PEP 440 compatible") + + # Construct the final version string + if release_type == "dev": + # If rev is not specified and defaults to 0 lets create a more meaningful development + # identifier that is pep440 compliant + if int(rev) == 0: + rev = datetime.datetime.now().strftime("%Y%m%d%H%M%S") + final_version = f"{version}.dev{rev}" + elif release_type == "release": + final_version = f"{version}.post{rev}" if int(rev) > 0 else version + else: + raise ValueError(f"Invalid release type: {release_type}") + + return final_version diff --git a/ci/scripts/build_pip_packages.sh b/ci/scripts/build_pip_packages.sh index 34433ad6..c12a5c56 100755 --- a/ci/scripts/build_pip_packages.sh +++ b/ci/scripts/build_pip_packages.sh @@ -2,7 +2,7 @@ # Function to display usage usage() { - echo "Usage: $0 --type --lib " + echo "Usage: $0 --type --lib " exit 1 } @@ -38,11 +38,16 @@ else fi # Set library-specific variables and paths -if [[ "$LIBRARY" == "client" ]]; then - NV_INGEST_CLIENT_VERSION_OVERRIDE="${VERSION_SUFFIX}" - export NV_INGEST_CLIENT_VERSION_OVERRIDE - SETUP_PATH="$SCRIPT_DIR/../../client" - (cd "$(dirname "$SETUP_PATH")/client" && python -m build) +if [[ "$LIBRARY" == "api" ]]; then + NV_INGEST_VERSION_OVERRIDE="${VERSION_SUFFIX}" + export NV_INGEST_VERSION_OVERRIDE + SETUP_PATH="$SCRIPT_DIR/../../api/pyproject.toml" + (cd "$(dirname "$SETUP_PATH")" && python -m build) +elif [[ "$LIBRARY" == "client" ]]; then + NV_INGEST_VERSION_OVERRIDE="${VERSION_SUFFIX}" + export NV_INGEST_VERSION_OVERRIDE + SETUP_PATH="$SCRIPT_DIR/../../client/pyproject.toml" + (cd "$(dirname "$SETUP_PATH")" && python -m build) elif [[ "$LIBRARY" == "service" ]]; then NV_INGEST_SERVICE_VERSION_OVERRIDE="${VERSION_SUFFIX}" export NV_INGEST_SERVICE_VERSION_OVERRIDE diff --git a/client/src/version.py b/client/src/version.py index 2037ed0f..ac7fc9d2 100644 --- a/client/src/version.py +++ b/client/src/version.py @@ -10,7 +10,7 @@ def get_version(): release_type = os.getenv("NV_INGEST_RELEASE_TYPE", "dev") - version = os.getenv("NV_INGEST_CLIENT_VERSION") + version = os.getenv("NV_INGEST_VERSION") rev = os.getenv("NV_INGEST_REV", "0") if not version: diff --git a/src/nv_ingest/stages/nim/chart_extraction.py b/src/nv_ingest/stages/nim/chart_extraction.py index f13d9321..8d956847 100644 --- a/src/nv_ingest/stages/nim/chart_extraction.py +++ b/src/nv_ingest/stages/nim/chart_extraction.py @@ -39,11 +39,15 @@ def _update_metadata( worker_pool_size: int = 8, # Not currently used. ) -> List[Tuple[str, Dict]]: """ - Given a list of base64-encoded chart images, this function calls both the Yolox and Paddle - inference services concurrently to extract chart data for all images. - - For each base64-encoded image, returns: - (original_image_str, joined_chart_content_dict) + <<<<<<< HEAD + Given a list of base64-encoded chart images, this function calls both the Yolox and Paddle + ======= + Given a list of base64-encoded chart images, this function calls both the Cached and Deplot + >>>>>>> main + inference services concurrently to extract chart data for all images. + + For each base64-encoded image, returns: + (original_image_str, joined_chart_content_dict) """ logger.debug("Running chart extraction using updated concurrency handling.") @@ -77,7 +81,7 @@ def _update_metadata( data=data_paddle, model_name="paddle", stage_name="chart_data_extraction", - max_batch_size=1, + max_batch_size=1 if paddle_client.protocol == "grpc" else 2, trace_info=trace_info, ) diff --git a/src/nv_ingest/stages/nim/table_extraction.py b/src/nv_ingest/stages/nim/table_extraction.py index 4836e107..000902c1 100644 --- a/src/nv_ingest/stages/nim/table_extraction.py +++ b/src/nv_ingest/stages/nim/table_extraction.py @@ -73,7 +73,7 @@ def _update_metadata( data=data, model_name="paddle", stage_name="table_data_extraction", - max_batch_size=1, + max_batch_size=1 if paddle_client.protocol == "grpc" else 2, trace_info=trace_info, ) From db1f3b5940dc5e4da87f122b8d3da414ad4a4087 Mon Sep 17 00:00:00 2001 From: edknv Date: Tue, 11 Feb 2025 14:04:46 -0800 Subject: [PATCH 20/28] fix paddle image shape check --- src/nv_ingest/util/nim/paddle.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/src/nv_ingest/util/nim/paddle.py b/src/nv_ingest/util/nim/paddle.py index aaca0790..db1176a9 100644 --- a/src/nv_ingest/util/nim/paddle.py +++ b/src/nv_ingest/util/nim/paddle.py @@ -127,13 +127,17 @@ def chunk_list(lst, chunk_size): arr = np.expand_dims(arr, axis=0) # => shape (1, H, W, C) processed.append(arr) - # Check that all images have the same shape (excluding batch dimension) - shapes = [p.shape[1:] for p in processed] # List of (H, W, C) shapes - if not all(s == shapes[0] for s in shapes[1:]): - raise ValueError(f"All images must have the same dimensions for gRPC batching. Found: {shapes}") + # Chunk the images into groups of size up to max_batch_size + batched_image_chunks = chunk_list(processed, max_batch_size) + + # Check that all images in each chunk have the same shape (excluding batch dimension) + for chunk in batched_image_chunks: + shapes = [p.shape[1:] for p in chunk] # List of (H, W, C) shapes + if not all(s == shapes[0] for s in shapes[1:]): + raise ValueError(f"All images must have the same dimensions for gRPC batching. Found: {shapes}") batches = [] - for chunk in chunk_list(processed, max_batch_size): + for chunk in batched_image_chunks: # Concatenate arrays in the chunk along the batch dimension => shape (B, H, W, C) batched_input = np.concatenate(chunk, axis=0) batches.append(batched_input) From 5d9f0ed93cf13f51f32f31379f5b2514167ede5a Mon Sep 17 00:00:00 2001 From: edknv Date: Tue, 11 Feb 2025 14:27:00 -0800 Subject: [PATCH 21/28] update image name and tag --- docker-compose.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docker-compose.yaml b/docker-compose.yaml index c9f36116..04fccd4c 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -39,7 +39,7 @@ services: runtime: nvidia yolox-graphic-elements: - image: ${YOLOX_GRAPHIC_ELEMENTS_IMAGE:-nvcr.io/nvidia/nemo-microservices/nv-yolox-graphic-elements-v1}:${YOLOX_GRAPHIC_ELEMENTS_TAG:-1.1.0} + image: ${YOLOX_GRAPHIC_ELEMENTS_IMAGE:-nvcr.io/nvidia/nemo-microservices/nemoretriever-graphic-elements-v1}:${YOLOX_GRAPHIC_ELEMENTS_TAG:-1.1} ports: - "8003:8000" - "8004:8001" From b1b0e18ca8c9f55ffa09b3dce080c8feff1c79c3 Mon Sep 17 00:00:00 2001 From: edknv Date: Tue, 11 Feb 2025 15:13:51 -0800 Subject: [PATCH 22/28] move milvus to device 0 --- docker-compose.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docker-compose.yaml b/docker-compose.yaml index 04fccd4c..7bc26bd6 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -289,7 +289,7 @@ services: reservations: devices: - driver: nvidia - device_ids: ["1"] + device_ids: ["0"] capabilities: [gpu] depends_on: - "etcd" From 252179b3d8ea96a49db86abb1d0ed32f463d9c0a Mon Sep 17 00:00:00 2001 From: edknv Date: Tue, 11 Feb 2025 16:22:07 -0800 Subject: [PATCH 23/28] adjust yolox batch size --- src/nv_ingest/util/nim/yolox.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/src/nv_ingest/util/nim/yolox.py b/src/nv_ingest/util/nim/yolox.py index e411a8ae..9962b312 100644 --- a/src/nv_ingest/util/nim/yolox.py +++ b/src/nv_ingest/util/nim/yolox.py @@ -7,6 +7,7 @@ import io import logging import warnings +from math import log from typing import Any from typing import Dict from typing import List @@ -66,11 +67,20 @@ ] -def chunkify(lst, chunk_size): +def chunkify_linearly(lst, chunk_size): for i in range(0, len(lst), chunk_size): yield lst[i : i + chunk_size] +def chunkify_geometrically(lst, max_size): + # TRT engine in Yolox NIM (gRPC) only allows a batch size in multiples of 2. + i = 0 + while i < len(lst): + chunk_size = min(2 ** int(log(len(lst) - i, 2)), max_size) + yield lst[i : i + chunk_size] + i += chunk_size + + # YoloxModelInterfaceBase implements methods that are common to yolox-page-elements and yolox-graphic-elements class YoloxModelInterfaceBase(ModelInterface): """ @@ -158,7 +168,7 @@ def format_input(self, data: Dict[str, Any], protocol: str, max_batch_size: int, # Create a list of smaller batches (chunkify) batches = [] - for chunk in chunkify(resized_images, max_batch_size): + for chunk in chunkify_geometrically(resized_images, max_batch_size): # Reorder axes to match model input (batch, channels, height, width) input_array = np.einsum("bijk->bkij", chunk).astype(np.float32) batches.append(input_array) @@ -191,7 +201,7 @@ def format_input(self, data: Dict[str, Any], protocol: str, max_batch_size: int, # Now split content_list into batches of up to max_batch_size batches = [] - for chunk in chunkify(content_list, max_batch_size): + for chunk in chunkify_linearly(content_list, max_batch_size): payload = { "input": content_list, "confidence_threshold": self.conf_threshold, From b63ea19b92b951e6989f84f9333b086894fef077 Mon Sep 17 00:00:00 2001 From: Edward Kim Date: Wed, 12 Feb 2025 21:30:18 +0000 Subject: [PATCH 24/28] A100 fix --- src/nv_ingest/util/nim/yolox.py | 32 ++++++++++++++------------------ 1 file changed, 14 insertions(+), 18 deletions(-) diff --git a/src/nv_ingest/util/nim/yolox.py b/src/nv_ingest/util/nim/yolox.py index 9562832d..47df3569 100644 --- a/src/nv_ingest/util/nim/yolox.py +++ b/src/nv_ingest/util/nim/yolox.py @@ -67,20 +67,6 @@ ] -def chunkify_linearly(lst, chunk_size): - for i in range(0, len(lst), chunk_size): - yield lst[i : i + chunk_size] - - -def chunkify_geometrically(lst, max_size): - # TRT engine in Yolox NIM (gRPC) only allows a batch size in multiples of 2. - i = 0 - while i < len(lst): - chunk_size = min(2 ** int(log(len(lst) - i, 2)), max_size) - yield lst[i : i + chunk_size] - i += chunk_size - - # YoloxModelInterfaceBase implements methods that are common to yolox-page-elements and yolox-graphic-elements class YoloxModelInterfaceBase(ModelInterface): """ @@ -175,10 +161,20 @@ def format_input( If the protocol is invalid. """ - # Helper to chunk a list into sublists of length up to chunk_size. + # Helper functions to chunk a list into sublists of length up to chunk_size. def chunk_list(lst: list, chunk_size: int) -> List[list]: return [lst[i : i + chunk_size] for i in range(0, len(lst), chunk_size)] + def chunk_list_geometrically(lst: list, max_size: int) -> List[list]: + # TRT engine in Yolox NIM (gRPC) only allows a batch size in powers of 2. + chunks = [] + i = 0 + while i < len(lst): + chunk_size = min(2 ** int(log(len(lst) - i, 2)), max_size) + chunks.append(lst[i : i + chunk_size]) + i += chunk_size + return chunks + if protocol == "grpc": logger.debug("Formatting input for gRPC Yolox model") # Resize images for model input (Yolox expects 1024x1024). @@ -186,9 +182,9 @@ def chunk_list(lst: list, chunk_size: int) -> List[list]: resize_image(image, (self.image_preproc_width, self.image_preproc_height)) for image in data["images"] ] # Chunk the resized images, the original images, and their shapes. - resized_chunks = chunk_list(resized_images, max_batch_size) - original_chunks = chunk_list(data["images"], max_batch_size) - shape_chunks = chunk_list(data["original_image_shapes"], max_batch_size) + resized_chunks = chunk_list_geometrically(resized_images, max_batch_size) + original_chunks = chunk_list_geometrically(data["images"], max_batch_size) + shape_chunks = chunk_list_geometrically(data["original_image_shapes"], max_batch_size) batched_inputs = [] formatted_batch_data = [] From fa6e75072789ac4052218dae6d80cca0fb194eb4 Mon Sep 17 00:00:00 2001 From: edknv Date: Thu, 13 Feb 2025 08:42:56 -0800 Subject: [PATCH 25/28] add banner to readme --- README.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/README.md b/README.md index 5ae93af9..c979dcef 100644 --- a/README.md +++ b/README.md @@ -4,6 +4,9 @@ All rights reserved. SPDX-License-Identifier: Apache-2.0 --> +> [!Note] +> Cached and Deplot are deprecated, docker-compose now points to a beta version of the yolox-graphic-elements container instead. That model and container is slated for full release in March. +> With this change, you should now be able to run on a single 80GB A100 or H100 GPU. ## NVIDIA-Ingest: Multi-modal data extraction From fd65f2d30eb5aa2c08865722a77115b5ed5b8c1c Mon Sep 17 00:00:00 2001 From: edknv Date: Thu, 13 Feb 2025 08:53:43 -0800 Subject: [PATCH 26/28] further readme update --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index c979dcef..7f6c7409 100644 --- a/README.md +++ b/README.md @@ -7,6 +7,7 @@ SPDX-License-Identifier: Apache-2.0 > [!Note] > Cached and Deplot are deprecated, docker-compose now points to a beta version of the yolox-graphic-elements container instead. That model and container is slated for full release in March. > With this change, you should now be able to run on a single 80GB A100 or H100 GPU. +> If you want to continue using the old pipeline with Cached and Deplot, please use the [24.12.1 release](https://github.com/NVIDIA/nv-ingest/tree/24.12.1). ## NVIDIA-Ingest: Multi-modal data extraction From 41f9e7be45d649b77b8971ae40925ffc527b38e4 Mon Sep 17 00:00:00 2001 From: edknv Date: Thu, 13 Feb 2025 08:54:11 -0800 Subject: [PATCH 27/28] change default embedding model to llama embedder --- docker-compose.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docker-compose.yaml b/docker-compose.yaml index 7bc26bd6..7b02948c 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -83,7 +83,7 @@ services: embedding: # NIM ON - image: ${EMBEDDING_IMAGE:-nvcr.io/nim/nvidia/nv-embedqa-e5-v5}:${EMBEDDING_TAG:-1.1.0} + image: ${EMBEDDING_IMAGE:-nvcr.io/nim/nvidia/llama-3.2-nv-embedqa-1b-v2}:${EMBEDDING_TAG:-1.3.0} shm_size: 16gb ports: - "8012:8000" @@ -121,7 +121,7 @@ services: environment: - CUDA_VISIBLE_DEVICES=0 - DOUGHNUT_GRPC_TRITON=triton-doughnut:8001 - - EMBEDDING_NIM_MODEL_NAME=${EMBEDDING_NIM_MODEL_NAME:-nvidia/nv-embedqa-e5-v5} + - EMBEDDING_NIM_MODEL_NAME=${EMBEDDING_NIM_MODEL_NAME:-nvidia/llama-3.2-nv-embedqa-1b-v2} - INGEST_LOG_LEVEL=DEFAULT # Message client for development #- MESSAGE_CLIENT_HOST=0.0.0.0 From 20c290f8e915b01a6a615900149164d35f49a15b Mon Sep 17 00:00:00 2001 From: edknv Date: Thu, 13 Feb 2025 08:55:57 -0800 Subject: [PATCH 28/28] change minimum requirement to 1 GPU --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 7f6c7409..d16f317c 100644 --- a/README.md +++ b/README.md @@ -48,8 +48,8 @@ A service that: | GPU | Family | Memory | # of GPUs (min.) | | ------ | ------ | ------ | ------ | -| H100 | SXM or PCIe | 80GB | 2 | -| A100 | SXM or PCIe | 80GB | 2 | +| H100 | SXM or PCIe | 80GB | 1 | +| A100 | SXM or PCIe | 80GB | 1 | ### Software