Skip to content

Commit

Permalink
fix - table transformer predictions are filtered if confidence is bel…
Browse files Browse the repository at this point in the history
…ow threshold (#338)

Add usage of table transformer related thresholds. 
The predictions with low confidence score are filtered out
  • Loading branch information
plutasnyy authored Apr 24, 2024
1 parent 4304c83 commit 9c3f644
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 24 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
## 0.7.29

* fix: table transformer predictions are now removed if confidence is below threshold


## 0.7.28

* feat: allow table transformer agent to return table prediction in not parsed format
Expand Down
50 changes: 50 additions & 0 deletions test_unstructured_inference/models/test_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import unstructured_inference.models.table_postprocess as postprocess
from unstructured_inference.models import tables
from unstructured_inference.models.tables import apply_thresholds_on_objects

skip_outside_ci = os.getenv("CI", "").lower() in {"", "false", "f", "0"}

Expand Down Expand Up @@ -977,6 +978,55 @@ def test_table_prediction_with_no_ocr_tokens(table_transformer, example_image):
table_transformer.predict(example_image)


@pytest.mark.parametrize(
("thresholds", "expected_object_number"),
[
({"0": 0.5}, 1),
({"0": 0.1}, 3),
({"0": 0.9}, 0),
],
)
def test_objects_are_filtered_based_on_class_thresholds_when_correct_prediction_and_threshold(
thresholds, expected_object_number
):
objects = [
{"label": "0", "score": 0.2},
{"label": "0", "score": 0.4},
{"label": "0", "score": 0.55},
]
assert len(apply_thresholds_on_objects(objects, thresholds)) == expected_object_number


@pytest.mark.parametrize(
("thresholds", "expected_object_number"),
[
({"0": 0.5, "1": 0.1}, 4),
({"0": 0.1, "1": 0.9}, 3),
({"0": 0.9, "1": 0.5}, 1),
],
)
def test_objects_are_filtered_based_on_class_thresholds_when_two_classes(
thresholds, expected_object_number
):
objects = [
{"label": "0", "score": 0.2},
{"label": "0", "score": 0.4},
{"label": "0", "score": 0.55},
{"label": "1", "score": 0.2},
{"label": "1", "score": 0.4},
{"label": "1", "score": 0.55},
]
assert len(apply_thresholds_on_objects(objects, thresholds)) == expected_object_number


def test_objects_filtering_when_missing_threshold():
class_name = "class_name"
objects = [{"label": class_name, "score": 0.2}]
thresholds = {"1": 0.5}
with pytest.raises(KeyError, match=class_name):
apply_thresholds_on_objects(objects, thresholds)


def test_intersect():
a = postprocess.Rect()
b = postprocess.Rect([1, 2, 3, 4])
Expand Down
2 changes: 1 addition & 1 deletion unstructured_inference/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.7.28" # pragma: no cover
__version__ = "0.7.29" # pragma: no cover
18 changes: 0 additions & 18 deletions unstructured_inference/models/table_postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,24 +80,6 @@ def apply_threshold(objects, threshold):
return [obj for obj in objects if obj["score"] >= threshold]


# def apply_class_thresholds(bboxes, labels, scores, class_names, class_thresholds):
# """
# Filter out bounding boxes whose confidence is below the confidence threshold for
# its associated class label.
# """
# # Apply class-specific thresholds
# indices_above_threshold = [
# idx
# for idx, (score, label) in enumerate(zip(scores, labels))
# if score >= class_thresholds[class_names[label]]
# ]
# bboxes = [bboxes[idx] for idx in indices_above_threshold]
# scores = [scores[idx] for idx in indices_above_threshold]
# labels = [labels[idx] for idx in indices_above_threshold]

# return bboxes, scores, labels


def refine_rows(rows, tokens, score_threshold):
"""
Apply operations to the detected rows, such as
Expand Down
43 changes: 38 additions & 5 deletions unstructured_inference/models/tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,16 @@
import xml.etree.ElementTree as ET
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Union
from typing import Any, Dict, List, Mapping, Optional, Sequence, Union

import cv2
import numpy as np
import torch
from PIL import Image as PILImage
from transformers import DetrImageProcessor, TableTransformerForObjectDetection
from transformers.models.table_transformer.modeling_table_transformer import (
TableTransformerObjectDetectionOutput,
)

from unstructured_inference.config import inference_config
from unstructured_inference.inference.layoutelement import table_cells_to_dataframe
Expand Down Expand Up @@ -172,18 +175,22 @@ def recognize(outputs: dict, img: PILImage.Image, tokens: list):
"""Recognize table elements."""
str_class_name2idx = get_class_map("structure")
str_class_idx2name = {v: k for k, v in str_class_name2idx.items()}
str_class_thresholds = structure_class_thresholds
class_thresholds = structure_class_thresholds

# Post-process detected objects, assign class labels
objects = outputs_to_objects(outputs, img.size, str_class_idx2name)

high_confidence_objects = apply_thresholds_on_objects(objects, class_thresholds)
# Further process the detected objects so they correspond to a consistent table
tables_structure = objects_to_structures(objects, tokens, str_class_thresholds)
tables_structure = objects_to_structures(high_confidence_objects, tokens, class_thresholds)
# Enumerate all table cells: grid cells and spanning cells
return [structure_to_cells(structure, tokens)[0] for structure in tables_structure]


def outputs_to_objects(outputs, img_size, class_idx2name):
def outputs_to_objects(
outputs: TableTransformerObjectDetectionOutput,
img_size: tuple[int, int],
class_idx2name: Mapping[int, str],
):
"""Output table element types."""
m = outputs["logits"].softmax(-1).max(-1)
pred_labels = list(m.indices.detach().cpu().numpy())[0]
Expand Down Expand Up @@ -212,6 +219,32 @@ def outputs_to_objects(outputs, img_size, class_idx2name):
return objects


def apply_thresholds_on_objects(
objects: Sequence[Mapping[str, Any]], thresholds: Mapping[str, float]
) -> Sequence[Mapping[str, Any]]:
"""
Filters predicted objects which the confidence scores below the thresholds
Args:
objects: Sequence of mappings for example:
[
{
"label": "table row",
"score": 0.55,
"bbox": [...],
},
...,
]
thresholds: Mapping from labels to thresholds
Returns:
Filtered list of objects
"""
objects = [obj for obj in objects if obj["score"] >= thresholds[obj["label"]]]
return objects


# for output bounding box post-processing
def box_cxcywh_to_xyxy(x):
"""Convert rectangle format from center-x, center-y, width, height to
Expand Down

0 comments on commit 9c3f644

Please sign in to comment.