Skip to content

Commit

Permalink
added error handling to image_to_text as well as updated tests
Browse files Browse the repository at this point in the history
  • Loading branch information
arinkulshi-skylight committed Nov 26, 2024
1 parent 44f6cef commit 7b4c014
Show file tree
Hide file tree
Showing 2 changed files with 132 additions and 20 deletions.
50 changes: 40 additions & 10 deletions OCR/ocr/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,17 @@


def data_uri_to_image(data_uri: str):
image_stripped = data_uri.replace("data:image/png;base64,", "", 1)
image_np = np.frombuffer(base64.b64decode(image_stripped), np.uint8)
return cv.imdecode(image_np, cv.IMREAD_COLOR)
try:
base64_data = data_uri.split(",")[1]
image_data = base64.b64decode(base64_data)
image_np = np.frombuffer(image_data, np.uint8)
image = cv.imdecode(image_np, cv.IMREAD_COLOR)
return image
except Exception as e:
raise HTTPException(
status_code=422,
detail=f"Failed to decode source image. Ensure the file is a valid Base64 image format. Error: {str(e)}",
)


def image_to_data_uri(image: np.ndarray):
Expand Down Expand Up @@ -106,15 +114,37 @@ async def image_file_to_text(source_image: UploadFile, segmentation_template: Up
return results


@app.post("/image_to_text/")
async def image_to_text(source_image: str = Form(), segmentation_template: str = Form(), labels: str = Form()):
source_image_img = data_uri_to_image(source_image)
segmentation_template_img = data_uri_to_image(segmentation_template)
@app.post("/image_to_text")
async def image_to_text(
source_image: str = Form(...), segmentation_template: str = Form(...), labels: str = Form(...)
):
try:
source_image_img = data_uri_to_image(source_image)
segmentation_template_img = data_uri_to_image(segmentation_template)

loaded_json = json.loads(labels)
segments = segmenter.segment(source_image_img, segmentation_template_img, loaded_json)
results = ocr.image_to_text(segments)
if source_image_img.shape[:2] != segmentation_template_img.shape[:2]:
raise HTTPException(
status_code=400,
detail="Dimension mismatch between source image and segmentation template. Both images must have the same width and height.",
)

try:
loaded_json = json.loads(labels)
except json.JSONDecodeError:
raise HTTPException(
status_code=422, detail="Failed to parse labels JSON. Ensure the labels are in valid JSON format."
)

segments = segmenter.segment(source_image_img, segmentation_template_img, loaded_json)
results = ocr.image_to_text(segments)

except asyncio.TimeoutError:
raise HTTPException(status_code=504, detail="The request timed out. Please try again.")
except HTTPException as e:
raise e
except Exception as e:
print(f"Unexpected error occurred: {str(e)}")
raise HTTPException(status_code=500, detail="An unexpected server error occurred.")
return results


Expand Down
102 changes: 92 additions & 10 deletions OCR/tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ def test_health_check(self):
assert response.json() == {"status": "UP"}

def test_image_file_to_text(self):
# load the files
with (
open(segmentation_template_path, "rb") as segmentation_template_file,
open(source_image_path, "rb") as source_image_file,
Expand All @@ -40,7 +39,6 @@ def test_image_file_to_text(self):
("segmentation_template", segmentation_template_file),
]

# call ocr api
response = client.post(
url="/image_file_to_text", files=files_to_send, data={"labels": json.dumps(label_data)}
)
Expand All @@ -53,36 +51,36 @@ def test_image_file_to_text(self):
assert response_json["nbs_cas_id"][0] == "123555"

def test_image_to_text(self):
# load the files
with (
open(segmentation_template_path, "rb") as segmentation_template_file,
open(source_image_path, "rb") as source_image_file,
open(labels_path, "rb") as labels,
open(labels_path, "rb") as labels_file,
):
label_data = json.load(labels)
label_data = json.load(labels_file)

source_image_base64 = base64.b64encode(source_image_file.read()).decode("ascii")
segmentation_template_base64 = base64.b64encode(segmentation_template_file.read()).decode("ascii")

source_image_data_uri = f"data:image/png;base64,{source_image_base64}"
segmentation_template_data_uri = f"data:image/png;base64,{segmentation_template_base64}"

response = client.post(
url="/image_to_text",
data={
"labels": json.dumps(label_data),
"source_image": str(source_image_base64),
"segmentation_template": str(segmentation_template_base64),
"source_image": source_image_data_uri,
"segmentation_template": segmentation_template_data_uri,
},
headers={"Content-Type": "application/x-www-form-urlencoded"},
)

print(response.json())
assert response.status_code == 200

# assert output
response_json = response.json()
assert response_json["nbs_patient_id"][0] == "SIENNA HAMPTON"
assert response_json["nbs_cas_id"][0] == "123555"

def test_image_to_text_with_padding(self):
# load the files
with (
open(segmentation_template_path, "rb") as segmentation_template_file,
open(source_image_path, "rb") as source_image_file,
Expand Down Expand Up @@ -218,3 +216,87 @@ def test_timeout_error_simulation(self):

assert response.status_code == 504
assert response.json()["detail"] == "The request timed out. Please try again."

def test_image_to_text_invalid_image_format(self):
with (
open(segmentation_template_path, "rb") as segmentation_template_file,
open(labels_path, "r") as labels_file,
):
invalid_source_image_data = "_is_not_a_valid_base64_encoded_string"

segmentation_template_base64 = base64.b64encode(segmentation_template_file.read()).decode("ascii")
segmentation_template_data_uri = f"data:image/png;base64,{segmentation_template_base64}"

label_data = json.load(labels_file)

response = client.post(
url="/image_to_text",
data={
"labels": json.dumps(label_data),
"source_image": invalid_source_image_data,
"segmentation_template": segmentation_template_data_uri,
},
headers={"Content-Type": "application/x-www-form-urlencoded"},
)

assert response.status_code == 422
assert (
response.json()["detail"]
== "Failed to decode source image. Ensure the file is a valid Base64 image format. Error: Incorrect padding"
)

def test_image_to_text_dimension_mismatch(self):
with (
open(source_image_path, "rb") as source_image_file,
open(invalid_dimension_path, "rb") as invalid_dimension_file,
open(labels_path, "r") as labels_file,
):
source_image_base64 = base64.b64encode(source_image_file.read()).decode("ascii")
source_image_data_uri = f"data:image/png;base64,{source_image_base64}"

invalid_dimension_base64 = base64.b64encode(invalid_dimension_file.read()).decode("ascii")
invalid_dimension_data_uri = f"data:image/png;base64,{invalid_dimension_base64}"

label_data = json.load(labels_file)

response = client.post(
url="/image_to_text",
data={
"labels": json.dumps(label_data),
"source_image": source_image_data_uri,
"segmentation_template": invalid_dimension_data_uri,
},
headers={"Content-Type": "application/x-www-form-urlencoded"},
)

assert response.status_code == 400
assert (
response.json()["detail"]
== "Dimension mismatch between source image and segmentation template. Both images must have the same width and height."
)

def test_image_to_text_invalid_json_labels(self):
with (
open(source_image_path, "rb") as source_image_file,
open(segmentation_template_path, "rb") as segmentation_template_file,
):
invalid_label_data = "{invalid: json}"

source_image_base64 = base64.b64encode(source_image_file.read()).decode("ascii")
source_image_data_uri = f"data:image/png;base64,{source_image_base64}"

segmentation_template_base64 = base64.b64encode(segmentation_template_file.read()).decode("ascii")
segmentation_template_data_uri = f"data:image/png;base64,{segmentation_template_base64}"

response = client.post(
url="/image_to_text",
data={
"labels": invalid_label_data,
"source_image": source_image_data_uri,
"segmentation_template": segmentation_template_data_uri,
},
headers={"Content-Type": "application/x-www-form-urlencoded"},
)

assert response.status_code == 422
assert response.json()["detail"] == "Failed to parse labels JSON. Ensure the labels are in valid JSON format."

0 comments on commit 7b4c014

Please sign in to comment.