Skip to content

Commit

Permalink
Add PDF OCR
Browse files Browse the repository at this point in the history
  • Loading branch information
gabriel-piles committed Dec 10, 2024
1 parent acc5387 commit fc89b40
Show file tree
Hide file tree
Showing 7 changed files with 75 additions and 10 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -156,4 +156,5 @@ cython_debug/
.DS_Store
.idea
.vscode
data/
data/
/ocr/
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
git+https://github.com/huridocs/pdf-document-layout-analysis@c6813b052df5ba3df454860b11a8d77630e14edb
git+https://github.com/huridocs/queue-processor@716ddf050c59035583b0852dc0b78a7860ce5c05
graypy==2.1.0
PyYAML==6.0.1
pymongo==4.8.0
httpx==0.27.0
sentry-sdk==2.8.0
git+https://github.com/huridocs/queue-processor@d528b7d37c6d9b8f7ec352eec4a8ebed64ffb248
git+https://github.com/huridocs/pdf-document-layout-analysis@ac8d850dfb294f8b8f828920d5d1ac9189be1e9f

22 changes: 20 additions & 2 deletions src/app.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
from contextlib import asynccontextmanager
from os.path import join

import pymongo
from fastapi import FastAPI, HTTPException, File, UploadFile
Expand All @@ -8,10 +9,10 @@
from sentry_sdk.integrations.asgi import SentryAsgiMiddleware
import sentry_sdk
from starlette.concurrency import run_in_threadpool
from starlette.responses import PlainTextResponse
from starlette.responses import PlainTextResponse, FileResponse

from catch_exceptions import catch_exceptions
from configuration import MONGO_HOST, MONGO_PORT, service_logger
from configuration import MONGO_HOST, MONGO_PORT, service_logger, OCR_OUTPUT
from PdfFile import PdfFile
from get_paragraphs import get_paragraphs
from get_xml import get_xml
Expand Down Expand Up @@ -77,3 +78,20 @@ async def get_paragraphs_endpoint(tenant: str, pdf_file_name: str):
@catch_exceptions
async def get_xml_by_name(xml_file_name: str):
return await run_in_threadpool(get_xml, xml_file_name)


@app.post("/upload/{namespace}")
async def upload_pdf(namespace, file: UploadFile = File(...)):
filename = file.filename
pdf_file = PdfFile(namespace)
pdf_file.save(pdf_file_name=filename, file=file.file.read())
return "File uploaded"


@app.get("/processed_pdf/{namespace}/{pdf_file_name}", response_class=FileResponse)
async def processed_pdf(namespace: str, pdf_file_name: str):
return FileResponse(
path=join(OCR_OUTPUT, namespace, pdf_file_name),
media_type="application/pdf",
filename=pdf_file_name,
)
2 changes: 2 additions & 0 deletions src/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
ROOT_PATH = Path(__file__).parent.parent.absolute()
DATA_PATH = join(ROOT_PATH, "data")

OCR_OUTPUT = Path(DATA_PATH, "ocr_output")

handlers = [logging.StreamHandler()]
if GRAYLOG_IP:
handlers.append(graypy.GELFUDPHandler(GRAYLOG_IP, 12201, localname="pdf_paragraphs_extraction"))
Expand Down
1 change: 1 addition & 0 deletions src/data_model/Params.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@

class Params(BaseModel):
filename: str
language: str = "en"
25 changes: 24 additions & 1 deletion src/extract_segments.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from configuration import DOCUMENT_LAYOUT_ANALYSIS_URL, service_logger, USE_FAST
from pathlib import Path

from configuration import DOCUMENT_LAYOUT_ANALYSIS_URL, service_logger, USE_FAST, OCR_OUTPUT
from data_model.SegmentBox import SegmentBox
from PdfFile import PdfFile
from data_model.ExtractionData import ExtractionData
Expand Down Expand Up @@ -39,3 +41,24 @@ def extract_segments(task: Task, xml_file_name: str = "") -> ExtractionData:
page_height=0 if not segments else segments[0].page_height,
page_width=0 if not segments else segments[0].page_width,
)


def ocr_pdf(task: Task) -> bool:
pdf_file = PdfFile(task.tenant)
path = pdf_file.get_path(task.params.filename)

if not path.exists():
raise FileNotFoundError(f"No PDF to OCR")

data = {"language": task.params.language}
for i in range(RETRIES):
with open(path, "rb") as stream:
files = {"file": stream}
results = requests.post(f"{DOCUMENT_LAYOUT_ANALYSIS_URL}/ocr", files=files, data=data)

if results and results.status_code == 200:
results_path = Path(OCR_OUTPUT, task.tenant, task.params.filename)
results_path.write_bytes(results.content)
return True

raise RuntimeError(f"Error OCR document: {results.status_code} - {results.text}")
28 changes: 24 additions & 4 deletions src/start_queue_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
)
from data_model.ResultMessage import ResultMessage
from data_model.Task import Task
from extract_segments import get_xml_name, extract_segments
from extract_segments import get_xml_name, extract_segments, ocr_pdf


def get_failed_results_message(task: Task, message: str) -> ResultMessage:
Expand All @@ -33,6 +33,22 @@ def get_failed_results_message(task: Task, message: str) -> ResultMessage:
)


def ocr_pdf_task(task):
ocr_pdf(task)

processed_pdf_url = f"{SERVICE_HOST}:{SERVICE_PORT}/processed_pdf/{task.tenant}/{task.params.filename}"
extraction_message = ResultMessage(
tenant=task.tenant,
task=task.task,
params=task.params,
success=True,
file_url=processed_pdf_url,
)

service_logger.info(f"OCR success: {extraction_message.model_dump_json()}")
return extraction_message.model_dump_json()


def process(message):
try:
task = Task(**message)
Expand All @@ -42,15 +58,19 @@ def process(message):

try:
service_logger.info(f"Processing Redis message: {message}")

if task.task == "ocr":
return ocr_pdf_task(task)

return process_task(task).model_dump_json()
except RuntimeError:
extraction_message = get_failed_results_message(task, "Error processing PDF document")
extraction_message = get_failed_results_message(task, "Error processing PDF")
service_logger.error(extraction_message.model_dump_json(), exc_info=True)
except FileNotFoundError:
extraction_message = get_failed_results_message(task, "Error getting the xml from the pdf")
extraction_message = get_failed_results_message(task, "Error FileNotFoundError")
service_logger.error(extraction_message.model_dump_json(), exc_info=True)
except Exception:
extraction_message = get_failed_results_message(task, "Error getting segments")
extraction_message = get_failed_results_message(task, "Error")
service_logger.error(extraction_message.model_dump_json(), exc_info=True)

return extraction_message.model_dump_json()
Expand Down

0 comments on commit fc89b40

Please sign in to comment.