Skip to content

Commit

Permalink
Add exception handler decorator
Browse files Browse the repository at this point in the history
  • Loading branch information
gabriel-piles committed Jan 23, 2025
1 parent 871778f commit cfbc022
Show file tree
Hide file tree
Showing 6 changed files with 85 additions and 89 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ sentry-sdk==2.8.0
redis==5.0.7
requests==2.32.3
git+https://github.com/huridocs/queue-processor@1f0294f17074ac6d52b0e48d49a32272f121cd9c
git+https://github.com/huridocs/trainable-entity-extractor@4e79693f42db5abc5d77be1d66cf172181d2590f
git+https://github.com/huridocs/trainable-entity-extractor@2ce855361399006045828af1cbff8e0aa706642f
112 changes: 39 additions & 73 deletions src/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from os.path import join

import pymongo
from catch_exceptions import catch_exceptions
from fastapi import FastAPI, HTTPException, UploadFile, File
import sys

Expand All @@ -19,7 +20,6 @@
from trainable_entity_extractor.send_logs import send_logs

from config import MONGO_HOST, MONGO_PORT, DATA_PATH
from data.Options import Options


@asynccontextmanager
Expand All @@ -45,6 +45,7 @@ async def lifespan(app: FastAPI):


@app.get("/")
@app.get("/info")
async def info():
config_logger.info("PDF information extraction endpoint")
return sys.version
Expand All @@ -57,101 +58,66 @@ async def error():


@app.post("/xml_to_train/{tenant}/{extraction_id}")
@app.post("/extract_paragraphs_xml/{tenant}/{extraction_id}")
@catch_exceptions
async def to_train_xml_file(tenant, extraction_id, file: UploadFile = File(...)):
filename = '"No file name! Probably an error about the file in the request"'
try:
filename = file.filename
xml_file = XmlFile(
extraction_identifier=ExtractionIdentifier(
run_name=tenant, extraction_name=extraction_id, output_path=DATA_PATH
),
to_train=True,
xml_file_name=filename,
)
xml_file.save(file=file.file.read())
return "xml_to_train saved"
except Exception:
config_logger.error(f"Error adding task {filename}", exc_info=1)
raise HTTPException(status_code=422, detail=f"Error adding task {filename}")
filename = file.filename
xml_file = XmlFile(
extraction_identifier=ExtractionIdentifier(run_name=tenant, extraction_name=extraction_id, output_path=DATA_PATH),
to_train=True,
xml_file_name=filename,
)
xml_file.save(file=file.file.read())
return "xml_to_train saved"


@app.post("/xml_to_predict/{tenant}/{extraction_id}")
@catch_exceptions
async def to_predict_xml_file(tenant, extraction_id, file: UploadFile = File(...)):
filename = '"No file name! Probably an error about the file in the request"'
try:
filename = file.filename
xml_file = XmlFile(
extraction_identifier=ExtractionIdentifier(
run_name=tenant, extraction_name=extraction_id, output_path=DATA_PATH
),
to_train=False,
xml_file_name=filename,
)
xml_file.save(file=file.file.read())
return "xml_to_train saved"
except Exception:
config_logger.error(f"Error adding task {filename}", exc_info=1)
raise HTTPException(status_code=422, detail=f"Error adding task {filename}")
filename = file.filename
xml_file = XmlFile(
extraction_identifier=ExtractionIdentifier(run_name=tenant, extraction_name=extraction_id, output_path=DATA_PATH),
to_train=False,
xml_file_name=filename,
)
xml_file.save(file=file.file.read())
return "xml_to_train saved"


@app.post("/labeled_data")
@catch_exceptions
async def labeled_data_post(labeled_data: LabeledData):
try:
pdf_metadata_extraction_db = app.mongodb_client["pdf_metadata_extraction"]
pdf_metadata_extraction_db.labeled_data.insert_one(labeled_data.scale_down_labels().to_dict())
return "labeled data saved"
except Exception:
config_logger.error("Error", exc_info=1)
raise HTTPException(status_code=422, detail="An error has occurred. Check graylog for more info")
pdf_metadata_extraction_db = app.mongodb_client["pdf_metadata_extraction"]
pdf_metadata_extraction_db.labeled_data.insert_one(labeled_data.scale_down_labels().to_dict())
return "labeled data saved"


@app.post("/prediction_data")
@catch_exceptions
async def prediction_data_post(prediction_data: PredictionData):
try:
pdf_metadata_extraction_db = app.mongodb_client["pdf_metadata_extraction"]
pdf_metadata_extraction_db.prediction_data.insert_one(prediction_data.to_dict())
return "prediction data saved"
except Exception:
config_logger.error("Error", exc_info=1)
raise HTTPException(status_code=422, detail="An error has occurred. Check graylog for more info")
pdf_metadata_extraction_db = app.mongodb_client["pdf_metadata_extraction"]
pdf_metadata_extraction_db.prediction_data.insert_one(prediction_data.to_dict())
return "prediction data saved"


@app.get("/get_suggestions/{tenant}/{extraction_id}")
@catch_exceptions
async def get_suggestions(tenant: str, extraction_id: str):
try:
pdf_metadata_extraction_db = app.mongodb_client["pdf_metadata_extraction"]
suggestions_filter = {"tenant": tenant, "id": extraction_id}
suggestions_list: list[str] = list()
pdf_metadata_extraction_db = app.mongodb_client["pdf_metadata_extraction"]
suggestions_filter = {"tenant": tenant, "id": extraction_id}
suggestions_list: list[str] = list()

for document in pdf_metadata_extraction_db.suggestions.find(suggestions_filter):
suggestions_list.append(Suggestion(**document).scale_up().to_output())
for document in pdf_metadata_extraction_db.suggestions.find(suggestions_filter):
suggestions_list.append(Suggestion(**document).scale_up().to_output())

pdf_metadata_extraction_db.suggestions.delete_many(suggestions_filter)
extraction_identifier = ExtractionIdentifier(run_name=tenant, extraction_name=extraction_id, output_path=DATA_PATH)
send_logs(extraction_identifier, f"{len(suggestions_list)} suggestions queried")
pdf_metadata_extraction_db.suggestions.delete_many(suggestions_filter)
extraction_identifier = ExtractionIdentifier(run_name=tenant, extraction_name=extraction_id, output_path=DATA_PATH)
send_logs(extraction_identifier, f"{len(suggestions_list)} suggestions queried")

return json.dumps(suggestions_list)
except Exception:
config_logger.error("Error", exc_info=1)
raise HTTPException(status_code=422, detail="An error has occurred. Check graylog for more info")
return json.dumps(suggestions_list)


@app.delete("/{tenant}/{extraction_id}")
async def get_suggestions(tenant: str, extraction_id: str):
shutil.rmtree(join(DATA_PATH, tenant, extraction_id), ignore_errors=True)
return True


@app.post("/options")
def save_options(options: Options):
try:
extraction_identifier = ExtractionIdentifier(
run_name=options.tenant, extraction_name=options.extraction_id, output_path=DATA_PATH
)
extraction_identifier.save_options(options.options)
os.utime(extraction_identifier.get_options_path().parent)
config_logger.info(f"Options {options.options[:150]} saved for {extraction_identifier}")
return True
except Exception:
config_logger.error("Error", exc_info=1)
raise HTTPException(status_code=422, detail="An error has occurred. Check graylog for more info")
15 changes: 15 additions & 0 deletions src/catch_exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from functools import wraps
from fastapi import HTTPException
from trainable_entity_extractor.config import config_logger


def catch_exceptions(func):
@wraps(func)
async def wrapper(*args, **kwargs):
try:
return await func(*args, **kwargs)
except Exception:
config_logger.error("Error see traceback", exc_info=1)
raise HTTPException(status_code=422, detail="An error has occurred. Check graylog for more info")

return wrapper
5 changes: 5 additions & 0 deletions src/data/TaskType.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from pydantic import BaseModel


class TaskType(BaseModel):
task: str
25 changes: 21 additions & 4 deletions src/start_queue_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from data.ExtractionTask import ExtractionTask
from data.ResultsMessage import ResultsMessage
from Extractor import Extractor
from data.TaskType import TaskType


def restart_condition(message: dict[str, any]) -> bool:
Expand All @@ -27,14 +28,31 @@ def restart_condition(message: dict[str, any]) -> bool:

def process(message: dict[str, any]) -> dict[str, any] | None:
try:
task = ExtractionTask(**message)
task_type = TaskType(**message)
config_logger.info(f"New task {message}")
except ValidationError:
config_logger.error(f"Not a valid Redis message: {message}")
return None

task_calculated, error_message = Extractor.calculate_task(task)
if task_type.task in [Extractor.CREATE_MODEL_TASK_NAME, Extractor.SUGGESTIONS_TASK_NAME]:
result_message = extraction_task(message)
else:
task = ExtractionTask(**message)
result_message = ResultsMessage(
tenant=task.tenant,
task=task.task,
params=task.params,
success=False,
error_message="Task not found",
)
config_logger.error(f"Task not found: {task_to_string(task_type)}")

return result_message.model_dump()


def extraction_task(message):
task = ExtractionTask(**message)
task_calculated, error_message = Extractor.calculate_task(task)
if task_calculated:
data_url = None

Expand All @@ -58,12 +76,11 @@ def process(message: dict[str, any]) -> dict[str, any] | None:
success=False,
error_message=error_message,
)

extraction_identifier = ExtractionIdentifier(
run_name=task.tenant, extraction_name=task.params.id, metadata=task.params.metadata, output_path=DATA_PATH
)
send_logs(extraction_identifier, f"Result message: {model_results_message.to_string()}")
return model_results_message.model_dump()
return model_results_message


def task_to_string(extraction_task: ExtractionTask):
Expand Down
15 changes: 4 additions & 11 deletions src/test_end_to_end.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def test_pdf_to_text(self):
self.assertEqual(len(suggestion.segments_boxes), 2)
self.assertEqual(529, suggestion.segments_boxes[0].left)
self.assertEqual(120, suggestion.segments_boxes[0].top)
self.assertEqual(105, suggestion.segments_boxes[0].width)
self.assertEqual(100, suggestion.segments_boxes[0].width)
self.assertEqual(15, suggestion.segments_boxes[0].height)
self.assertEqual(1, suggestion.segments_boxes[0].page_number)

Expand Down Expand Up @@ -207,14 +207,6 @@ def test_pdf_to_multi_option(self):
}
requests.post(f"{SERVER_URL}/labeled_data", json=labeled_data_json)

options = {
"tenant": tenant,
"extraction_id": extraction_id,
"options": [Option(id="1", label="United Nations").model_dump(), Option(id="2", label="Other").model_dump()],
}

requests.post(f"{SERVER_URL}/options", json=options)

with open(test_xml_path, mode="rb") as stream:
files = {"file": stream}
requests.post(f"{SERVER_URL}/xml_to_predict/{tenant}/{extraction_id}", files=files)
Expand All @@ -230,10 +222,11 @@ def test_pdf_to_multi_option(self):

requests.post(f"{SERVER_URL}/prediction_data", json=predict_data_json)

options = [Option(id="1", label="United Nations"), Option(id="2", label="Other")]
task = ExtractionTask(
tenant=tenant,
task="create_model",
params=Params(id=extraction_id, multi_value=False, metadata={"name": "test"}),
params=Params(id=extraction_id, multi_value=False, metadata={"name": "test"}, options=options),
)

QUEUE.sendMessage(delay=0).message(task.model_dump_json()).execute()
Expand Down Expand Up @@ -265,7 +258,7 @@ def test_pdf_to_multi_option(self):
SegmentBox(
left=164.0,
top=60.0,
width=116.0,
width=109.0,
height=21.0,
page_width=0,
page_height=0,
Expand Down

0 comments on commit cfbc022

Please sign in to comment.