Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use Yolox-table-structure to extract tables as markdown #444

Merged
merged 16 commits into from
Feb 19, 2025
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions docker-compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,29 @@ services:
capabilities: [gpu]
runtime: nvidia

yolox-table-structure:
image: ${YOLOX_TABLE_STRUCTURE_IMAGE:-set-image-name-in-dot-env}:${YOLOX_TABLE_STRUCTURE_TAG:-set-image-tag-in-dot-env}
ports:
- "8006:8000"
- "8007:8001"
- "8008:8002"
user: root
environment:
- NIM_HTTP_API_PORT=8000
- NIM_TRITON_LOG_VERBOSE=1
- NGC_API_KEY=${STAGING_NIM_NGC_API_KEY}
- CUDA_VISIBLE_DEVICES=0
deploy:
resources:
reservations:
devices:
- driver: nvidia
device_ids: ["1"]
capabilities: [gpu]
runtime: nvidia
profiles:
- yolox-table-structure

paddle:
image: ${PADDLE_IMAGE:-nvcr.io/nvidia/nemo-microservices/paddleocr}:${PADDLE_TAG:-1.0.0}
shm_size: 2gb
Expand Down Expand Up @@ -155,6 +178,9 @@ services:
- YOLOX_GRAPHIC_ELEMENTS_GRPC_ENDPOINT=yolox-graphic-elements:8001
- YOLOX_GRAPHIC_ELEMENTS_HTTP_ENDPOINT=http://yolox-graphic-elements:8000/v1/infer
- YOLOX_GRAPHIC_ELEMENTS_INFER_PROTOCOL=grpc
- YOLOX_TABLE_STRUCTURE_GRPC_ENDPOINT=yolox-table-structure:8001
- YOLOX_TABLE_STRUCTURE_HTTP_ENDPOINT=http://yolox-table-structure:8000/v1/infer
- YOLOX_TABLE_STRUCTURE_INFER_PROTOCOL=grpc
- VLM_CAPTION_ENDPOINT=https://ai.api.nvidia.com/v1/gr/meta/llama-3.2-11b-vision-instruct/chat/completions
- VLM_CAPTION_MODEL_NAME=meta/llama-3.2-11b-vision-instruct
healthcheck:
Expand Down
16 changes: 10 additions & 6 deletions src/nv_ingest/schemas/table_extractor_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ class TableExtractorConfigSchema(BaseModel):

auth_token: Optional[str] = None

yolox_endpoints: Tuple[Optional[str], Optional[str]] = (None, None)
yolox_infer_protocol: str = ""

paddle_endpoints: Tuple[Optional[str], Optional[str]] = (None, None)
paddle_infer_protocol: str = ""

Expand Down Expand Up @@ -78,14 +81,15 @@ def clean_service(service):
return None
return service

grpc_service, http_service = values.get("paddle_endpoints", (None, None))
grpc_service = clean_service(grpc_service)
http_service = clean_service(http_service)
for endpoint_name in ["yolox_endpoints", "paddle_endpoints"]:
grpc_service, http_service = values.get(endpoint_name, (None, None))
grpc_service = clean_service(grpc_service)
http_service = clean_service(http_service)

if not grpc_service and not http_service:
raise ValueError("Both gRPC and HTTP services cannot be empty for paddle_endpoints.")
if not grpc_service and not http_service:
raise ValueError(f"Both gRPC and HTTP services cannot be empty for {endpoint_name}.")

values["paddle_endpoints"] = (grpc_service, http_service)
values[endpoint_name] = (grpc_service, http_service)

return values

Expand Down
4 changes: 2 additions & 2 deletions src/nv_ingest/stages/nim/chart_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from nv_ingest.schemas.chart_extractor_schema import ChartExtractorSchema
from nv_ingest.stages.multiprocessing_stage import MultiProcessingBaseStage
from nv_ingest.util.image_processing.table_and_chart import join_yolox_and_paddle_output
from nv_ingest.util.image_processing.table_and_chart import join_yolox_graphic_elements_and_paddle_output
from nv_ingest.util.image_processing.table_and_chart import process_yolox_graphic_elements
from nv_ingest.util.image_processing.transforms import base64_to_numpy
from nv_ingest.util.nim.helpers import NimClient
Expand Down Expand Up @@ -106,7 +106,7 @@ def _update_metadata(
results = []
for img_str, yolox_res, paddle_res in zip(base64_images, yolox_results, paddle_results):
bounding_boxes, text_predictions = paddle_res
yolox_elements = join_yolox_and_paddle_output(yolox_res, bounding_boxes, text_predictions)
yolox_elements = join_yolox_graphic_elements_and_paddle_output(yolox_res, bounding_boxes, text_predictions)
chart_content = process_yolox_graphic_elements(yolox_elements)
results.append((img_str, chart_content))

Expand Down
155 changes: 104 additions & 51 deletions src/nv_ingest/stages/nim/table_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,28 @@

import functools
import logging
import traceback
from concurrent.futures import ThreadPoolExecutor
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple

import numpy as np
import pandas as pd
from morpheus.config import Config

from nv_ingest.schemas.metadata_schema import TableFormatEnum
from nv_ingest.schemas.table_extractor_schema import TableExtractorSchema
from nv_ingest.stages.multiprocessing_stage import MultiProcessingBaseStage
from nv_ingest.util.image_processing.table_and_chart import convert_paddle_response_to_psuedo_markdown
from nv_ingest.util.image_processing.transforms import base64_to_numpy
from nv_ingest.util.nim.helpers import NimClient
from nv_ingest.util.nim.helpers import create_inference_client
from nv_ingest.util.nim.helpers import get_version
from nv_ingest.util.nim.paddle import PaddleOCRModelInterface
from nv_ingest.util.nim.yolox import YoloxTableStructureModelInterface
from nv_ingest.util.image_processing.table_and_chart import join_yolox_table_structure_and_paddle_output
from nv_ingest.util.image_processing.table_and_chart import convert_paddle_response_to_psuedo_markdown

logger = logging.getLogger(__name__)

Expand All @@ -31,8 +35,10 @@

def _update_metadata(
base64_images: List[str],
yolox_client: NimClient,
paddle_client: NimClient,
worker_pool_size: int = 8, # Not currently used
enable_yolox: bool = False,
trace_info: Dict = None,
) -> List[Tuple[str, Tuple[Any, Any]]]:
"""
Expand All @@ -48,79 +54,110 @@ def _update_metadata(
logger.debug(f"Running table extraction using protocol {paddle_client.protocol}")

# Initialize the results list in the same order as base64_images.
results: List[Optional[Tuple[str, Tuple[Any, Any]]]] = [None] * len(base64_images)
results: List[Optional[Tuple[str, Tuple[Any, Any, Any]]]] = ["", (None, None, None)] * len(base64_images)

valid_images: List[str] = []
valid_indices: List[int] = []
valid_arrays: List[np.ndarray] = []

_ = worker_pool_size
# Pre-decode image dimensions and filter valid images.
for i, img in enumerate(base64_images):
array = base64_to_numpy(img)
height, width = array.shape[0], array.shape[1]
if width >= PADDLE_MIN_WIDTH and height >= PADDLE_MIN_HEIGHT:
valid_images.append(img)
valid_arrays.append(array)
valid_indices.append(i)
else:
# Image is too small; mark as skipped.
results[i] = (img, (None, None))
results[i] = ("", None, None, None)

if valid_images:
data = {"base64_images": valid_images}
try:
# Call infer once for all valid images. The NimClient will handle batching internally.
paddle_result = paddle_client.infer(
data=data,
model_name="paddle",
if not valid_images:
return results

# Prepare data payloads for both clients.
if enable_yolox:
data_yolox = {"images": valid_arrays}
data_paddle = {"base64_images": valid_images}

_ = worker_pool_size
with ThreadPoolExecutor(max_workers=2) as executor:
if enable_yolox:
future_yolox = executor.submit(
yolox_client.infer,
data=data_yolox,
model_name="yolox",
stage_name="table_data_extraction",
max_batch_size=1 if paddle_client.protocol == "grpc" else 2,
max_batch_size=8,
trace_info=trace_info,
)
future_paddle = executor.submit(
paddle_client.infer,
data=data_paddle,
model_name="paddle",
stage_name="table_data_extraction",
max_batch_size=1 if paddle_client.protocol == "grpc" else 2,
trace_info=trace_info,
)

if not isinstance(paddle_result, list):
raise ValueError(f"Expected a list of tuples, got {type(paddle_result)}")
if len(paddle_result) != len(valid_images):
raise ValueError(f"Expected {len(valid_images)} results, got {len(paddle_result)}")

# Assign each result back to its original position.
for idx, result in enumerate(paddle_result):
original_index = valid_indices[idx]
results[original_index] = (base64_images[original_index], result)
if enable_yolox:
try:
yolox_results = future_yolox.result()
except Exception as e:
logger.error(f"Error calling yolox_client.infer: {e}", exc_info=True)
raise
else:
yolox_results = [None] * len(valid_images)

try:
paddle_results = future_paddle.result()
except Exception as e:
logger.error(f"Error processing images. Error: {e}", exc_info=True)
for i in valid_indices:
results[i] = (base64_images[i], (None, None))
logger.error(f"Error calling yolox_client.infer: {e}", exc_info=True)
raise

return results
# Ensure both clients returned lists of results matching the number of input images.
if not (isinstance(yolox_results, list) and isinstance(paddle_results, list)):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want this to be a full failure of the job? It indicates a lack of agreement between one yolox model and another, but its not clear that it should cause document processing to fail.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about something like 34d7ce1?

raise ValueError("Expected list results from both yolox_client and paddle_client infer calls.")

if len(yolox_results) != len(valid_arrays):
raise ValueError(f"Expected {len(valid_arrays)} yolox results, got {len(yolox_results)}")
if len(paddle_results) != len(valid_images):
raise ValueError(f"Expected {len(valid_images)} paddle results, got {len(paddle_results)}")

def _create_paddle_client(stage_config) -> NimClient:
"""
Helper to create a NimClient for PaddleOCR, retrieving the paddle version from the endpoint.
"""
# Attempt to obtain PaddleOCR version from the second endpoint
paddle_endpoint = stage_config.paddle_endpoints[1]
try:
paddle_version = get_version(paddle_endpoint)
if not paddle_version:
logger.warning("Failed to obtain PaddleOCR version from the endpoint. Falling back to the latest version.")
paddle_version = None
except Exception:
logger.warning("Failed to get PaddleOCR version after 30 seconds. Falling back to the latest version.")
paddle_version = None
for idx, (yolox_res, paddle_res) in enumerate(zip(yolox_results, paddle_results)):
original_index = valid_indices[idx]
results[original_index] = (base64_images[original_index], yolox_res, paddle_res[0], paddle_res[1])

return results


def _create_clients(
yolox_endpoints: Tuple[str, str],
yolox_protocol: str,
paddle_endpoints: Tuple[str, str],
paddle_protocol: str,
auth_token: str,
) -> Tuple[NimClient, NimClient]:
yolox_model_interface = YoloxTableStructureModelInterface()
paddle_model_interface = PaddleOCRModelInterface()

logger.debug(f"Inference protocols: yolox={yolox_protocol}, paddle={paddle_protocol}")

yolox_client = create_inference_client(
endpoints=yolox_endpoints,
model_interface=yolox_model_interface,
auth_token=auth_token,
infer_protocol=yolox_protocol,
)

paddle_client = create_inference_client(
endpoints=stage_config.paddle_endpoints,
endpoints=paddle_endpoints,
model_interface=paddle_model_interface,
auth_token=stage_config.auth_token,
infer_protocol=stage_config.paddle_infer_protocol,
auth_token=auth_token,
infer_protocol=paddle_protocol,
)

return paddle_client
return yolox_client, paddle_client


def _extract_table_data(
Expand Down Expand Up @@ -157,7 +194,13 @@ def _extract_table_data(
return df, trace_info

stage_config = validated_config.stage_config
paddle_client = _create_paddle_client(stage_config)
yolox_client, paddle_client = _create_clients(
stage_config.yolox_endpoints,
stage_config.yolox_infer_protocol,
stage_config.paddle_endpoints,
stage_config.paddle_infer_protocol,
stage_config.auth_token,
)

try:
# 1) Identify rows that meet criteria (structured, subtype=table, table_metadata != None, content not empty)
Expand Down Expand Up @@ -189,26 +232,34 @@ def meets_criteria(row):
base64_images.append(meta["content"])

# 3) Call our bulk _update_metadata to get all results
table_content_format = (
df.at[valid_indices[0], "metadata"]["table_metadata"].get("table_content_format")
or TableFormatEnum.PSEUDO_MARKDOWN
)
enable_yolox = True if table_content_format in (TableFormatEnum.MARKDOWN,) else False

bulk_results = _update_metadata(
base64_images=base64_images,
yolox_client=yolox_client,
paddle_client=paddle_client,
worker_pool_size=stage_config.workers_per_progress_engine,
enable_yolox=enable_yolox,
trace_info=trace_info,
)

# 4) Write the results (bounding_boxes, text_predictions) back
table_content_format = df.at[valid_indices[0], "metadata"]["table_metadata"].get(
"table_content_format", TableFormatEnum.PSEUDO_MARKDOWN
)

for row_id, idx in enumerate(valid_indices):
# unpack (base64_image, (bounding boxes, text_predictions))
_, (bounding_boxes, text_predictions) = bulk_results[row_id]
# unpack (base64_image, (yolox_predictions, paddle_bounding boxes, paddle_text_predictions))
_, cell_predictions, bounding_boxes, text_predictions = bulk_results[row_id]

if table_content_format == TableFormatEnum.SIMPLE:
table_content = " ".join(text_predictions)
elif table_content_format == TableFormatEnum.PSEUDO_MARKDOWN:
table_content = convert_paddle_response_to_psuedo_markdown(bounding_boxes, text_predictions)
elif table_content_format == TableFormatEnum.MARKDOWN:
table_content = join_yolox_table_structure_and_paddle_output(
cell_predictions, bounding_boxes, text_predictions
)
else:
raise ValueError(f"Unexpected table format: {table_content_format}")

Expand All @@ -219,8 +270,10 @@ def meets_criteria(row):

except Exception:
logger.error("Error occurred while extracting table data.", exc_info=True)
traceback.print_exc()
raise
finally:
yolox_client.close()
paddle_client.close()


Expand Down
Loading
Loading