Skip to content

Commit

Permalink
Introduce a "Convert" endpoint for directly handling HTTP fileupload …
Browse files Browse the repository at this point in the history
…documents (#313)

Co-authored-by: Julio Perez <[email protected]>
Co-authored-by: Julio Perez <[email protected]>
Co-authored-by: Devin Robison <[email protected]>
Co-authored-by: tmonty12 <[email protected]>
  • Loading branch information
5 people authored Jan 16, 2025
1 parent cf8c5f5 commit 0e44b01
Show file tree
Hide file tree
Showing 8 changed files with 406 additions and 4 deletions.
1 change: 1 addition & 0 deletions client/src/nv_ingest_client/primitives/tasks/caption.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class CaptionTaskSchema(BaseModel):
model_name: Optional[str] = None

model_config = ConfigDict(extra="forbid")
model_config["protected_namespaces"] = ()


class CaptionTask(Task):
Expand Down
171 changes: 167 additions & 4 deletions src/nv_ingest/api/v1/ingest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,27 +10,33 @@

# pylint: skip-file

from io import BytesIO
from typing import Annotated, Dict, List
import base64
import json
import logging
import time
import traceback
from io import BytesIO
from typing import Annotated
import uuid

from fastapi import APIRouter, Request, Response
from fastapi import Depends
from fastapi import File
from fastapi import File, UploadFile, Form
from fastapi import HTTPException
from fastapi import UploadFile
from fastapi.responses import JSONResponse
from nv_ingest_client.primitives.jobs.job_spec import JobSpec
from nv_ingest_client.primitives.tasks.extract import ExtractTask
from opentelemetry import trace
from redis import RedisError

from nv_ingest.util.converters.formats import ingest_json_results_to_blob

from nv_ingest.schemas.message_wrapper_schema import MessageWrapper
from nv_ingest.schemas.processing_job_schema import ConversionStatus, ProcessingJob
from nv_ingest.service.impl.ingest.redis_ingest_service import RedisIngestService
from nv_ingest.service.meta.ingest.ingest_service_meta import IngestServiceMeta
from nv_ingest_client.primitives.tasks.table_extraction import TableExtractionTask
from nv_ingest_client.primitives.tasks.chart_extraction import ChartExtractionTask

logger = logging.getLogger("uvicorn")
tracer = trace.get_tracer(__name__)
Expand Down Expand Up @@ -184,3 +190,160 @@ async def fetch_job(job_id: str, ingest_service: INGEST_SERVICE_T):
# Catch-all for other exceptions, returning a 500 Internal Server Error
traceback.print_exc()
raise HTTPException(status_code=500, detail=f"Nv-Ingest Internal Server Error: {str(ex)}")


@router.post("/convert")
async def convert_pdf(
ingest_service: INGEST_SERVICE_T,
files: List[UploadFile] = File(...),
job_id: str = Form(...),
extract_text: bool = Form(True),
extract_images: bool = Form(True),
extract_tables: bool = Form(True),
extract_charts: bool = Form(False),
) -> Dict[str, str]:
try:

if job_id is None:
job_id = str(uuid.uuid4())
logger.debug(f"JobId is None, Created JobId: {job_id}")

submitted_jobs: List[ProcessingJob] = []
for file in files:
file_stream = BytesIO(file.file.read())
doc_content = base64.b64encode(file_stream.read()).decode("utf-8")

try:
content_type = file.content_type.split("/")[1]
except Exception:
err_message = f"Unsupported content_type: {file.content_type}"
logger.error(err_message)
raise HTTPException(status_code=500, detail=err_message)

job_spec = JobSpec(
document_type=content_type,
payload=doc_content,
source_id=file.filename,
source_name=file.filename,
extended_options={
"tracing_options": {
"trace": True,
"ts_send": time.time_ns(),
}
},
)

extract_task = ExtractTask(
document_type=content_type,
extract_text=extract_text,
extract_images=extract_images,
extract_tables=extract_tables,
extract_charts=extract_charts,
)

job_spec.add_task(extract_task)

# Conditionally add tasks as needed.
if extract_tables:
table_data_extract = TableExtractionTask()
job_spec.add_task(table_data_extract)

if extract_charts:
chart_data_extract = ChartExtractionTask()
job_spec.add_task(chart_data_extract)

submitted_job_id = await ingest_service.submit_job(
MessageWrapper(payload=json.dumps(job_spec.to_dict())), job_id
)

processing_job = ProcessingJob(
submitted_job_id=submitted_job_id,
filename=file.filename,
status=ConversionStatus.IN_PROGRESS,
)

submitted_jobs.append(processing_job)

await ingest_service.set_processing_cache(job_id, submitted_jobs)

logger.debug(f"Submitted: {len(submitted_jobs)} documents of type: '{content_type}' for processing")

return {
"task_id": job_id,
"status": "processing",
"status_url": f"/status/{job_id}",
}

except Exception as e:
logger.error(f"Error starting conversion: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))


@router.get("/status/{job_id}")
async def get_status(ingest_service: INGEST_SERVICE_T, job_id: str):
t_start = time.time()
try:
processing_jobs = await ingest_service.get_processing_cache(job_id)
except Exception as e:
logger.error(f"Error getting status: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))

updated_cache: List[ProcessingJob] = []
num_ready_docs = 0

for processing_job in processing_jobs:
logger.debug(f"submitted_job_id: {processing_job.submitted_job_id} - Status: {processing_job.status}")

if processing_job.status == ConversionStatus.IN_PROGRESS:
# Attempt to fetch the job from the ingest service
try:
job_response = await ingest_service.fetch_job(processing_job.submitted_job_id)

job_response = json.dumps(job_response)

# Convert JSON into pseudo markdown format
blob_response = ingest_json_results_to_blob(job_response)

processing_job.raw_result = job_response
processing_job.content = blob_response
processing_job.status = ConversionStatus.SUCCESS
num_ready_docs = num_ready_docs + 1
updated_cache.append(processing_job)

except TimeoutError:
logger.error(f"TimeoutError getting result for job_id: {processing_job.submitted_job_id}")
updated_cache.append(processing_job)
continue
except RedisError:
logger.error(f"RedisError getting result for job_id: {processing_job.submitted_job_id}")
updated_cache.append(processing_job)
continue
else:
logger.debug(f"{processing_job.submitted_job_id} has already finished successfully ....")
num_ready_docs = num_ready_docs + 1
updated_cache.append(processing_job)

await ingest_service.set_processing_cache(job_id, updated_cache)

logger.debug(f"{num_ready_docs}/{len(updated_cache)} complete")
if num_ready_docs == len(updated_cache):
results = []
raw_results = []
for result in updated_cache:
results.append(
{
"filename": result.filename,
"status": "success",
"content": result.content,
}
)
raw_results.append(result.raw_result)

return JSONResponse(
content={"status": "completed", "result": results},
status_code=200,
)
else:
# Not yet ready ...
logger.debug(f"/status/{job_id} endpoint execution time: {time.time() - t_start}")
raise HTTPException(status_code=202, detail="Job is not ready yet. Retry later.")
31 changes: 31 additions & 0 deletions src/nv_ingest/schemas/processing_job_schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.

from pydantic import BaseModel, ConfigDict
from enum import Enum


class ConversionStatus(str, Enum):
IN_PROGRESS = "in_progress"
SUCCESS = "success"
FAILED = "failed"

model_config = ConfigDict(extra="forbid")


class ProcessingJob(BaseModel):
submitted_job_id: str
filename: str
raw_result: str = ""
content: str = ""
status: ConversionStatus
error: str | None = None

model_config = ConfigDict(extra="forbid")
25 changes: 25 additions & 0 deletions src/nv_ingest/service/impl/ingest/redis_ingest_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@
from json import JSONDecodeError
from typing import Any

from typing import List
from nv_ingest.schemas import validate_ingest_job
from nv_ingest.schemas.message_wrapper_schema import MessageWrapper
from nv_ingest.service.meta.ingest.ingest_service_meta import IngestServiceMeta
from nv_ingest.util.message_brokers.redis.redis_client import RedisClient
from nv_ingest.schemas.processing_job_schema import ProcessingJob

logger = logging.getLogger("uvicorn")

Expand Down Expand Up @@ -46,6 +48,8 @@ def __init__(self, redis_hostname: str, redis_port: int, redis_task_queue: str):
self._redis_hostname = redis_hostname
self._redis_port = redis_port
self._redis_task_queue = redis_task_queue
self._cache_prefix = "processing_cache:"
self._bulk_vdb_cache_prefix = "vdb_bulk_upload_cache:"

self._ingest_client = RedisClient(
host=self._redis_hostname, port=self._redis_port, max_pool_size=self._concurrency_level
Expand Down Expand Up @@ -89,3 +93,24 @@ async def fetch_job(self, job_id: str) -> Any:
raise TimeoutError()

return message

async def set_processing_cache(self, job_id: str, jobs_data: List[ProcessingJob]) -> None:
"""Store processing jobs data using simple key-value"""
cache_key = f"{self._cache_prefix}{job_id}"
try:
self._ingest_client.get_client().set(cache_key, json.dumps([job.dict() for job in jobs_data]), ex=3600)
except Exception as err:
logger.error(f"Error setting cache for {cache_key}: {err}")
raise

async def get_processing_cache(self, job_id: str) -> List[ProcessingJob]:
"""Retrieve processing jobs data using simple key-value"""
cache_key = f"{self._cache_prefix}{job_id}"
try:
data = self._ingest_client.get_client().get(cache_key)
if data is None:
return []
return [ProcessingJob(**job) for job in json.loads(data)]
except Exception as err:
logger.error(f"Error getting cache for {cache_key}: {err}")
raise
10 changes: 10 additions & 0 deletions src/nv_ingest/service/meta/ingest/ingest_service_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@

from abc import ABC
from abc import abstractmethod
from typing import List

from nv_ingest.schemas.message_wrapper_schema import MessageWrapper
from nv_ingest.schemas.processing_job_schema import ProcessingJob


class IngestServiceMeta(ABC):
Expand All @@ -22,3 +24,11 @@ async def submit_job(self, job_spec: MessageWrapper, trace_id: str) -> str:
@abstractmethod
async def fetch_job(self, job_id: str):
"""Abstract method for fetching job from ingestion service based on job_id"""

@abstractmethod
async def set_processing_cache(self, job_id: str, jobs_data: List[ProcessingJob]) -> None:
"""Abstract method for setting processing cache"""

@abstractmethod
async def get_processing_cache(self, job_id: str) -> List[ProcessingJob]:
"""Abstract method for getting processing cache"""
70 changes: 70 additions & 0 deletions src/nv_ingest/util/converters/formats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.

# pylint: skip-file

import json


def ingest_json_results_to_blob(result_content):
"""
Parse a JSON string or BytesIO object, combine and sort entries, and create a blob string.
Returns:
str: The generated blob string.
"""
try:
# Load the JSON data
data = json.loads(result_content) if isinstance(result_content, str) else json.loads(result_content)
data = data["data"]

# Smarter sorting: by page, then structured objects by x0, y0
def sorting_key(entry):
page = entry["metadata"]["content_metadata"]["page_number"]
if entry["document_type"] == "structured":
# Use table location's x0 and y0 as secondary keys
x0 = entry["metadata"]["table_metadata"]["table_location"][0]
y0 = entry["metadata"]["table_metadata"]["table_location"][1]
else:
# Non-structured objects are sorted after structured ones
x0 = float("inf")
y0 = float("inf")
return page, x0, y0

data.sort(key=sorting_key)

# Initialize the blob string
blob = []

for entry in data:
document_type = entry.get("document_type", "")

if document_type == "structured":
# Add table content to the blob
blob.append(entry["metadata"]["table_metadata"]["table_content"])
blob.append("\n")

elif document_type == "text":
# Add content to the blob
blob.append(entry["metadata"]["content"])
blob.append("\n")

elif document_type == "image":
# Add image caption to the blob
caption = entry["metadata"]["image_metadata"].get("caption", "")
blob.append(f"image_caption:[{caption}]")
blob.append("\n")

# Join all parts of the blob into a single string
return "".join(blob)

except Exception as e:
print(f"[ERROR] An error occurred while processing JSON content: {e}")
return ""

Large diffs are not rendered by default.

Loading

0 comments on commit 0e44b01

Please sign in to comment.