Skip to content

Commit

Permalink
Add ability to group collections of image bounding boxes (#386)
Browse files Browse the repository at this point in the history
  • Loading branch information
edknv authored Feb 27, 2025
1 parent ecb66b6 commit aa015bc
Show file tree
Hide file tree
Showing 9 changed files with 707 additions and 49 deletions.
36 changes: 33 additions & 3 deletions client/src/nv_ingest_client/primitives/tasks/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,16 @@

import logging
import os
from typing import Any
from typing import Dict
from typing import Literal
from typing import Optional
from typing import get_args

from pydantic import BaseModel, field_validator, model_validator
from pydantic import BaseModel
from pydantic import ConfigDict
from pydantic import field_validator
from pydantic import model_validator

from .task_base import Task

Expand Down Expand Up @@ -82,12 +86,16 @@
"pptx": get_args(_Type_Extract_Tables_Method_PPTX),
}

_Type_Extract_Images_Method = Literal["simple", "group"]


class ExtractTaskSchema(BaseModel):
document_type: str
extract_method: str = None # Initially allow None to set a smart default
extract_text: bool = True
extract_images: bool = True
extract_images_method: str = "group"
extract_images_params: Optional[Dict[str, Any]] = None
extract_tables: bool = True
extract_tables_method: str = "yolox"
extract_charts: Optional[bool] = None # Initially allow None to set a smart default
Expand Down Expand Up @@ -149,8 +157,15 @@ def extract_tables_method_must_be_valid(cls, v, values, **kwargs):
raise ValueError(f"extract_method must be one of {valid_methods}")
return v

class Config:
extra = "forbid"
@field_validator("extract_images_method")
def extract_images_method_must_be_valid(cls, v):
if v.lower() not in get_args(_Type_Extract_Images_Method):
raise ValueError(
f"Unsupported document type '{v}'. Supported types are: {', '.join(_Type_Extract_Images_Method)}"
)
return v.lower()

model_config = ConfigDict(extra="forbid")


class ExtractTask(Task):
Expand All @@ -166,6 +181,8 @@ def __init__(
extract_images: bool = False,
extract_tables: bool = False,
extract_charts: Optional[bool] = None,
extract_images_method: _Type_Extract_Images_Method = "group",
extract_images_params: Optional[Dict[str, Any]] = None,
extract_tables_method: _Type_Extract_Tables_Method_PDF = "yolox",
extract_infographics: bool = False,
text_depth: str = "document",
Expand All @@ -180,6 +197,8 @@ def __init__(
self._extract_images = extract_images
self._extract_method = extract_method
self._extract_tables = extract_tables
self._extract_images_method = extract_images_method
self._extract_images_params = extract_images_params
self._extract_tables_method = extract_tables_method
# `extract_charts` is initially set to None for backward compatibility.
# {extract_tables: true, extract_charts: None} or {extract_tables: true, extract-charts: true} enables both
Expand All @@ -204,9 +223,13 @@ def __str__(self) -> str:
info += f" extract tables: {self._extract_tables}\n"
info += f" extract charts: {self._extract_charts}\n"
info += f" extract infographics: {self._extract_infographics}\n"
info += f" extract images method: {self._extract_images_method}\n"
info += f" extract tables method: {self._extract_tables_method}\n"
info += f" text depth: {self._text_depth}\n"
info += f" paddle_output_format: {self._paddle_output_format}\n"

if self._extract_images_params:
info += f" extract images params: {self._extract_images_params}\n"
return info

def to_dict(self) -> Dict:
Expand All @@ -217,12 +240,19 @@ def to_dict(self) -> Dict:
"extract_text": self._extract_text,
"extract_images": self._extract_images,
"extract_tables": self._extract_tables,
"extract_images_method": self._extract_images_method,
"extract_tables_method": self._extract_tables_method,
"extract_charts": self._extract_charts,
"extract_infographics": self._extract_infographics,
"text_depth": self._text_depth,
"paddle_output_format": self._paddle_output_format,
}
if self._extract_images_params:
extract_params.update(
{
"extract_images_params": self._extract_images_params,
}
)

task_properties = {
"method": self._extract_method,
Expand Down
Binary file added data/test-page-form.pdf
Binary file not shown.
Binary file added data/test-shapes.pdf
Binary file not shown.
63 changes: 32 additions & 31 deletions src/nv_ingest/extraction_workflows/pdf/pdfium_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,14 @@
from nv_ingest.util.nim.yolox import YOLOX_PAGE_IMAGE_PREPROC_HEIGHT
from nv_ingest.util.nim.yolox import YOLOX_PAGE_IMAGE_PREPROC_WIDTH
from nv_ingest.util.nim.yolox import get_yolox_model_name
from nv_ingest.util.pdf.metadata_aggregators import Base64Image
from nv_ingest.util.pdf.metadata_aggregators import CroppedImageWithContent
from nv_ingest.util.pdf.metadata_aggregators import construct_image_metadata_from_pdf_image
from nv_ingest.util.pdf.metadata_aggregators import construct_page_element_metadata
from nv_ingest.util.pdf.metadata_aggregators import construct_text_metadata
from nv_ingest.util.pdf.metadata_aggregators import extract_pdf_metadata
from nv_ingest.util.pdf.pdfium import PDFIUM_PAGEOBJ_MAPPING
from nv_ingest.util.pdf.pdfium import pdfium_pages_to_numpy
from nv_ingest.util.pdf.pdfium import pdfium_try_get_bitmap_as_numpy
from nv_ingest.util.pdf.pdfium import extract_nested_simple_images_from_pdfium_page
from nv_ingest.util.pdf.pdfium import extract_image_like_objects_from_pdfium_page

YOLOX_MAX_BATCH_SIZE = 8

Expand Down Expand Up @@ -191,6 +190,9 @@ def extract_page_element_images(
w1, h1, w2, h2 = bbox

cropped = crop_image(original_image, (int(w1), int(h1), int(w2), int(h2)))
if cropped is None:
continue

base64_img = numpy_to_base64(cropped)

bbox_in_orig_coord = (
Expand Down Expand Up @@ -223,47 +225,39 @@ def _extract_page_text(page) -> str:


def _extract_page_images(
extract_images_method: str,
page,
page_idx: int,
page_width: float,
page_height: float,
page_count: int,
source_metadata: dict,
base_unified_metadata: dict,
**extract_images_params,
) -> list:
"""
Always extract images from the given page and return a list of image metadata items.
The caller decides whether to call this based on a flag.
"""
extracted_images = []
for obj in page.get_objects():
obj_type = PDFIUM_PAGEOBJ_MAPPING.get(obj.type, "UNKNOWN")
if obj_type == "IMAGE":
try:
image_numpy = pdfium_try_get_bitmap_as_numpy(obj)
image_base64 = numpy_to_base64(image_numpy)
image_bbox = obj.get_pos()
image_size = obj.get_size()

image_data = Base64Image(
image=image_base64,
bbox=image_bbox,
width=image_size[0],
height=image_size[1],
max_width=page_width,
max_height=page_height,
)
if extract_images_method == "simple":
extracted_image_data = extract_nested_simple_images_from_pdfium_page(page)
else: # if extract_images_method == "group"
extracted_image_data = extract_image_like_objects_from_pdfium_page(page, merge=True, **extract_images_params)

image_meta = construct_image_metadata_from_pdf_image(
image_data,
page_idx,
page_count,
source_metadata,
base_unified_metadata,
)
extracted_images.append(image_meta)
except Exception as e:
logger.error(f"Unhandled error extracting image on page {page_idx}: {e}")
extracted_images = []
for image_data in extracted_image_data:
try:
image_meta = construct_image_metadata_from_pdf_image(
image_data,
page_idx,
page_count,
source_metadata,
base_unified_metadata,
)
extracted_images.append(image_meta)
except Exception as e:
logger.error(f"Unhandled error extracting image on page {page_idx}: {e}")
# continue extracting other images

return extracted_images

Expand Down Expand Up @@ -332,6 +326,8 @@ def pdfium_extractor(
extract_infographics = kwargs.get("extract_infographics", False)
paddle_output_format = kwargs.get("paddle_output_format", "pseudo_markdown")
paddle_output_format = TableFormatEnum[paddle_output_format.upper()]
extract_images_method = kwargs.get("extract_images_method", "group")
extract_images_params = kwargs.get("extract_images_params") or {}

# Basic config
metadata_col = kwargs.get("metadata_column", "metadata")
Expand Down Expand Up @@ -412,13 +408,15 @@ def pdfium_extractor(
# If we want images, extract images now.
if extract_images:
image_data = _extract_page_images(
extract_images_method,
page,
page_idx,
page_width,
page_height,
page_count,
source_metadata,
base_unified_metadata,
**extract_images_params,
)
extracted_data.extend(image_data)

Expand Down Expand Up @@ -491,4 +489,7 @@ def pdfium_extractor(
)
extracted_data.append(doc_text_meta)

doc.close()

logger.debug(f"Extracted {len(extracted_data)} items from PDF.")
return extracted_data
Loading

0 comments on commit aa015bc

Please sign in to comment.