From 9b1668d3cca83c38db3d1ff9554b6feaf90f5cc7 Mon Sep 17 00:00:00 2001 From: Devin Robison Date: Thu, 16 Jan 2025 19:50:31 -0700 Subject: [PATCH] Extend docx and pptx extractors to attempt to extract tables/charts from images (#334) Co-authored-by: Edward Kim <109497216+edknv@users.noreply.github.com> --- .../extraction_workflows/docx/docx_helper.py | 21 +- .../extraction_workflows/docx/docxreader.py | 526 ++++++++++++++---- .../image/image_handlers.py | 120 ++-- .../extraction_workflows/pptx/pptx_helper.py | 301 ++++++---- .../schemas/docx_extractor_schema.py | 124 +++++ .../schemas/ingest_pipeline_config_schema.py | 4 +- .../schemas/pptx_extractor_schema.py | 120 +++- src/nv_ingest/stages/docx_extractor_stage.py | 107 ++-- .../extractors/image_extractor_stage.py | 2 - src/nv_ingest/stages/nim/chart_extraction.py | 1 + src/nv_ingest/stages/nim/table_extraction.py | 1 + src/nv_ingest/stages/pptx_extractor_stage.py | 110 ++-- .../util/pdf/metadata_aggregators.py | 1 - .../util/pipeline/pipeline_builders.py | 4 +- src/nv_ingest/util/pipeline/stage_builders.py | 30 +- src/util/image_viewer.py | 27 +- .../docx/test_docx_helper.py | 17 +- .../pptx/test_pptx_helper.py | 20 +- 18 files changed, 1147 insertions(+), 389 deletions(-) create mode 100644 src/nv_ingest/schemas/docx_extractor_schema.py diff --git a/src/nv_ingest/extraction_workflows/docx/docx_helper.py b/src/nv_ingest/extraction_workflows/docx/docx_helper.py index 6bfa12b6..44a946ad 100644 --- a/src/nv_ingest/extraction_workflows/docx/docx_helper.py +++ b/src/nv_ingest/extraction_workflows/docx/docx_helper.py @@ -36,7 +36,14 @@ logger = logging.getLogger(__name__) -def python_docx(docx: Union[str, Path, IO], extract_text: bool, extract_images: bool, extract_tables: bool, **kwargs): +def python_docx( + docx: Union[str, Path, IO], + extract_text: bool, + extract_images: bool, + extract_tables: bool, + extract_charts: bool, + **kwargs +): """ Helper function that use python-docx to extract text from a bytestream document @@ -57,6 +64,8 @@ def python_docx(docx: Union[str, Path, IO], extract_text: bool, extract_images: Specifies whether to extract images. extract_tables : bool Specifies whether to extract tables. + extract_charts : bool + Specifies whether to extract charts. **kwargs The keyword arguments are used for additional extraction parameters. @@ -73,10 +82,12 @@ def python_docx(docx: Union[str, Path, IO], extract_text: bool, extract_images: source_id = row_data["source_id"] # get text_depth text_depth = kwargs.get("text_depth", "document") - text_depth = TextTypeEnum[text_depth.upper()] + text_depth = TextTypeEnum(text_depth) # get base metadata metadata_col = kwargs.get("metadata_column", "metadata") + docx_extractor_config = kwargs.get("docx_extraction_config", {}) + base_unified_metadata = row_data[metadata_col] if metadata_col in row_data.index else {} # get base source_metadata @@ -103,7 +114,9 @@ def python_docx(docx: Union[str, Path, IO], extract_text: bool, extract_images: } # Extract data from the document using python-docx - doc = DocxReader(docx, source_metadata) - extracted_data = doc.extract_data(base_unified_metadata, text_depth, extract_text, extract_tables, extract_images) + doc = DocxReader(docx, source_metadata, extraction_config=docx_extractor_config) + extracted_data = doc.extract_data( + base_unified_metadata, text_depth, extract_text, extract_charts, extract_tables, extract_images + ) return extracted_data diff --git a/src/nv_ingest/extraction_workflows/docx/docxreader.py b/src/nv_ingest/extraction_workflows/docx/docxreader.py index b550d936..b2920203 100644 --- a/src/nv_ingest/extraction_workflows/docx/docxreader.py +++ b/src/nv_ingest/extraction_workflows/docx/docxreader.py @@ -23,14 +23,16 @@ """ Parse document content and properties using python-docx """ - +import io import logging import re import uuid -from typing import Dict +from typing import Dict, Optional, Union from typing import List from typing import Tuple +from collections import defaultdict + import pandas as pd from docx import Document from docx.image.constants import MIME_TYPE @@ -42,7 +44,11 @@ from docx.text.hyperlink import Hyperlink from docx.text.paragraph import Paragraph from docx.text.run import Run +from pandas import DataFrame +from build.lib.nv_ingest.extraction_workflows.image.image_handlers import load_and_preprocess_image +from nv_ingest.extraction_workflows.image.image_handlers import extract_tables_and_charts_from_images +from nv_ingest.schemas.image_extractor_schema import ImageConfigSchema from nv_ingest.schemas.metadata_schema import ContentTypeEnum from nv_ingest.schemas.metadata_schema import ImageTypeEnum from nv_ingest.schemas.metadata_schema import StdContentDescEnum @@ -50,6 +56,7 @@ from nv_ingest.schemas.metadata_schema import validate_metadata from nv_ingest.util.converters import bytetools from nv_ingest.util.detectors.language import detect_language +from nv_ingest.util.pdf.metadata_aggregators import construct_table_and_chart_metadata, CroppedImageWithContent PARAGRAPH_FORMATS = ["text", "markdown"] TABLE_FORMATS = ["markdown", "markdown_light", "csv", "tag"] @@ -92,7 +99,7 @@ def __str__(self): def _update_source_meta_data(self): """ - Update the source meta data with the document's core properties + Update the source metadata with the document's core properties """ self.source_metadata.update( { @@ -132,9 +139,11 @@ def __init__( handle_text_styles: bool = True, image_tag="", table_tag="", + extraction_config: Dict = None, ): if paragraph_format not in PARAGRAPH_FORMATS: raise ValueError(f"Unknown paragraph format {paragraph_format}. Supported formats are: {PARAGRAPH_FORMATS}") + if table_format not in TABLE_FORMATS: raise ValueError(f"Unknown table format {table_format}. Supported formats are: {TABLE_FORMATS}") @@ -161,18 +170,47 @@ def __init__( # placeholders for metadata extraction self._accumulated_text = [] self._extracted_data = [] - self._prev_para_images = [] + self._extraction_config = extraction_config if extraction_config else {} + self._pending_images = [] self._prev_para_image_idx = 0 + self._prev_para_images = [] def is_text_empty(self, text: str) -> bool: """ - Check if text is available + Check if the given text is empty or matches the empty text pattern. + + Parameters + ---------- + text : str + The text to check. + + Returns + ------- + bool + True if the text is empty or matches the empty text pattern, False otherwise. """ + return self.empty_text_pattern.match(text) is not None - def format_text(self, text, bold: bool, italic: bool, underline: bool) -> str: + def format_text(self, text: str, bold: bool, italic: bool, underline: bool) -> str: """ - Apply markdown style to text (bold, italic, underline). + Apply markdown styling (bold, italic, underline) to the given text. + + Parameters + ---------- + text : str + The text to format. + bold : bool + Whether to apply bold styling. + italic : bool + Whether to apply italic styling. + underline : bool + Whether to apply underline styling. + + Returns + ------- + str + The formatted text with the applied styles. """ if self.is_text_empty(text): @@ -198,9 +236,20 @@ def format_text(self, text, bold: bool, italic: bool, underline: bool) -> str: return text - def format_paragraph(self, paragraph: Paragraph) -> Tuple[str, List[Image]]: - f""" - Format a paragraph into text. Supported formats are: {PARAGRAPH_FORMATS} + def format_paragraph(self, paragraph: "Paragraph") -> Tuple[str, List["Image"]]: + """ + Format a paragraph into styled text and extract associated images. + + Parameters + ---------- + paragraph : Paragraph + The paragraph to format. This includes text and potentially embedded images. + + Returns + ------- + tuple of (str, list of Image) + - The formatted paragraph text with markdown styling applied. + - A list of extracted images from the paragraph. """ paragraph_images = [] @@ -257,10 +306,22 @@ def format_paragraph(self, paragraph: Paragraph) -> Tuple[str, List[Image]]: paragraph_text = paragraph_text.strip() return paragraph_text, paragraph_images - def format_cell(self, cell: _Cell) -> Tuple[str, List[Image]]: + def format_cell(self, cell: "_Cell") -> Tuple[str, List["Image"]]: """ - Format a table cell into markdown text + Format a table cell into Markdown text and extract associated images. + + Parameters + ---------- + cell : _Cell + The table cell to format. + + Returns + ------- + tuple of (str, list of Image) + - The formatted text of the cell with markdown styling applied. + - A list of images extracted from the cell. """ + if self.paragraph_format == "markdown": newline = "
" else: @@ -268,10 +329,23 @@ def format_cell(self, cell: _Cell) -> Tuple[str, List[Image]]: paragraph_texts, paragraph_images = zip(*[self.format_paragraph(p) for p in cell.paragraphs]) return newline.join(paragraph_texts), paragraph_images - def format_table(self, table: Table) -> Tuple[str, List[Image]]: - f""" - Format a table into text. Supported formats are: {TABLE_FORMATS} + def format_table(self, table: "Table") -> Tuple[Optional[str], List["Image"], DataFrame]: + """ + Format a table into text, extract images, and represent it as a DataFrame. + + Parameters + ---------- + table : Table + The table to format. + + Returns + ------- + tuple of (str or None, list of Image, DataFrame) + - The formatted table as text, using the specified format (e.g., markdown, CSV). + - A list of images extracted from the table. + - A DataFrame representation of the table's content. """ + rows = [[self.format_cell(cell) for cell in row.cells] for row in table.rows] texts = [[text for text, _ in row] for row in rows] table_images = [image for row in rows for _, images in row for image in images] @@ -295,9 +369,24 @@ def format_table(self, table: Table) -> Tuple[str, List[Image]]: @staticmethod def apply_text_style(style: str, text: str, level: int = 0) -> str: """ - Apply style on a paragraph (heading, list, title, subtitle). - Not recommended if the document has been converted from pdf. + Apply a specific text style (e.g., heading, list, title, subtitle) to the given text. + + Parameters + ---------- + style : str + The style to apply. Supported styles include headings ("Heading 1" to "Heading 9"), + list items ("List"), and document structures ("Title", "Subtitle"). + text : str + The text to style. + level : int, optional + The indentation level for the styled text. Default is 0. + + Returns + ------- + str + The text with the specified style and indentation applied. """ + if re.match(r"^Heading [1-9]$", style): n = int(style.split(" ")[-1]) text = f"{'#' * n} {text}" @@ -313,43 +402,62 @@ def apply_text_style(style: str, text: str, level: int = 0) -> str: return text @staticmethod - def docx_content_type_to_image_type(content_type: MIME_TYPE) -> str: + def docx_content_type_to_image_type(content_type: "MIME_TYPE") -> str: """ - python-docx stores the content type in the image header as a string of format - "image/jpeg" etc. This is converted into one of ImageTypeEnum. - Reference: src/docx/image/jpeg.py + Convert a DOCX content type string to an image type. + + Parameters + ---------- + content_type : MIME_TYPE + The content type string from the image header, e.g., "image/jpeg". + + Returns + ------- + str + The image type extracted from the content type string. """ + return content_type.split("/")[1] - def _construct_image_metadata(self, image, para_idx, caption, base_unified_metadata): + def _construct_image_metadata( + self, para_idx: int, caption: str, base_unified_metadata: Dict, base64_img: str + ) -> List[Union[str, dict]]: """ - Fill the metadata for the extracted image + Build metadata for an image in a DOCX file. + + Parameters + ---------- + para_idx : int + The paragraph index containing the image. + caption : str + The caption associated with the image. + base_unified_metadata : dict + The base metadata to build upon. + base64_img : str + The image content encoded as a base64 string. + + Returns + ------- + list + A list containing the content type, validated metadata, and a unique identifier. """ - image_type = self.docx_content_type_to_image_type(image.content_type) - if ImageTypeEnum.has_value(image_type): - image_type = ImageTypeEnum[image_type.upper()] - - base64_img = bytetools.base64frombytes(image.blob) - # For docx there is no bounding box. The paragraph that follows the image is typically - # the caption. Add that para to the page nearby for now. fixme bbox = (0, 0, 0, 0) + caption_len = len(caption.splitlines()) + + page_idx = 0 # docx => single page + page_count = 1 + page_nearby_blocks = { "text": {"content": [], "bbox": []}, "images": {"content": [], "bbox": []}, "structured": {"content": [], "bbox": []}, } - caption_len = len(caption.splitlines()) + if caption_len: page_nearby_blocks["text"]["content"].append(caption) page_nearby_blocks["text"]["bbox"] = [[-1, -1, -1, -1]] * caption_len - page_block = para_idx - - # python-docx treats the entire document as a single page - page_count = 1 - page_idx = 0 - content_metadata = { "type": ContentTypeEnum.IMAGE, "description": StdContentDescEnum.DOCX_IMAGE, @@ -357,16 +465,15 @@ def _construct_image_metadata(self, image, para_idx, caption, base_unified_metad "hierarchy": { "page_count": page_count, "page": page_idx, - "block": page_block, + "block": para_idx, "line": -1, "span": -1, "nearby_objects": page_nearby_blocks, }, } - # bbox is not available in docx. the para following the image is typically the caption. image_metadata = { - "image_type": image_type, + "image_type": ImageTypeEnum.image_type_1, "structured_image_type": ImageTypeEnum.image_type_1, "caption": caption, "text": "", @@ -374,7 +481,6 @@ def _construct_image_metadata(self, image, para_idx, caption, base_unified_metad } unified_metadata = base_unified_metadata.copy() - unified_metadata.update( { "content": base64_img, @@ -386,24 +492,64 @@ def _construct_image_metadata(self, image, para_idx, caption, base_unified_metad validated_unified_metadata = validate_metadata(unified_metadata) - # Work around until https://github.com/apache/arrow/pull/40412 is resolved - return [ContentTypeEnum.IMAGE.value, validated_unified_metadata.model_dump(), str(uuid.uuid4())] + return [ + ContentTypeEnum.IMAGE.value, + validated_unified_metadata.model_dump(), + str(uuid.uuid4()), + ] - def _extract_para_images(self, images, para_idx, caption, base_unified_metadata, extracted_data): + def _extract_para_images( + self, images: List["Image"], para_idx: int, caption: str, base_unified_metadata: Dict + ) -> None: """ - Extract all images in a paragraph. These images share the same metadata. + Collect images from a paragraph and store them for metadata construction. + + Parameters + ---------- + images : list of Image + The images found in the paragraph. + para_idx : int + The index of the paragraph containing the images. + caption : str + The caption associated with the images. + base_unified_metadata : dict + The base metadata to associate with the images. + + Returns + ------- + None """ + for image in images: logger.debug("image content_type %s para_idx %d", image.content_type, para_idx) logger.debug("image caption %s", caption) - extracted_image = self._construct_image_metadata(image, para_idx, caption, base_unified_metadata) - extracted_data.append(extracted_image) - def _construct_text_metadata(self, accumulated_text, para_idx, text_depth, base_unified_metadata): + # Simply append a tuple so we can build the final metadata in _finalize_images + self._pending_images.append((image, para_idx, caption, base_unified_metadata)) + + def _construct_text_metadata( + self, accumulated_text: List[str], para_idx: int, text_depth: "TextTypeEnum", base_unified_metadata: Dict + ) -> List[Union[str, dict]]: """ - Store the text with associated metadata. Docx uses the same scheme as - PDF. + Build metadata for text content in a DOCX file. + + Parameters + ---------- + accumulated_text : list of str + The accumulated text to include in the metadata. + para_idx : int + The paragraph index containing the text. + text_depth : TextTypeEnum + The depth of the text content (e.g., page-level, paragraph-level). + base_unified_metadata : dict + The base metadata to build upon. + + Returns + ------- + list + A list containing the content type, validated metadata, and a unique identifier. """ + if len(accumulated_text) < 1: return [] @@ -447,36 +593,37 @@ def _construct_text_metadata(self, accumulated_text, para_idx, text_depth, base_ return [ContentTypeEnum.TEXT.value, validated_unified_metadata.model_dump(), str(uuid.uuid4())] - def _extract_para_data( - self, child, base_unified_metadata, text_depth: TextTypeEnum, extract_images: bool, para_idx: int - ): + def _extract_para_text( + self, + paragraph, + paragraph_text, + base_unified_metadata: Dict, + text_depth: "TextTypeEnum", + para_idx: int, + ) -> None: """ - Process the text and images in a docx paragraph + Process the text, images, and styles in a DOCX paragraph. + + Parameters + ---------- + paragraph: Paragraph + The paragraph to process. + paragraph_text: str + The text content of the paragraph. + base_unified_metadata : dict + The base metadata to associate with extracted data. + text_depth : TextTypeEnum + The depth of text extraction (e.g., block-level, document-level). + para_idx : int + The index of the paragraph being processed. + + Returns + ------- + None """ - # Paragraph - paragraph = Paragraph(child, self.document) - paragraph_text, paragraph_images = self.format_paragraph(paragraph) - - if self._prev_para_images: - # build image metadata with image from previous paragraph and text from current - self._extract_para_images( - self._prev_para_images, - self._prev_para_image_idx, - paragraph_text, - base_unified_metadata, - self._extracted_data, - ) - self._prev_para_images = [] - - if extract_images and paragraph_images: - # cache the images till the next paragraph is read - self._prev_para_images = paragraph_images - self._prev_para_image_idx = para_idx - - self.images += paragraph_images + # Handle text styles if desired if self.handle_text_styles: - # Get the level of the paragraph (especially for lists) try: numPr = paragraph._element.xpath("./w:pPr/w:numPr")[0] level = int(numPr.xpath("./w:ilvl/@w:val")[0]) @@ -486,6 +633,7 @@ def _extract_para_data( self._accumulated_text.append(paragraph_text + "\n") + # If text_depth is BLOCK, we flush after each paragraph if text_depth == TextTypeEnum.BLOCK: text_extraction = self._construct_text_metadata( self._accumulated_text, para_idx, text_depth, base_unified_metadata @@ -493,77 +641,233 @@ def _extract_para_data( self._extracted_data.append(text_extraction) self._accumulated_text = [] - def _extract_table_data(self, child, base_unified_metadata, text_depth: TextTypeEnum, para_idx: int): + def _finalize_images(self, extract_tables: bool, extract_charts: bool, **kwargs) -> None: + """ + Build and append final metadata for each pending image in batches. + + Parameters + ---------- + extract_tables : bool + Whether to attempt table detection in images. + extract_charts : bool + Whether to attempt chart detection in images. + **kwargs + Additional configuration for image processing. + + Returns + ------- + None + """ + if not self._pending_images: + return + + # 1) Convert all pending images into numpy arrays (and also store base64 + context), + # so we can run detection on them in one go. + all_image_arrays = [] + image_info = [] # parallel list to hold (para_idx, caption, base_unified_metadata, base64_img) + + for docx_image, para_idx, caption, base_unified_metadata in self._pending_images: + # Convert docx image blob to BytesIO, then to numpy array + image_bytes = docx_image.blob + image_stream = io.BytesIO(image_bytes) + image_array = load_and_preprocess_image(image_stream) + base64_img = str(bytetools.base64frombytes(image_bytes)) + + all_image_arrays.append(image_array) + + # Keep track of all needed metadata so we can rebuild final entries + image_info.append((para_idx, caption, base_unified_metadata, base64_img)) + + # 2) If the user wants to detect tables/charts, do it in one pass for all images. + detection_map = defaultdict(list) # maps image_index -> list of CroppedImageWithContent + + if extract_tables or extract_charts: + try: + # Perform the batched detection on all images + detection_results = extract_tables_and_charts_from_images( + images=all_image_arrays, + config=ImageConfigSchema(**self._extraction_config.model_dump()), + trace_info=kwargs.get("trace_info"), + ) + # detection_results is typically List[Tuple[int, CroppedImageWithContent]] + # Group by image_index + for image_idx, cropped_item in detection_results: + detection_map[image_idx].append(cropped_item) + + except Exception as e: + logger.error(f"Error extracting tables/charts in batch: {e}") + # If something goes wrong, we can fall back to empty detection map + # so that all images are treated normally + detection_map = {} + + # 3) For each pending image, decide if we found tables/charts or not. + for i, _ in enumerate(self._pending_images): + para_idx_i, caption_i, base_unified_metadata_i, base64_img_i = image_info[i] + + # If detection_map[i] is non-empty, we have found table(s)/chart(s). + if i in detection_map and detection_map[i]: + for table_chart_data in detection_map[i]: + # Build structured metadata for each table or chart + structured_entry = construct_table_and_chart_metadata( + structured_image=table_chart_data, # A CroppedImageWithContent + page_idx=0, # docx => single page + page_count=1, + source_metadata=self.properties.source_metadata, + base_unified_metadata=base_unified_metadata_i, + ) + self._extracted_data.append(structured_entry) + else: + # Either detection was not requested, or no table/chart was found + image_entry = self._construct_image_metadata( + para_idx_i, + caption_i, + base_unified_metadata_i, + base64_img_i, + ) + self._extracted_data.append(image_entry) + + # 4) Clear out the pending images after finalizing + self._pending_images = [] + + def _extract_table_data( + self, + child, + base_unified_metadata: Dict, + ) -> None: """ - Process the text in a docx paragraph + Process the text and images in a DOCX table. + + Parameters + ---------- + child : element + The table element to process. + base_unified_metadata : dict + The base metadata to associate with extracted data. + text_depth : TextTypeEnum + The depth of text extraction (e.g., block-level, document-level). + para_idx : int + The index of the table being processed. + + Returns + ------- + None """ + # Table table = Table(child, self.document) table_text, table_images, table_dataframe = self.format_table(table) + self.images += table_images self.tables.append(table_dataframe) - self._accumulated_text.append(table_text + "\n") - if text_depth == TextTypeEnum.BLOCK: - text_extraction = self._construct_text_metadata( - self._accumulated_text, para_idx, text_depth, base_unified_metadata + cropped_image_with_content = CroppedImageWithContent( + content=table_text, + image="", # no image content + bbox=(0, 0, 0, 0), + max_width=0, + max_height=0, + type_string="table", + ) + + self._extracted_data.append( + construct_table_and_chart_metadata( + structured_image=cropped_image_with_content, + page_idx=0, # docx => single page + page_count=1, + source_metadata=self.properties.source_metadata, + base_unified_metadata=base_unified_metadata, ) - if len(text_extraction) > 0: - self._extracted_data.append(text_extraction) - self._accumulated_text = [] + ) def extract_data( self, - base_unified_metadata, - text_depth: TextTypeEnum, + base_unified_metadata: Dict, + text_depth: "TextTypeEnum", extract_text: bool, + extract_charts: bool, extract_tables: bool, extract_images: bool, - ) -> Dict: + ) -> list[list[str | dict]]: """ - Iterate over paragraphs and tables + Iterate over paragraphs and tables in a DOCX document to extract data. + + Parameters + ---------- + base_unified_metadata : dict + The base metadata to associate with all extracted content. + text_depth : TextTypeEnum + The depth of text extraction (e.g., block-level, document-level). + extract_text : bool + Whether to extract text from the document. + extract_charts : bool + Whether to extract charts from the document. + extract_tables : bool + Whether to extract tables from the document. + extract_images : bool + Whether to extract images from the document. + + Returns + ------- + dict + A dictionary containing the extracted data from the document. """ + self._accumulated_text = [] self._extracted_data = [] - - para_idx = 0 + self._pending_images = [] self._prev_para_images = [] self._prev_para_image_idx = 0 + para_idx = 0 + for child in self.document.element.body.iterchildren(): if isinstance(child, CT_P): - if not extract_text: - continue - self._extract_para_data(child, base_unified_metadata, text_depth, extract_images, para_idx) - - if isinstance(child, CT_Tbl): - if not extract_tables: - continue - self._extract_table_data(child, base_unified_metadata, text_depth, para_idx) + paragraph = Paragraph(child, self.document) + paragraph_text, paragraph_images = self.format_paragraph(paragraph) + + if extract_text: + self._extract_para_text( + paragraph, + paragraph_text, + base_unified_metadata, + text_depth, + para_idx, + ) + + if (extract_charts or extract_images or extract_tables) and paragraph_images: + self._prev_para_images = paragraph_images + self._prev_para_image_idx = para_idx + self._pending_images += [(image, para_idx, "", base_unified_metadata) for image in paragraph_images] + self.images += paragraph_images + + elif isinstance(child, CT_Tbl): + if extract_tables or extract_charts: + self._extract_table_data(child, base_unified_metadata) para_idx += 1 - # We treat the document as a single page + # If there's leftover text at the doc’s end if ( extract_text and text_depth in (TextTypeEnum.DOCUMENT, TextTypeEnum.PAGE) and len(self._accumulated_text) > 0 ): text_extraction = self._construct_text_metadata( - self._accumulated_text, -1, text_depth, base_unified_metadata + self._accumulated_text, + -1, + text_depth, + base_unified_metadata, ) - if len(text_extraction) > 0: + + if text_extraction: self._extracted_data.append(text_extraction) - if self._prev_para_images: - # if we got here it means that image was at the end of the document and there - # was no caption for the image - self._extract_para_images( - self._prev_para_images, - self._prev_para_image_idx, - "", - base_unified_metadata, - self._extracted_data, + # Final pass: Decide if images are just images or contain tables/charts + if extract_images or extract_tables or extract_charts: + self._finalize_images( + extract_tables=extract_tables, + extract_charts=extract_charts, + trace_info=None, ) return self._extracted_data diff --git a/src/nv_ingest/extraction_workflows/image/image_handlers.py b/src/nv_ingest/extraction_workflows/image/image_handlers.py index f6a80b8d..f7b12982 100644 --- a/src/nv_ingest/extraction_workflows/image/image_handlers.py +++ b/src/nv_ingest/extraction_workflows/image/image_handlers.py @@ -27,11 +27,12 @@ import numpy as np from PIL import Image +from math import log from wand.image import Image as WandImage import nv_ingest.util.nim.yolox as yolox_utils from nv_ingest.extraction_workflows.pdf.doughnut_utils import crop_image -from nv_ingest.schemas.image_extractor_schema import ImageExtractorSchema +from nv_ingest.schemas.image_extractor_schema import ImageConfigSchema from nv_ingest.schemas.metadata_schema import AccessLevelEnum from nv_ingest.util.image_processing.transforms import numpy_to_base64 from nv_ingest.util.nim.helpers import create_inference_client @@ -173,79 +174,85 @@ def extract_table_and_chart_images( tables_and_charts.append((page_idx, table_data)) -def extract_tables_and_charts_from_image( - image: np.ndarray, - config: ImageExtractorSchema, - num_classes: int = YOLOX_NUM_CLASSES, - conf_thresh: float = YOLOX_CONF_THRESHOLD, - iou_thresh: float = YOLOX_IOU_THRESHOLD, - min_score: float = YOLOX_MIN_SCORE, - final_thresh: float = YOLOX_FINAL_SCORE, +def extract_tables_and_charts_from_images( + images: List[np.ndarray], + config: ImageConfigSchema, trace_info: Optional[List] = None, -) -> List[CroppedImageWithContent]: +) -> List[Tuple[int, object]]: """ - Extract tables and charts from a single image using an ensemble of image-based models. - - This function processes a single image to detect and extract tables and charts. - It uses a sequence of models hosted on different inference servers to achieve this. + Detect and extract tables/charts from a list of NumPy images using YOLOX. Parameters ---------- - image : np.ndarray - A preprocessed image array for table and chart detection. - config : ImageExtractorSchema - Configuration for the inference client, including endpoint URLs and authentication. - num_classes : int, optional - The number of classes the model is trained to detect (default is 3). - conf_thresh : float, optional - The confidence threshold for detection (default is 0.01). - iou_thresh : float, optional - The Intersection Over Union (IoU) threshold for non-maximum suppression (default is 0.5). - min_score : float, optional - The minimum score threshold for considering a detection valid (default is 0.1). - final_thresh: float, optional - Threshold for keeping a bounding box applied after postprocessing (default is 0.48). + images : List[np.ndarray] + List of images in NumPy array format. + config : PDFiumConfigSchema + Configuration object containing YOLOX endpoints, auth token, etc. trace_info : Optional[List], optional - Tracing information for logging or debugging purposes. + Optional tracing data for debugging/performance profiling. Returns ------- - List[CroppedImageWithContent] - A list of `CroppedImageWithContent` objects representing detected tables or charts, - each containing metadata about the detected region. + List[Tuple[int, object]] + A list of (image_index, CroppedImageWithContent) + representing extracted table/chart data from each image. """ tables_and_charts = [] - yolox_client = None + try: - model_interface = yolox_utils.YoloxModelInterface() + model_interface = yolox_utils.YoloxPageElementsModelInterface() yolox_client = create_inference_client( - config.yolox_endpoints, model_interface, config.auth_token, config.yolox_infer_protocol + config.yolox_endpoints, + model_interface, + config.auth_token, + config.yolox_infer_protocol, ) - data = {"images": [image]} + max_batch_size = YOLOX_MAX_BATCH_SIZE + batches = [] + i = 0 + while i < len(images): + batch_size = min(2 ** int(log(len(images) - i, 2)), max_batch_size) + batches.append(images[i : i + batch_size]) # noqa: E203 + i += batch_size + + img_index = 0 + for batch in batches: + data = {"images": batch} + + # NimClient inference + inference_results = yolox_client.infer( + data, + model_name="yolox", + num_classes=YOLOX_NUM_CLASSES, + conf_thresh=YOLOX_CONF_THRESHOLD, + iou_thresh=YOLOX_IOU_THRESHOLD, + min_score=YOLOX_MIN_SCORE, + final_thresh=YOLOX_FINAL_SCORE, + trace_info=trace_info, # traceable_func arg + stage_name="pdf_content_extractor", # traceable_func arg + ) - inference_results = yolox_client.infer( - data, - model_name="yolox", - num_classes=YOLOX_NUM_CLASSES, - conf_thresh=YOLOX_CONF_THRESHOLD, - iou_thresh=YOLOX_IOU_THRESHOLD, - min_score=YOLOX_MIN_SCORE, - final_thresh=YOLOX_FINAL_SCORE, - ) + # 5) Extract table/chart info from each image's annotations + for annotation_dict, original_image in zip(inference_results, batch): + extract_table_and_chart_images( + annotation_dict, + original_image, + img_index, + tables_and_charts, + ) + img_index += 1 - extract_table_and_chart_images( - inference_results, - image, - page_idx=0, # Single image treated as one page - tables_and_charts=tables_and_charts, - ) + except TimeoutError: + logger.error("Timeout error during table/chart extraction.") + raise except Exception as e: - logger.error(f"Error during table/chart extraction from image: {str(e)}") + logger.error(f"Unhandled error during table/chart extraction: {str(e)}") traceback.print_exc() raise e + finally: if yolox_client: yolox_client.close() @@ -282,6 +289,8 @@ def image_data_extractor( Specifies whether to extract tables. extract_charts : bool Specifies whether to extract charts. + trace_info : dict, optional + Tracing information for logging or debugging purposes. **kwargs Additional extraction parameters. @@ -352,13 +361,13 @@ def image_data_extractor( # Table and chart extraction if extract_tables or extract_charts: try: - tables_and_charts = extract_tables_and_charts_from_image( - image_array, + tables_and_charts = extract_tables_and_charts_from_images( + [image_array], config=kwargs.get("image_extraction_config"), trace_info=trace_info, ) logger.debug("Extracted table/chart data from image") - for _, table_chart_data in tables_and_charts: + for _, table_chart_data in tables_and_charts[0]: extracted_data.append( construct_table_and_chart_metadata( table_chart_data, @@ -370,6 +379,7 @@ def image_data_extractor( ) except Exception as e: logger.error(f"Error extracting tables/charts from image: {e}") + raise logger.debug(f"Extracted {len(extracted_data)} items from the image.") diff --git a/src/nv_ingest/extraction_workflows/pptx/pptx_helper.py b/src/nv_ingest/extraction_workflows/pptx/pptx_helper.py index 7e6b6d89..df930f1c 100644 --- a/src/nv_ingest/extraction_workflows/pptx/pptx_helper.py +++ b/src/nv_ingest/extraction_workflows/pptx/pptx_helper.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. # All rights reserved. # SPDX-License-Identifier: Apache-2.0 - +import io # Copyright (c) 2024, NVIDIA CORPORATION. # @@ -21,8 +21,9 @@ import operator import re import uuid +from collections import defaultdict from datetime import datetime -from typing import Dict +from typing import Dict, List, Tuple from typing import Optional import pandas as pd @@ -31,8 +32,14 @@ from pptx.enum.dml import MSO_THEME_COLOR from pptx.enum.shapes import MSO_SHAPE_TYPE from pptx.enum.shapes import PP_PLACEHOLDER +from pptx.shapes.autoshape import Shape from pptx.slide import Slide +from nv_ingest.extraction_workflows.image.image_handlers import ( + load_and_preprocess_image, + extract_tables_and_charts_from_images, +) +from nv_ingest.schemas.image_extractor_schema import ImageConfigSchema from nv_ingest.schemas.metadata_schema import AccessLevelEnum from nv_ingest.schemas.metadata_schema import ContentTypeEnum from nv_ingest.schemas.metadata_schema import ImageTypeEnum @@ -41,70 +48,144 @@ from nv_ingest.schemas.metadata_schema import TableFormatEnum from nv_ingest.schemas.metadata_schema import TextTypeEnum from nv_ingest.schemas.metadata_schema import validate_metadata +from nv_ingest.schemas.pptx_extractor_schema import PPTXConfigSchema from nv_ingest.util.converters import bytetools from nv_ingest.util.detectors.language import detect_language +from nv_ingest.util.pdf.metadata_aggregators import construct_table_and_chart_metadata logger = logging.getLogger(__name__) -# Define a helper function to use python-pptx to extract text from a base64 -# encoded bytestram PPTX -def python_pptx(pptx_stream, extract_text: bool, extract_images: bool, extract_tables: bool, **kwargs): +def _finalize_images( + pending_images: List[Tuple[Shape, int, int, int, dict, dict, dict]], + extracted_data: List, + pptx_extraction_config: PPTXConfigSchema, + extract_tables: bool = False, + extract_charts: bool = False, + trace_info: Optional[Dict] = None, +): + """ + Post-process all pending images. + - Convert shape image -> NumPy or base64 + - If `extract_tables` or `extract_charts`, do detection (table/chart) + - Build the appropriate metadata, either table/chart or image. + + This mimics the docx approach, but adapted for python-pptx shapes. + """ + if not pending_images: + return + + # Convert each shape to image data (base64 or ndarray). + # We'll store them for a single call to your model if you'd like (batching). + image_arrays = [] + image_contexts = [] + for ( + shape, + shape_idx, + slide_idx, + slide_count, + page_nearby_blocks, + source_metadata, + base_unified_metadata, + ) in pending_images: + try: + image_bytes = shape.image.blob + image_array = load_and_preprocess_image(io.BytesIO(image_bytes)) + base64_img = bytetools.base64frombytes(image_bytes) + + image_arrays.append(image_array) + image_contexts.append( + ( + shape_idx, + slide_idx, + slide_count, + page_nearby_blocks, + source_metadata, + base_unified_metadata, + base64_img, + ) + ) + except Exception as e: + logger.warning(f"Unable to process shape image: {e}") + + # If you want table/chart detection for these images, do it now + # (similar to docx approach). This might use your YOLO or other method: + detection_map = defaultdict(list) # image_idx -> list of CroppedImageWithContent + if extract_tables or extract_charts: + try: + # For example, a call to your function that checks for tables/charts + detection_results = extract_tables_and_charts_from_images( + images=image_arrays, + config=ImageConfigSchema(**(pptx_extraction_config.model_dump())), + trace_info=trace_info, + ) + # detection_results is something like [(image_idx, CroppedImageWithContent), ...] + for img_idx, cropped_obj in detection_results: + detection_map[img_idx].append(cropped_obj) + except Exception as e: + logger.error(f"Error while running table/chart detection on PPTX images: {e}") + detection_map = {} + + # Now build the final metadata objects + for i, context in enumerate(image_contexts): + (shape_idx, slide_idx, slide_count, page_nearby_blocks, source_metadata, base_unified_metadata, base64_img) = ( + context + ) + + # If there's a detection result for this image, handle it + if i in detection_map and detection_map[i]: + # We found table(s)/chart(s) in the image + for cropped_item in detection_map[i]: + structured_entry = construct_table_and_chart_metadata( + structured_image=cropped_item, + page_idx=slide_idx, + page_count=slide_count, + source_metadata=source_metadata, + base_unified_metadata=base_unified_metadata, + ) + extracted_data.append(structured_entry) + else: + # No table detected => build normal image metadata + image_entry = _construct_image_metadata( + shape_idx=shape_idx, + slide_idx=slide_idx, + slide_count=slide_count, + page_nearby_blocks=page_nearby_blocks, + base64_img=base64_img, + source_metadata=source_metadata, + base_unified_metadata=base_unified_metadata, + ) + extracted_data.append(image_entry) + + +def python_pptx( + pptx_stream, extract_text: bool, extract_images: bool, extract_tables: bool, extract_charts: bool, **kwargs +): """ - Helper function to use python-pptx to extract text from a bytestream PPTX. - - A document has five levels - presentation, slides, shapes, paragraphs, and runs. - To align with the pdf extraction, we map the levels as follows: - - Document -> Presention - - Pages -> Slides - - Blocks -> Shapes - - Lines -> Paragraphs - - Spans -> Runs - - Parameters - ---------- - pptx_stream : io.BytesIO - A bytestream PPTX. - extract_text : bool - Specifies whether to extract text. - extract_images : bool - Specifies whether to extract images. - extract_tables : bool - Specifies whether to extract tables. - **kwargs - The keyword arguments are used for additional extraction parameters. - - Returns - ------- - str - A string of extracted text. + Helper function to use python-pptx to extract text from a bytestream PPTX, + while deferring image classification into tables/charts if requested. """ logger.debug("Extracting PPTX with python-pptx backend.") row_data = kwargs.get("row_data") - # get source_id source_id = row_data["source_id"] - # get text_depth + text_depth = kwargs.get("text_depth", "page") text_depth = TextTypeEnum[text_depth.upper()] - # Not configurable anywhere at the moment paragraph_format = kwargs.get("paragraph_format", "markdown") identify_nearby_objects = kwargs.get("identify_nearby_objects", True) - # get base metadata metadata_col = kwargs.get("metadata_column", "metadata") + pptx_extractor_config = kwargs.get("pptx_extraction_config", {}) + trace_info = kwargs.get("trace_info", {}) + base_unified_metadata = row_data[metadata_col] if metadata_col in row_data.index else {} - # get base source_metadata base_source_metadata = base_unified_metadata.get("source_metadata", {}) - # get source_location source_location = base_source_metadata.get("source_location", "") - # get collection_id (assuming coming in from source_metadata...) collection_id = base_source_metadata.get("collection_id", "") - # get partition_id (assuming coming in from source_metadata...) partition_id = base_source_metadata.get("partition_id", -1) - # get access_level (assuming coming in from source_metadata...) access_level = base_source_metadata.get("access_level", AccessLevelEnum.LEVEL_1) presentation = Presentation(pptx_stream) @@ -140,6 +221,10 @@ def python_pptx(pptx_stream, extract_text: bool, extract_images: bool, extract_t accumulated_text = [] extracted_data = [] + # Hold images here for final classification + # Each item is (shape, shape_idx, slide_idx, page_nearby_blocks, base_unified_metadata) + pending_images = [] + for slide_idx, slide in enumerate(presentation.slides): shapes = sorted(ungroup_shapes(slide.shapes), key=operator.attrgetter("top", "left")) @@ -153,6 +238,9 @@ def python_pptx(pptx_stream, extract_text: bool, extract_images: bool, extract_t block_text = [] added_title = added_subtitle = False + # --------------------------------------------- + # 1) Text Extraction + # --------------------------------------------- if extract_text and shape.has_text_frame: for paragraph_idx, paragraph in enumerate(shape.text_frame.paragraphs): if not paragraph.text.strip(): @@ -162,21 +250,22 @@ def python_pptx(pptx_stream, extract_text: bool, extract_images: bool, extract_t text = run.text if not text: continue + text = escape_text(text) if paragraph_format == "markdown": - # For titles/subtitles, process them on the block/shape level, and - # skip formatting. if is_title(shape): - if added_title: + if not added_title: + text = process_title(shape) # format a heading or something + added_title = True + else: continue - text = process_title(shape) - added_title = True elif is_subtitle(shape): - if added_subtitle: + if not added_subtitle: + text = process_subtitle(shape) + added_subtitle = True + else: continue - text = process_subtitle(shape) - added_subtitle = True else: if run.hyperlink.address: text = get_hyperlink(text, run.hyperlink.address) @@ -193,9 +282,11 @@ def python_pptx(pptx_stream, extract_text: bool, extract_images: bool, extract_t accumulated_text.append(text) + # For "nearby objects", store block text if extract_images and identify_nearby_objects: block_text.append(text) + # If we only want text at SPAN level, flush after each run if text_depth == TextTypeEnum.SPAN: text_extraction = _construct_text_metadata( presentation, @@ -211,17 +302,15 @@ def python_pptx(pptx_stream, extract_text: bool, extract_images: bool, extract_t source_metadata, base_unified_metadata, ) - if len(text_extraction) > 0: extracted_data.append(text_extraction) - accumulated_text = [] - # Avoid excessive newline characters and add them only at - # the line/paragraph level or higher. + # Add newlines for separation at line/paragraph level if accumulated_text and not accumulated_text[-1].endswith("\n\n"): accumulated_text.append("\n\n") + # If text_depth is LINE, flush after each paragraph if text_depth == TextTypeEnum.LINE: text_extraction = _construct_text_metadata( presentation, @@ -237,12 +326,11 @@ def python_pptx(pptx_stream, extract_text: bool, extract_images: bool, extract_t source_metadata, base_unified_metadata, ) - if len(text_extraction) > 0: extracted_data.append(text_extraction) - accumulated_text = [] + # If text_depth is BLOCK, flush after we've read the entire shape if text_depth == TextTypeEnum.BLOCK: text_extraction = _construct_text_metadata( presentation, @@ -258,54 +346,60 @@ def python_pptx(pptx_stream, extract_text: bool, extract_images: bool, extract_t source_metadata, base_unified_metadata, ) - if len(text_extraction) > 0: extracted_data.append(text_extraction) - accumulated_text = [] - if extract_images and identify_nearby_objects and (len(block_text) > 0): + # If we have text in this shape and the user wants "nearby objects" references: + if extract_images and identify_nearby_objects and block_text: page_nearby_blocks["text"]["content"].append("".join(block_text)) page_nearby_blocks["text"]["bbox"].append(get_bbox(shape_object=shape)) + # --------------------------------------------- + # 2) Image Handling (DEFERRED) + # --------------------------------------------- + # If shape is a picture (or a placeholder that is an embedded image) + # Instead of building metadata now, we'll store it in pending_images. if extract_images and ( shape.shape_type == MSO_SHAPE_TYPE.PICTURE or ( shape.is_placeholder and shape.placeholder_format.type == PP_PLACEHOLDER.OBJECT and hasattr(shape, "image") - and getattr(shape, "image") ) ): try: - image_extraction = _construct_image_metadata( - shape, - shape_idx, - slide_idx, - slide_count, - source_metadata, - base_unified_metadata, - page_nearby_blocks, + # Just accumulate the shape + context; don't build the final item yet. + pending_images.append( + ( + shape, # so we can later pull shape.image.blob + shape_idx, + slide_idx, + slide_count, + page_nearby_blocks, + source_metadata, + base_unified_metadata, + ) ) - extracted_data.append(image_extraction) except ValueError as e: - # Handle the specific case where no embedded image is found logger.warning(f"No embedded image found for shape {shape_idx} on slide {slide_idx}: {e}") except Exception as e: - # Handle any other exceptions that might occur - logger.warning(f"An error occurred while processing shape {shape_idx} on slide {slide_idx}: {e}") + logger.warning(f"Error processing shape {shape_idx} on slide {slide_idx}: {e}") + # --------------------------------------------- + # 3) Table Handling + # --------------------------------------------- if extract_tables and shape.has_table: table_extraction = _construct_table_metadata( shape, slide_idx, slide_count, source_metadata, base_unified_metadata ) extracted_data.append(table_extraction) - # Extract text - slide (b) + # If text_depth is PAGE, flush once per slide if (extract_text) and (text_depth == TextTypeEnum.PAGE) and (len(accumulated_text) > 0): text_extraction = _construct_text_metadata( presentation, - shape, + shape, # might pass None if you prefer accumulated_text, keywords, slide_idx, @@ -317,17 +411,15 @@ def python_pptx(pptx_stream, extract_text: bool, extract_images: bool, extract_t source_metadata, base_unified_metadata, ) - if len(text_extraction) > 0: extracted_data.append(text_extraction) - accumulated_text = [] - # Extract text - presentation (c) + # If text_depth is DOCUMENT, flush once at the end if (extract_text) and (text_depth == TextTypeEnum.DOCUMENT) and (len(accumulated_text) > 0): text_extraction = _construct_text_metadata( presentation, - shape, + shape, # might pass None accumulated_text, keywords, -1, @@ -339,12 +431,23 @@ def python_pptx(pptx_stream, extract_text: bool, extract_images: bool, extract_t source_metadata, base_unified_metadata, ) - if len(text_extraction) > 0: extracted_data.append(text_extraction) - accumulated_text = [] + # --------------------------------------------- + # FINAL STEP: Finalize images + # --------------------------------------------- + if extract_images or extract_tables or extract_charts: + _finalize_images( + pending_images, + extracted_data, + pptx_extractor_config, + extract_tables=extract_tables, + extract_charts=extract_charts, + trace_info=trace_info, + ) + return extracted_data @@ -410,17 +513,19 @@ def _construct_text_metadata( # need to add block text to hierarchy/nearby_objects, including bbox def _construct_image_metadata( - shape, shape_idx, slide_idx, slide_count, source_metadata, base_unified_metadata, page_nearby_blocks + shape_idx: int, + slide_idx: int, + slide_count: int, + page_nearby_blocks: Dict, + base64_img: str, + source_metadata: Dict, + base_unified_metadata: Dict, ): - image_type = shape.image.ext - if ImageTypeEnum.has_value(image_type): - image_type = ImageTypeEnum[image_type.upper()] - - base64_img = bytetools.base64frombytes(shape.image.blob) - - bbox = get_bbox(shape_object=shape) - width = shape.width - height = shape.height + """ + Build standard PPTX image metadata. + """ + # Example bounding box + bbox = (0, 0, 0, 0) # or extract from shape.left, shape.top, shape.width, shape.height if desired content_metadata = { "type": ContentTypeEnum.IMAGE, @@ -437,17 +542,14 @@ def _construct_image_metadata( } image_metadata = { - "image_type": image_type, + "image_type": ImageTypeEnum.image_type_1, "structured_image_type": ImageTypeEnum.image_type_1, - "caption": "", + "caption": "", # could attempt to guess a caption from nearby text "text": "", "image_location": bbox, - "width": width, - "height": height, } - unified_metadata = base_unified_metadata.copy() - + unified_metadata = base_unified_metadata.copy() if base_unified_metadata else {} unified_metadata.update( { "content": base64_img, @@ -459,7 +561,11 @@ def _construct_image_metadata( validated_unified_metadata = validate_metadata(unified_metadata) - return [ContentTypeEnum.IMAGE, validated_unified_metadata.model_dump(), str(uuid.uuid4())] + return [ + ContentTypeEnum.IMAGE.value, + validated_unified_metadata.model_dump(), + str(uuid.uuid4()), + ] def _construct_table_metadata( @@ -492,12 +598,13 @@ def _construct_table_metadata( "caption": "", "table_format": TableFormatEnum.MARKDOWN, "table_location": bbox, + "table_content": df.to_markdown(index=False), } ext_unified_metadata = base_unified_metadata.copy() ext_unified_metadata.update( { - "content": df.to_markdown(index=False), + "content": "", "source_metadata": source_metadata, "content_metadata": content_metadata, "table_metadata": table_metadata, diff --git a/src/nv_ingest/schemas/docx_extractor_schema.py b/src/nv_ingest/schemas/docx_extractor_schema.py new file mode 100644 index 00000000..5204674e --- /dev/null +++ b/src/nv_ingest/schemas/docx_extractor_schema.py @@ -0,0 +1,124 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + + +import logging +from typing import Optional +from typing import Tuple + +from pydantic import model_validator, ConfigDict, BaseModel + +logger = logging.getLogger(__name__) + + +class DocxConfigSchema(BaseModel): + """ + Configuration schema for docx extraction endpoints and options. + + Parameters + ---------- + auth_token : Optional[str], default=None + Authentication token required for secure services. + + yolox_endpoints : Tuple[str, str] + A tuple containing the gRPC and HTTP services for the yolox endpoint. + Either the gRPC or HTTP service can be empty, but not both. + + Methods + ------- + validate_endpoints(values) + Validates that at least one of the gRPC or HTTP services is provided for each endpoint. + + Raises + ------ + ValueError + If both gRPC and HTTP services are empty for any endpoint. + + Config + ------ + extra : str + Pydantic config option to forbid extra fields. + """ + + auth_token: Optional[str] = None + + yolox_endpoints: Tuple[Optional[str], Optional[str]] = (None, None) + yolox_infer_protocol: str = "" + + @model_validator(mode="before") + @classmethod + def validate_endpoints(cls, values): + """ + Validates the gRPC and HTTP services for all endpoints. + + Parameters + ---------- + values : dict + Dictionary containing the values of the attributes for the class. + + Returns + ------- + dict + The validated dictionary of values. + + Raises + ------ + ValueError + If both gRPC and HTTP services are empty for any endpoint. + """ + + def clean_service(service): + """Set service to None if it's an empty string or contains only spaces or quotes.""" + if service is None or not service.strip() or service.strip(" \"'") == "": + return None + return service + + for model_name in ["yolox"]: + endpoint_name = f"{model_name}_endpoints" + grpc_service, http_service = values.get(endpoint_name) + grpc_service = clean_service(grpc_service) + http_service = clean_service(http_service) + + if not grpc_service and not http_service: + raise ValueError(f"Both gRPC and HTTP services cannot be empty for {endpoint_name}.") + + values[endpoint_name] = (grpc_service, http_service) + + protocol_name = f"{model_name}_infer_protocol" + protocol_value = values.get(protocol_name) + if not protocol_value: + protocol_value = "http" if http_service else "grpc" if grpc_service else "" + protocol_value = protocol_value.lower() + values[protocol_name] = protocol_value + + return values + + model_config = ConfigDict(extra="forbid") + + +class DocxExtractorSchema(BaseModel): + """ + Configuration schema for the PDF extractor settings. + + Parameters + ---------- + max_queue_size : int, default=1 + The maximum number of items allowed in the processing queue. + + n_workers : int, default=16 + The number of worker threads to use for processing. + + raise_on_failure : bool, default=False + A flag indicating whether to raise an exception on processing failure. + + image_extraction_config: Optional[ImageConfigSchema], default=None + Configuration schema for the image extraction stage. + """ + + max_queue_size: int = 1 + n_workers: int = 16 + raise_on_failure: bool = False + + docx_extraction_config: Optional[DocxConfigSchema] = None + model_config = ConfigDict(extra="forbid") diff --git a/src/nv_ingest/schemas/ingest_pipeline_config_schema.py b/src/nv_ingest/schemas/ingest_pipeline_config_schema.py index 1471a338..fe5debd6 100644 --- a/src/nv_ingest/schemas/ingest_pipeline_config_schema.py +++ b/src/nv_ingest/schemas/ingest_pipeline_config_schema.py @@ -22,7 +22,7 @@ from nv_ingest.schemas.otel_meter_schema import OpenTelemetryMeterSchema from nv_ingest.schemas.otel_tracer_schema import OpenTelemetryTracerSchema from nv_ingest.schemas.pdf_extractor_schema import PDFExtractorSchema -from nv_ingest.schemas.pptx_extractor_schema import PPTXExctractorSchema +from nv_ingest.schemas.pptx_extractor_schema import PPTXExtractorSchema from nv_ingest.schemas.table_extractor_schema import TableExtractorSchema logger = logging.getLogger(__name__) @@ -42,7 +42,7 @@ class PipelineConfigSchema(BaseModel): otel_meter_module: OpenTelemetryMeterSchema = OpenTelemetryMeterSchema() otel_tracer_module: OpenTelemetryTracerSchema = OpenTelemetryTracerSchema() pdf_extractor_module: PDFExtractorSchema = PDFExtractorSchema() - pptx_extractor_module: PPTXExctractorSchema = PPTXExctractorSchema() + pptx_extractor_module: PPTXExtractorSchema = PPTXExtractorSchema() redis_task_sink: MessageBrokerTaskSinkSchema = MessageBrokerTaskSinkSchema() redis_task_source: MessageBrokerTaskSourceSchema = MessageBrokerTaskSourceSchema() table_extractor_module: TableExtractorSchema = TableExtractorSchema() diff --git a/src/nv_ingest/schemas/pptx_extractor_schema.py b/src/nv_ingest/schemas/pptx_extractor_schema.py index 987ac671..d3897075 100644 --- a/src/nv_ingest/schemas/pptx_extractor_schema.py +++ b/src/nv_ingest/schemas/pptx_extractor_schema.py @@ -3,8 +3,122 @@ # SPDX-License-Identifier: Apache-2.0 -from nv_ingest.schemas.pdf_extractor_schema import PDFExtractorSchema +import logging +from typing import Optional +from typing import Tuple +from pydantic import model_validator, ConfigDict, BaseModel -class PPTXExctractorSchema(PDFExtractorSchema): - pass +logger = logging.getLogger(__name__) + + +class PPTXConfigSchema(BaseModel): + """ + Configuration schema for docx extraction endpoints and options. + + Parameters + ---------- + auth_token : Optional[str], default=None + Authentication token required for secure services. + + yolox_endpoints : Tuple[str, str] + A tuple containing the gRPC and HTTP services for the yolox endpoint. + Either the gRPC or HTTP service can be empty, but not both. + + Methods + ------- + validate_endpoints(values) + Validates that at least one of the gRPC or HTTP services is provided for each endpoint. + + Raises + ------ + ValueError + If both gRPC and HTTP services are empty for any endpoint. + + Config + ------ + extra : str + Pydantic config option to forbid extra fields. + """ + + auth_token: Optional[str] = None + + yolox_endpoints: Tuple[Optional[str], Optional[str]] = (None, None) + yolox_infer_protocol: str = "" + + @model_validator(mode="before") + @classmethod + def validate_endpoints(cls, values): + """ + Validates the gRPC and HTTP services for all endpoints. + + Parameters + ---------- + values : dict + Dictionary containing the values of the attributes for the class. + + Returns + ------- + dict + The validated dictionary of values. + + Raises + ------ + ValueError + If both gRPC and HTTP services are empty for any endpoint. + """ + + def clean_service(service): + """Set service to None if it's an empty string or contains only spaces or quotes.""" + if service is None or not service.strip() or service.strip(" \"'") == "": + return None + return service + + for model_name in ["yolox"]: + endpoint_name = f"{model_name}_endpoints" + grpc_service, http_service = values.get(endpoint_name) + grpc_service = clean_service(grpc_service) + http_service = clean_service(http_service) + + if not grpc_service and not http_service: + raise ValueError(f"Both gRPC and HTTP services cannot be empty for {endpoint_name}.") + + values[endpoint_name] = (grpc_service, http_service) + + protocol_name = f"{model_name}_infer_protocol" + protocol_value = values.get(protocol_name) + if not protocol_value: + protocol_value = "http" if http_service else "grpc" if grpc_service else "" + protocol_value = protocol_value.lower() + values[protocol_name] = protocol_value + + return values + + model_config = ConfigDict(extra="forbid") + + +class PPTXExtractorSchema(BaseModel): + """ + Configuration schema for the PDF extractor settings. + + Parameters + ---------- + max_queue_size : int, default=1 + The maximum number of items allowed in the processing queue. + + n_workers : int, default=16 + The number of worker threads to use for processing. + + raise_on_failure : bool, default=False + A flag indicating whether to raise an exception on processing failure. + + image_extraction_config: Optional[ImageConfigSchema], default=None + Configuration schema for the image extraction stage. + """ + + max_queue_size: int = 1 + n_workers: int = 16 + raise_on_failure: bool = False + + pptx_extraction_config: Optional[PPTXConfigSchema] = None + model_config = ConfigDict(extra="forbid") diff --git a/src/nv_ingest/stages/docx_extractor_stage.py b/src/nv_ingest/stages/docx_extractor_stage.py index 7fcc434c..953eefc1 100644 --- a/src/nv_ingest/stages/docx_extractor_stage.py +++ b/src/nv_ingest/stages/docx_extractor_stage.py @@ -8,19 +8,70 @@ import io import logging import traceback +from typing import Optional, Dict, Any import pandas as pd from pydantic import BaseModel from morpheus.config import Config from nv_ingest.extraction_workflows import docx +from nv_ingest.schemas.docx_extractor_schema import DocxExtractorSchema from nv_ingest.stages.multiprocessing_stage import MultiProcessingBaseStage from nv_ingest.util.exception_handlers.pdf import create_exception_tag logger = logging.getLogger(f"morpheus.{__name__}") -def _process_docx_bytes(df, task_props): +def decode_and_extract(base64_row, task_props, validated_config: Any, trace_info: Dict, default="python_docx"): + if isinstance(task_props, BaseModel): + task_props = task_props.model_dump() + + # Base64 content to extract + base64_content = base64_row["content"] + # Row data to include in extraction + bool_index = base64_row.index.isin(("content",)) + row_data = base64_row[~bool_index] + task_props["params"]["row_data"] = row_data + # Get source_id + source_id = base64_row["source_id"] if "source_id" in base64_row.index else None + # Decode the base64 content + doc_bytes = base64.b64decode(base64_content) + + # Load the document + doc_stream = io.BytesIO(doc_bytes) + + # Type of extraction method to use + extract_method = task_props.get("method", "python_docx") + extract_params = task_props.get("params", {}) + try: + if validated_config.docx_extraction_config is not None: + extract_params["docx_extraction_config"] = validated_config.docx_extraction_config + + if trace_info is not None: + extract_params["trace_info"] = trace_info + + if not hasattr(docx, extract_method): + extract_method = default + + func = getattr(docx, extract_method, default) + logger.debug("Running extraction method: %s", extract_method) + extracted_data = func(doc_stream, **extract_params) + + return extracted_data + + except Exception as error: + traceback.print_exc() + log_error_message = f"Error loading extractor:{error}" + logger.error(log_error_message) + logger.error(f"Failed on file:{source_id}") + + # Propagate error back and tag message as failed. + exception_tag = create_exception_tag(error_message=log_error_message, source_id=source_id) + + return exception_tag + + +def _process_docx_bytes(df, task_props, validated_config: Any, trace_info: Optional[Dict[str, Any]] = None): """ Processes a cuDF DataFrame containing docx files in base64 encoding. Each document's content is replaced with its extracted text. @@ -33,51 +84,11 @@ def _process_docx_bytes(df, task_props): - A pandas DataFrame with the docx content replaced by the extracted text. """ - def decode_and_extract(base64_row, task_props, default="python_docx"): - if isinstance(task_props, BaseModel): - task_props = task_props.model_dump() - - # Base64 content to extract - base64_content = base64_row["content"] - # Row data to include in extraction - bool_index = base64_row.index.isin(("content",)) - row_data = base64_row[~bool_index] - task_props["params"]["row_data"] = row_data - # Get source_id - source_id = base64_row["source_id"] if "source_id" in base64_row.index else None - # Decode the base64 content - doc_bytes = base64.b64decode(base64_content) - - # Load the document - doc_stream = io.BytesIO(doc_bytes) - - # Type of extraction method to use - extract_method = task_props.get("method", "python_docx") - extract_params = task_props.get("params", {}) - if not hasattr(docx, extract_method): - extract_method = default - try: - func = getattr(docx, extract_method, default) - logger.debug("Running extraction method: %s", extract_method) - extracted_data = func(doc_stream, **extract_params) - - return extracted_data - - except Exception as e: - traceback.print_exc() - log_error_message = f"Error loading extractor:{e}" - logger.error(log_error_message) - logger.error(f"Failed on file:{source_id}") - - # Propagate error back and tag message as failed. - exception_tag = create_exception_tag(error_message=log_error_message, source_id=source_id) - - return exception_tag - try: # Apply the helper function to each row in the 'content' column - _decode_and_extract = functools.partial(decode_and_extract, task_props=task_props) - logger.debug(f"processing ({task_props.get('method', None)})") + _decode_and_extract = functools.partial( + decode_and_extract, task_props=task_props, validated_config=validated_config, trace_info=trace_info + ) sr_extraction = df.apply(_decode_and_extract, axis=1) sr_extraction = sr_extraction.explode().dropna() @@ -92,12 +103,14 @@ def decode_and_extract(base64_row, task_props, default="python_docx"): except Exception as e: traceback.print_exc() logger.error(f"Failed to extract text from document: {e}") + raise return df def generate_docx_extractor_stage( c: Config, + extractor_config: dict, task: str = "docx-extract", task_desc: str = "docx_content_extractor", pe_count: int = 24, @@ -109,6 +122,8 @@ def generate_docx_extractor_stage( ---------- c : Config Morpheus global configuration object + extractor_config : dict + Configuration parameters for document content extractor. task : str The task name to match for the stage worker function. task_desc : str @@ -121,7 +136,9 @@ def generate_docx_extractor_stage( MultiProcessingBaseStage A Morpheus stage with applied worker function. """ + validated_config = DocxExtractorSchema(**extractor_config) + _wrapped_process_fn = functools.partial(_process_docx_bytes, validated_config=validated_config) return MultiProcessingBaseStage( - c=c, pe_count=pe_count, task=task, task_desc=task_desc, process_fn=_process_docx_bytes, document_type="docx" + c=c, pe_count=pe_count, task=task, task_desc=task_desc, process_fn=_wrapped_process_fn, document_type="docx" ) diff --git a/src/nv_ingest/stages/extractors/image_extractor_stage.py b/src/nv_ingest/stages/extractors/image_extractor_stage.py index 9bf97029..c0e90c28 100644 --- a/src/nv_ingest/stages/extractors/image_extractor_stage.py +++ b/src/nv_ingest/stages/extractors/image_extractor_stage.py @@ -81,8 +81,6 @@ def decode_and_extract( source_id = base64_row["source_id"] if "source_id" in base64_row.index else None # Decode the base64 content image_bytes = base64.b64decode(base64_content) - - # Load the PDF image_stream = io.BytesIO(image_bytes) # Type of extraction method to use diff --git a/src/nv_ingest/stages/nim/chart_extraction.py b/src/nv_ingest/stages/nim/chart_extraction.py index 127b7f85..3890c62e 100644 --- a/src/nv_ingest/stages/nim/chart_extraction.py +++ b/src/nv_ingest/stages/nim/chart_extraction.py @@ -66,6 +66,7 @@ def _update_metadata(row: pd.Series, cached_client: NimClient, deplot_client: Ni (content_metadata.get("type") != "structured") or (content_metadata.get("subtype") != "chart") or (chart_metadata is None) + or (base64_image in [None, ""]) ): return metadata diff --git a/src/nv_ingest/stages/nim/table_extraction.py b/src/nv_ingest/stages/nim/table_extraction.py index b64b4949..dd803af1 100644 --- a/src/nv_ingest/stages/nim/table_extraction.py +++ b/src/nv_ingest/stages/nim/table_extraction.py @@ -67,6 +67,7 @@ def _update_metadata(row: pd.Series, paddle_client: NimClient, trace_info: Dict) (content_metadata.get("type") != "structured") or (content_metadata.get("subtype") != "table") or (table_metadata is None) + or (base64_image in [None, ""]) ): return metadata diff --git a/src/nv_ingest/stages/pptx_extractor_stage.py b/src/nv_ingest/stages/pptx_extractor_stage.py index 9512a2f4..efbf848b 100644 --- a/src/nv_ingest/stages/pptx_extractor_stage.py +++ b/src/nv_ingest/stages/pptx_extractor_stage.py @@ -8,6 +8,7 @@ import io import logging import traceback +from typing import Any, Optional, Dict import pandas as pd from pydantic import BaseModel @@ -15,12 +16,61 @@ from nv_ingest.extraction_workflows import pptx from nv_ingest.stages.multiprocessing_stage import MultiProcessingBaseStage +from nv_ingest.schemas.pptx_extractor_schema import PPTXExtractorSchema from nv_ingest.util.exception_handlers.pdf import create_exception_tag logger = logging.getLogger(f"morpheus.{__name__}") -def _process_pptx_bytes(df, task_props): +def decode_and_extract(base64_row, task_props, validated_config: Any, trace_info: Dict, default="python_pptx"): + if isinstance(task_props, BaseModel): + task_props = task_props.model_dump() + + # Base64 content to extract + base64_content = base64_row["content"] + # Row data to include in extraction + bool_index = base64_row.index.isin(("content",)) + row_data = base64_row[~bool_index] + task_props["params"]["row_data"] = row_data + # Get source_id + source_id = base64_row["source_id"] if "source_id" in base64_row.index else None + # Decode the base64 content + pptx_bytes = base64.b64decode(base64_content) + + # Load the PPTX + pptx_stream = io.BytesIO(pptx_bytes) + + # Type of extraction method to use + extract_method = task_props.get("method", "python_pptx") + extract_params = task_props.get("params", {}) + if not hasattr(pptx, extract_method): + extract_method = default + try: + if validated_config.pptx_extraction_config is not None: + extract_params["pptx_extraction_config"] = validated_config.pptx_extraction_config + + if trace_info is not None: + extract_params["trace_info"] = trace_info + + func = getattr(pptx, extract_method, default) + logger.debug("Running extraction method: %s", extract_method) + extracted_data = func(pptx_stream, **extract_params) + + return extracted_data + + except Exception as e: + traceback.print_exc() + log_error_message = f"Error loading extractor:{e}" + logger.error(log_error_message) + logger.error(f"Failed on file:{source_id}") + + # Propagate error back and tag message as failed. + exception_tag = create_exception_tag(error_message=log_error_message, source_id=source_id) + + return exception_tag + + +def _process_pptx_bytes(df, task_props: dict, validated_config: Any, trace_info: Optional[Dict[str, Any]] = None): """ Processes a cuDF DataFrame containing PPTX files in base64 encoding. Each PPTX's content is replaced with its extracted text. @@ -32,52 +82,13 @@ def _process_pptx_bytes(df, task_props): Returns: - A pandas DataFrame with the PPTX content replaced by the extracted text. """ - - def decode_and_extract(base64_row, task_props, default="python_pptx"): - if isinstance(task_props, BaseModel): - task_props = task_props.model_dump() - - # Base64 content to extract - base64_content = base64_row["content"] - # Row data to include in extraction - bool_index = base64_row.index.isin(("content",)) - row_data = base64_row[~bool_index] - task_props["params"]["row_data"] = row_data - # Get source_id - source_id = base64_row["source_id"] if "source_id" in base64_row.index else None - # Decode the base64 content - pptx_bytes = base64.b64decode(base64_content) - - # Load the PPTX - pptx_stream = io.BytesIO(pptx_bytes) - - # Type of extraction method to use - extract_method = task_props.get("method", "python_pptx") - extract_params = task_props.get("params", {}) - if not hasattr(pptx, extract_method): - extract_method = default - try: - func = getattr(pptx, extract_method, default) - logger.debug("Running extraction method: %s", extract_method) - extracted_data = func(pptx_stream, **extract_params) - - return extracted_data - - except Exception as e: - traceback.print_exc() - log_error_message = f"Error loading extractor:{e}" - logger.error(log_error_message) - logger.error(f"Failed on file:{source_id}") - - # Propagate error back and tag message as failed. - exception_tag = create_exception_tag(error_message=log_error_message, source_id=source_id) - - return exception_tag - try: # Apply the helper function to each row in the 'content' column - _decode_and_extract = functools.partial(decode_and_extract, task_props=task_props) - logger.debug(f"processing ({task_props.get('method', None)})") + _decode_and_extract = functools.partial( + decode_and_extract, task_props=task_props, validated_config=validated_config, trace_info=trace_info + ) + + # logger.debug(f"processing ({task_props.get('method', None)})") sr_extraction = df.apply(_decode_and_extract, axis=1) sr_extraction = sr_extraction.explode().dropna() @@ -91,12 +102,14 @@ def decode_and_extract(base64_row, task_props, default="python_pptx"): except Exception as e: traceback.print_exc() logger.error(f"Failed to extract text from PPTX: {e}") + raise return df def generate_pptx_extractor_stage( c: Config, + extractor_config: dict, task: str = "pptx-extract", task_desc: str = "pptx_content_extractor", pe_count: int = 24, @@ -108,6 +121,8 @@ def generate_pptx_extractor_stage( ---------- c : Config Morpheus global configuration object + extractor_config : dict + Configuration parameters for document content extractor. task : str The task name to match for the stage worker function. task_desc : str @@ -121,6 +136,9 @@ def generate_pptx_extractor_stage( A Morpheus stage with applied worker function. """ + validated_config = PPTXExtractorSchema(**extractor_config) + _wrapped_process_fn = functools.partial(_process_pptx_bytes, validated_config=validated_config) + return MultiProcessingBaseStage( - c=c, pe_count=pe_count, task=task, task_desc=task_desc, process_fn=_process_pptx_bytes, document_type="pptx" + c=c, pe_count=pe_count, task=task, task_desc=task_desc, process_fn=_wrapped_process_fn, document_type="pptx" ) diff --git a/src/nv_ingest/util/pdf/metadata_aggregators.py b/src/nv_ingest/util/pdf/metadata_aggregators.py index 8c6237f7..3fac696e 100644 --- a/src/nv_ingest/util/pdf/metadata_aggregators.py +++ b/src/nv_ingest/util/pdf/metadata_aggregators.py @@ -29,7 +29,6 @@ from nv_ingest.util.exception_handlers.pdf import pdfium_exception_handler -# TODO(Devin): Shift to this, since there is no difference between ImageTable and ImageChart @dataclass class CroppedImageWithContent: content: str diff --git a/src/nv_ingest/util/pipeline/pipeline_builders.py b/src/nv_ingest/util/pipeline/pipeline_builders.py index 842682f0..efeca97f 100644 --- a/src/nv_ingest/util/pipeline/pipeline_builders.py +++ b/src/nv_ingest/util/pipeline/pipeline_builders.py @@ -30,8 +30,8 @@ def setup_ingestion_pipeline( ######################################################################################################## pdf_extractor_stage = add_pdf_extractor_stage(pipe, morpheus_pipeline_config, ingest_config, default_cpu_count) image_extractor_stage = add_image_extractor_stage(pipe, morpheus_pipeline_config, ingest_config, default_cpu_count) - docx_extractor_stage = add_docx_extractor_stage(pipe, morpheus_pipeline_config, default_cpu_count) - pptx_extractor_stage = add_pptx_extractor_stage(pipe, morpheus_pipeline_config, default_cpu_count) + docx_extractor_stage = add_docx_extractor_stage(pipe, morpheus_pipeline_config, ingest_config, default_cpu_count) + pptx_extractor_stage = add_pptx_extractor_stage(pipe, morpheus_pipeline_config, ingest_config, default_cpu_count) ######################################################################################################## ######################################################################################################## diff --git a/src/nv_ingest/util/pipeline/stage_builders.py b/src/nv_ingest/util/pipeline/stage_builders.py index 352ed006..7dd80f3a 100644 --- a/src/nv_ingest/util/pipeline/stage_builders.py +++ b/src/nv_ingest/util/pipeline/stage_builders.py @@ -273,16 +273,28 @@ def add_image_extractor_stage(pipe, morpheus_pipeline_config, ingest_config, def extractor_config=image_extractor_config, pe_count=8, task="extract", - task_desc="docx_content_extractor", + task_desc="image_content_extractor", ) ) return image_extractor_stage -def add_docx_extractor_stage(pipe, morpheus_pipeline_config, default_cpu_count): +def add_docx_extractor_stage(pipe, morpheus_pipeline_config, ingest_config, default_cpu_count): + yolox_grpc, yolox_http, yolox_auth, yolox_protocol = get_table_detection_service("yolox") + docx_extractor_config = ingest_config.get( + "docx_extraction_module", + { + "docx_extraction_config": { + "yolox_endpoints": (yolox_grpc, yolox_http), + "yolox_infer_protocol": yolox_protocol, + "auth_token": yolox_auth, + } + }, + ) docx_extractor_stage = pipe.add_stage( generate_docx_extractor_stage( morpheus_pipeline_config, + extractor_config=docx_extractor_config, pe_count=1, task="extract", task_desc="docx_content_extractor", @@ -291,10 +303,22 @@ def add_docx_extractor_stage(pipe, morpheus_pipeline_config, default_cpu_count): return docx_extractor_stage -def add_pptx_extractor_stage(pipe, morpheus_pipeline_config, default_cpu_count): +def add_pptx_extractor_stage(pipe, morpheus_pipeline_config, ingest_config, default_cpu_count): + yolox_grpc, yolox_http, yolox_auth, yolox_protocol = get_table_detection_service("yolox") + pptx_extractor_config = ingest_config.get( + "pptx_extraction_module", + { + "pptx_extraction_config": { + "yolox_endpoints": (yolox_grpc, yolox_http), + "yolox_infer_protocol": yolox_protocol, + "auth_token": yolox_auth, + } + }, + ) pptx_extractor_stage = pipe.add_stage( generate_pptx_extractor_stage( morpheus_pipeline_config, + extractor_config=pptx_extractor_config, pe_count=1, task="extract", task_desc="pptx_content_extractor", diff --git a/src/util/image_viewer.py b/src/util/image_viewer.py index b47ccbdd..cebac902 100644 --- a/src/util/image_viewer.py +++ b/src/util/image_viewer.py @@ -31,12 +31,33 @@ def load_images_from_json(json_file_path): with open(json_file_path, "r") as file: data = json.load(file) + def create_default_image(): + """Create a solid black 300×300 image.""" + width, height = 300, 300 + default_img = Image.new("RGB", (width, height), color="black") + return default_img + images = [] for item in data: # Assuming the JSON is a list of objects if item["document_type"] in ("image", "structured"): - image_data = base64.b64decode(item["metadata"]["content"]) - image = Image.open(BytesIO(image_data)) - images.append(image) + content = item.get("metadata", {}).get("content", "") + # Check if content is missing or empty + if not content: + images.append(create_default_image()) + continue + + # Attempt to decode and open the image + try: + image_data = base64.b64decode(content) + temp_image = Image.open(BytesIO(image_data)) + # Verify & re-open to ensure no corruption or errors + temp_image.verify() + temp_image = Image.open(BytesIO(image_data)) + images.append(temp_image) + except Exception: + # If there's any error decoding/reading the image, use the default + images.append(create_default_image()) + return images diff --git a/tests/nv_ingest/extraction_workflows/docx/test_docx_helper.py b/tests/nv_ingest/extraction_workflows/docx/test_docx_helper.py index e56d003d..341ea68c 100644 --- a/tests/nv_ingest/extraction_workflows/docx/test_docx_helper.py +++ b/tests/nv_ingest/extraction_workflows/docx/test_docx_helper.py @@ -9,6 +9,7 @@ import pytest from nv_ingest.extraction_workflows.docx.docx_helper import python_docx +from nv_ingest.schemas.metadata_schema import ImageTypeEnum @pytest.fixture @@ -37,6 +38,7 @@ def test_docx_all_text(doc_stream, document_df): extract_text=True, extract_images=False, extract_tables=False, + extract_charts=False, row_data=document_df.iloc[0], ) @@ -64,6 +66,7 @@ def test_docx_all_text(doc_stream, document_df): assert extracted_data[0][1]["source_metadata"]["source_id"] == "woods_frost" +@pytest.mark.xfail(reason="Table extract requires yolox, disabling for now") def test_docx_table(doc_stream, document_df): """ Validate text and table extraction. Table content is converted into markdown text. @@ -73,6 +76,7 @@ def test_docx_table(doc_stream, document_df): extract_text=True, extract_images=False, extract_tables=True, + extract_charts=False, row_data=document_df.iloc[0], ) @@ -108,11 +112,11 @@ def test_docx_image(doc_stream, document_df): doc_stream, extract_text=True, extract_images=True, - extract_tables=True, + extract_tables=False, + extract_charts=False, row_data=document_df.iloc[0], ) - expected_captions = ["*Figure 1: Snowy Woods*", "*Figure 2: Robert Frost*"] expected_text_cnt = 1 expected_image_cnt = 2 expected_entry_cnt = expected_image_cnt + expected_text_cnt @@ -133,11 +137,4 @@ def test_docx_image(doc_stream, document_df): assert extracted_data[idx][0] == "image" # validate image type - assert extracted_data[idx][1]["image_metadata"]["image_type"] == "jpeg" - - # validate captions - expected_caption = expected_captions[idx] - extracted_caption = extracted_data[idx][1]["image_metadata"]["caption"] - assert extracted_caption == expected_caption - - assert image_cnt == expected_image_cnt + assert extracted_data[idx][1]["image_metadata"]["image_type"] == ImageTypeEnum.image_type_1 diff --git a/tests/nv_ingest/extraction_workflows/pptx/test_pptx_helper.py b/tests/nv_ingest/extraction_workflows/pptx/test_pptx_helper.py index 1a85c95c..43e799d9 100644 --- a/tests/nv_ingest/extraction_workflows/pptx/test_pptx_helper.py +++ b/tests/nv_ingest/extraction_workflows/pptx/test_pptx_helper.py @@ -1,8 +1,7 @@ # SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. # All rights reserved. # SPDX-License-Identifier: Apache-2.0 - - +import json from io import BytesIO from textwrap import dedent @@ -220,6 +219,7 @@ def test_pptx(pptx_stream_with_text, document_df): extract_text=True, extract_images=False, extract_tables=False, + extract_charts=False, row_data=document_df.iloc[0], ) @@ -267,6 +267,7 @@ def test_pptx_with_multiple_runs_in_title(pptx_stream_with_multiple_runs_in_titl extract_text=True, extract_images=False, extract_tables=False, + extract_charts=False, row_data=document_df.iloc[0], ) @@ -299,6 +300,7 @@ def test_pptx_text_depth_presentation(pptx_stream_with_text, document_df): extract_text=True, extract_images=False, extract_tables=False, + extract_charts=False, row_data=document_df.iloc[0], text_depth="document", ) @@ -343,6 +345,7 @@ def test_pptx_text_depth_shape(pptx_stream_with_text, document_df): extract_text=True, extract_images=False, extract_tables=False, + extract_charts=False, row_data=document_df.iloc[0], text_depth="block", ) @@ -397,6 +400,7 @@ def test_pptx_text_depth_para_run(pptx_stream_with_text, document_df, text_depth extract_text=True, extract_images=False, extract_tables=False, + extract_charts=False, row_data=document_df.iloc[0], text_depth=text_depth, ) @@ -441,6 +445,7 @@ def test_pptx_bullet(pptx_stream_with_bullet, document_df): extract_text=True, extract_images=False, extract_tables=False, + extract_charts=False, row_data=document_df.iloc[0], ) @@ -473,6 +478,7 @@ def test_pptx_group(pptx_stream_with_group, document_df): extract_text=True, extract_images=False, extract_tables=False, + extract_charts=False, row_data=document_df.iloc[0], ) @@ -502,6 +508,7 @@ def test_pptx_table(pptx_stream_with_table, document_df): extract_text=True, extract_images=False, extract_tables=True, + extract_charts=False, row_data=document_df.iloc[0], ) @@ -524,7 +531,7 @@ def test_pptx_table(pptx_stream_with_table, document_df): | Baz | Qux | """ ) - assert extracted_data[0][1]["content"].rstrip() == expected_content.rstrip() + assert extracted_data[0][1]["table_metadata"]["table_content"].rstrip() == expected_content.rstrip() def test_pptx_image(pptx_stream_with_image, document_df): @@ -533,14 +540,17 @@ def test_pptx_image(pptx_stream_with_image, document_df): extract_text=True, extract_images=True, extract_tables=False, + extract_charts=False, row_data=document_df.iloc[0], ) assert isinstance(extracted_data, list) assert len(extracted_data) == 2 assert len(extracted_data[0]) == 3 - assert extracted_data[0][0] == "image" + + assert extracted_data[0][0] == "text" assert extracted_data[0][1]["source_metadata"]["source_id"] == "source1" assert isinstance(extracted_data[0][2], str) - assert extracted_data[0][1]["content"][:10] == "iVBORw0KGg" # PNG format header + assert extracted_data[1][0] == "image" + assert extracted_data[1][1]["content"][:10] == "iVBORw0KGg" # PNG format header