diff --git a/docling_ibm_models/tableformer/data_management/matching_post_processor.py b/docling_ibm_models/tableformer/data_management/matching_post_processor.py index 39e420f..7d94dd9 100644 --- a/docling_ibm_models/tableformer/data_management/matching_post_processor.py +++ b/docling_ibm_models/tableformer/data_management/matching_post_processor.py @@ -4,6 +4,7 @@ # import json import logging +import math import statistics import docling_ibm_models.tableformer.settings as s @@ -403,45 +404,63 @@ def correct_overlap(box1, box2): # Push horizontally if x1_min < x2_min: # Move box1 to the left and box2 to the right - box1["bbox"][2] -= overlap_x - box2["bbox"][0] += overlap_x + box1["bbox"][2] -= math.ceil(overlap_x / 2) + 2 + box2["bbox"][0] += math.floor(overlap_x / 2) else: # Move box2 to the left and box1 to the right - box2["bbox"][2] -= overlap_x - box1["bbox"][0] += overlap_x + box2["bbox"][2] -= math.ceil(overlap_x / 2) + 2 + box1["bbox"][0] += math.floor(overlap_x / 2) else: # Push vertically if y1_min < y2_min: # Move box1 up and box2 down - box1["bbox"][3] -= overlap_y - box2["bbox"][1] += overlap_y + box1["bbox"][3] -= math.ceil(overlap_y / 2) + 2 + box2["bbox"][1] += math.floor(overlap_y / 2) else: # Move box2 up and box1 down - box2["bbox"][3] -= overlap_y - box1["bbox"][1] += overlap_y + box2["bbox"][3] -= math.ceil(overlap_y / 2) + 2 + box1["bbox"][1] += math.floor(overlap_y / 2) + + # Will flip coordinates in proper order, if previous operations reversed it + box1["bbox"] = [ + min(box1["bbox"][0], box1["bbox"][2]), + min(box1["bbox"][1], box1["bbox"][3]), + max(box1["bbox"][0], box1["bbox"][2]), + max(box1["bbox"][1], box1["bbox"][3]), + ] + box2["bbox"] = [ + min(box2["bbox"][0], box2["bbox"][2]), + min(box2["bbox"][1], box2["bbox"][3]), + max(box2["bbox"][0], box2["bbox"][2]), + max(box2["bbox"][1], box2["bbox"][3]), + ] return box1, box2 def do_boxes_overlap(box1, box2): - # print("{} - {}".format(box1["bbox"], box2["bbox"])) - # Extract coordinates from the bounding boxes - x1_min, y1_min, x1_max, y1_max = box1["bbox"] - x2_min, y2_min, x2_max, y2_max = box2["bbox"] - # Check if one box is to the left of the other - if x1_max < x2_min or x2_max < x1_min: + B1 = box1["bbox"] + B2 = box2["bbox"] + if ( + (B1[0] >= B2[2]) + or (B1[2] <= B2[0]) + or (B1[3] <= B2[1]) + or (B1[1] >= B2[3]) + ): return False - # Check if one box is above the other - if y1_max < y2_min or y2_max < y1_min: - return False - return True + else: + return True def find_overlapping_pairs_indexes(bboxes): overlapping_indexes = [] # Compare each box with every other box (combinations) for i in range(len(bboxes)): for j in range(i + 1, len(bboxes)): - if do_boxes_overlap(bboxes[i], bboxes[j]): - bboxes[i], bboxes[j] = correct_overlap(bboxes[i], bboxes[j]) + if i != j: + if bboxes[i] != bboxes[j]: + if do_boxes_overlap(bboxes[i], bboxes[j]): + bboxes[i], bboxes[j] = correct_overlap( + bboxes[i], bboxes[j] + ) return overlapping_indexes, bboxes @@ -1144,7 +1163,7 @@ def _clear_pdf_cells(self, pdf_cells): new_pdf_cells.append(pdf_cells[i]) return new_pdf_cells - def process(self, matching_details): + def process(self, matching_details, correct_overlapping_cells=False): r""" Do post processing, see details in the comments below @@ -1348,9 +1367,10 @@ def process(self, matching_details): table_cells_wo = po2 max_cell_id = po3 - # As the last step - correct cell bboxes in a way that they don't overlap: - if len(table_cells_wo) <= 300: # For performance reasons - table_cells_wo = self._find_overlapping(table_cells_wo) + if correct_overlapping_cells: + # As the last step - correct cell bboxes in a way that they don't overlap: + if len(table_cells_wo) <= 300: # For performance reasons + table_cells_wo = self._find_overlapping(table_cells_wo) self._log().debug("*** final_matches_wo") self._log().debug(final_matches_wo) diff --git a/docling_ibm_models/tableformer/data_management/tf_predictor.py b/docling_ibm_models/tableformer/data_management/tf_predictor.py index 8324d0c..48e28ac 100644 --- a/docling_ibm_models/tableformer/data_management/tf_predictor.py +++ b/docling_ibm_models/tableformer/data_management/tf_predictor.py @@ -523,8 +523,9 @@ def resize_img(self, image, width=None, height=None, inter=cv2.INTER_AREA): # return the resized image return resized, sf - def multi_table_predict(self, iocr_page, table_bboxes, do_matching=True): - # def multi_table_predict(self, iocr_page, page_image, table_bboxes): + def multi_table_predict( + self, iocr_page, table_bboxes, do_matching=True, correct_overlapping_cells=False + ): multi_tf_output = [] page_image = iocr_page["image"] @@ -546,7 +547,12 @@ def multi_table_predict(self, iocr_page, table_bboxes, do_matching=True): # Predict if do_matching: tf_responses, predict_details = self.predict( - iocr_page, table_bbox, table_image, scale_factor, None + iocr_page, + table_bbox, + table_image, + scale_factor, + None, + correct_overlapping_cells, ) else: tf_responses, predict_details = self.predict_dummy( @@ -733,7 +739,13 @@ def predict_dummy( return tf_output, matching_details def predict( - self, iocr_page, table_bbox, table_image, scale_factor, eval_res_preds=None + self, + iocr_page, + table_bbox, + table_image, + scale_factor, + eval_res_preds=None, + correct_overlapping_cells=False, ): r""" Predict the table out of an image in memory @@ -744,6 +756,8 @@ def predict( Docling provided table data eval_res_preds : dict Ready predictions provided by the evaluation results + correct_overlapping_cells : boolean + Enables or disables last post-processing step, that fixes cell bboxes to remove overlap Returns ------- @@ -834,7 +848,9 @@ def predict( ): # There are at least some pdf cells to match with if self.enable_post_process: AggProfiler().begin("post_process", self._prof) - matching_details = self._post_processor.process(matching_details) + matching_details = self._post_processor.process( + matching_details, correct_overlapping_cells + ) AggProfiler().end("post_process", self._prof) # Generate the expected Docling responses diff --git a/docling_ibm_models/tableformer/models/table04_rs/bbox_decoder_rs.py b/docling_ibm_models/tableformer/models/table04_rs/bbox_decoder_rs.py index 3c055e0..2124d53 100644 --- a/docling_ibm_models/tableformer/models/table04_rs/bbox_decoder_rs.py +++ b/docling_ibm_models/tableformer/models/table04_rs/bbox_decoder_rs.py @@ -157,7 +157,12 @@ def inference(self, encoder_out, tag_H): predictions_classes.append(self._class_embed(h)) if len(predictions_bboxes) > 0: predictions_bboxes = torch.stack([x[0] for x in predictions_bboxes]) + else: + predictions_bboxes = torch.empty(0) + if len(predictions_classes) > 0: predictions_classes = torch.stack([x[0] for x in predictions_classes]) + else: + predictions_classes = torch.empty(0) return predictions_classes, predictions_bboxes diff --git a/docling_ibm_models/tableformer/otsl.py b/docling_ibm_models/tableformer/otsl.py index 3c52fc7..1a14387 100644 --- a/docling_ibm_models/tableformer/otsl.py +++ b/docling_ibm_models/tableformer/otsl.py @@ -123,6 +123,9 @@ def otsl_check_right(rs_split, x, y): def otsl_to_html(rs_list, logdebug): + if len(rs_list) == 0: + return [] + if rs_list[0] not in ["fcel", "ched", "rhed", "srow", "ecel"]: # Most likely already HTML... return rs_list