Skip to content

Commit

Permalink
fix: validation and typechecks in TF post processing and OTSL to HTML…
Browse files Browse the repository at this point in the history
… conversion function (#18)

* - Improved TF bbox post-processing step that corrected bboxes overlap, now should produce more correct bboxes
- Bbox overlap correction step made optional, as an extra parameter in multi_table_predict
- Corrected otsl_to_html to be fault tolerant to empty rs_list

Signed-off-by: Maxim Lysak <[email protected]>

* Safety check for outputs_class, outputs_coord for proper torch tensor before sending to device

Signed-off-by: Maxim Lysak <[email protected]>

* Proper fix for empty outputs from bbox decoder

---------

Signed-off-by: Maxim Lysak <[email protected]>
Co-authored-by: Maxim Lysak <[email protected]>
Co-authored-by: Christoph Auer <[email protected]>
  • Loading branch information
3 people authored Sep 5, 2024
1 parent cb0b2d4 commit d607914
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#
import json
import logging
import math
import statistics

import docling_ibm_models.tableformer.settings as s
Expand Down Expand Up @@ -403,45 +404,63 @@ def correct_overlap(box1, box2):
# Push horizontally
if x1_min < x2_min:
# Move box1 to the left and box2 to the right
box1["bbox"][2] -= overlap_x
box2["bbox"][0] += overlap_x
box1["bbox"][2] -= math.ceil(overlap_x / 2) + 2
box2["bbox"][0] += math.floor(overlap_x / 2)
else:
# Move box2 to the left and box1 to the right
box2["bbox"][2] -= overlap_x
box1["bbox"][0] += overlap_x
box2["bbox"][2] -= math.ceil(overlap_x / 2) + 2
box1["bbox"][0] += math.floor(overlap_x / 2)
else:
# Push vertically
if y1_min < y2_min:
# Move box1 up and box2 down
box1["bbox"][3] -= overlap_y
box2["bbox"][1] += overlap_y
box1["bbox"][3] -= math.ceil(overlap_y / 2) + 2
box2["bbox"][1] += math.floor(overlap_y / 2)
else:
# Move box2 up and box1 down
box2["bbox"][3] -= overlap_y
box1["bbox"][1] += overlap_y
box2["bbox"][3] -= math.ceil(overlap_y / 2) + 2
box1["bbox"][1] += math.floor(overlap_y / 2)

# Will flip coordinates in proper order, if previous operations reversed it
box1["bbox"] = [
min(box1["bbox"][0], box1["bbox"][2]),
min(box1["bbox"][1], box1["bbox"][3]),
max(box1["bbox"][0], box1["bbox"][2]),
max(box1["bbox"][1], box1["bbox"][3]),
]
box2["bbox"] = [
min(box2["bbox"][0], box2["bbox"][2]),
min(box2["bbox"][1], box2["bbox"][3]),
max(box2["bbox"][0], box2["bbox"][2]),
max(box2["bbox"][1], box2["bbox"][3]),
]

return box1, box2

def do_boxes_overlap(box1, box2):
# print("{} - {}".format(box1["bbox"], box2["bbox"]))
# Extract coordinates from the bounding boxes
x1_min, y1_min, x1_max, y1_max = box1["bbox"]
x2_min, y2_min, x2_max, y2_max = box2["bbox"]
# Check if one box is to the left of the other
if x1_max < x2_min or x2_max < x1_min:
B1 = box1["bbox"]
B2 = box2["bbox"]
if (
(B1[0] >= B2[2])
or (B1[2] <= B2[0])
or (B1[3] <= B2[1])
or (B1[1] >= B2[3])
):
return False
# Check if one box is above the other
if y1_max < y2_min or y2_max < y1_min:
return False
return True
else:
return True

def find_overlapping_pairs_indexes(bboxes):
overlapping_indexes = []
# Compare each box with every other box (combinations)
for i in range(len(bboxes)):
for j in range(i + 1, len(bboxes)):
if do_boxes_overlap(bboxes[i], bboxes[j]):
bboxes[i], bboxes[j] = correct_overlap(bboxes[i], bboxes[j])
if i != j:
if bboxes[i] != bboxes[j]:
if do_boxes_overlap(bboxes[i], bboxes[j]):
bboxes[i], bboxes[j] = correct_overlap(
bboxes[i], bboxes[j]
)

return overlapping_indexes, bboxes

Expand Down Expand Up @@ -1144,7 +1163,7 @@ def _clear_pdf_cells(self, pdf_cells):
new_pdf_cells.append(pdf_cells[i])
return new_pdf_cells

def process(self, matching_details):
def process(self, matching_details, correct_overlapping_cells=False):
r"""
Do post processing, see details in the comments below
Expand Down Expand Up @@ -1348,9 +1367,10 @@ def process(self, matching_details):
table_cells_wo = po2
max_cell_id = po3

# As the last step - correct cell bboxes in a way that they don't overlap:
if len(table_cells_wo) <= 300: # For performance reasons
table_cells_wo = self._find_overlapping(table_cells_wo)
if correct_overlapping_cells:
# As the last step - correct cell bboxes in a way that they don't overlap:
if len(table_cells_wo) <= 300: # For performance reasons
table_cells_wo = self._find_overlapping(table_cells_wo)

self._log().debug("*** final_matches_wo")
self._log().debug(final_matches_wo)
Expand Down
26 changes: 21 additions & 5 deletions docling_ibm_models/tableformer/data_management/tf_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,8 +523,9 @@ def resize_img(self, image, width=None, height=None, inter=cv2.INTER_AREA):
# return the resized image
return resized, sf

def multi_table_predict(self, iocr_page, table_bboxes, do_matching=True):
# def multi_table_predict(self, iocr_page, page_image, table_bboxes):
def multi_table_predict(
self, iocr_page, table_bboxes, do_matching=True, correct_overlapping_cells=False
):
multi_tf_output = []
page_image = iocr_page["image"]

Expand All @@ -546,7 +547,12 @@ def multi_table_predict(self, iocr_page, table_bboxes, do_matching=True):
# Predict
if do_matching:
tf_responses, predict_details = self.predict(
iocr_page, table_bbox, table_image, scale_factor, None
iocr_page,
table_bbox,
table_image,
scale_factor,
None,
correct_overlapping_cells,
)
else:
tf_responses, predict_details = self.predict_dummy(
Expand Down Expand Up @@ -733,7 +739,13 @@ def predict_dummy(
return tf_output, matching_details

def predict(
self, iocr_page, table_bbox, table_image, scale_factor, eval_res_preds=None
self,
iocr_page,
table_bbox,
table_image,
scale_factor,
eval_res_preds=None,
correct_overlapping_cells=False,
):
r"""
Predict the table out of an image in memory
Expand All @@ -744,6 +756,8 @@ def predict(
Docling provided table data
eval_res_preds : dict
Ready predictions provided by the evaluation results
correct_overlapping_cells : boolean
Enables or disables last post-processing step, that fixes cell bboxes to remove overlap
Returns
-------
Expand Down Expand Up @@ -834,7 +848,9 @@ def predict(
): # There are at least some pdf cells to match with
if self.enable_post_process:
AggProfiler().begin("post_process", self._prof)
matching_details = self._post_processor.process(matching_details)
matching_details = self._post_processor.process(
matching_details, correct_overlapping_cells
)
AggProfiler().end("post_process", self._prof)

# Generate the expected Docling responses
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,12 @@ def inference(self, encoder_out, tag_H):
predictions_classes.append(self._class_embed(h))
if len(predictions_bboxes) > 0:
predictions_bboxes = torch.stack([x[0] for x in predictions_bboxes])
else:
predictions_bboxes = torch.empty(0)

if len(predictions_classes) > 0:
predictions_classes = torch.stack([x[0] for x in predictions_classes])
else:
predictions_classes = torch.empty(0)

return predictions_classes, predictions_bboxes
3 changes: 3 additions & 0 deletions docling_ibm_models/tableformer/otsl.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,9 @@ def otsl_check_right(rs_split, x, y):


def otsl_to_html(rs_list, logdebug):
if len(rs_list) == 0:
return []

if rs_list[0] not in ["fcel", "ched", "rhed", "srow", "ecel"]:
# Most likely already HTML...
return rs_list
Expand Down

0 comments on commit d607914

Please sign in to comment.