Skip to content

Commit

Permalink
Merge branch 'main' into feat/chipper-repetitions
Browse files Browse the repository at this point in the history
  • Loading branch information
ajjimeno authored Jan 31, 2024
2 parents a965f70 + ed5f2c2 commit ca82299
Show file tree
Hide file tree
Showing 7 changed files with 63 additions and 9 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ on:
branches: [ main ]

env:
PYTHON_VERSION: 3.8
PYTHON_VERSION: 3.9

jobs:
setup:
Expand Down Expand Up @@ -107,7 +107,7 @@ jobs:
test_ingest:
strategy:
matrix:
python-version: ["3.8","3.9","3.10"]
python-version: ["3.9","3.10"]
runs-on: ubuntu-latest
env:
NLTK_DATA: ${{ github.workspace }}/nltk_data
Expand Down
11 changes: 10 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,16 @@
## 0.7.22
## 0.7.24

* Revised repetitions for Chipper

## 0.7.23

* fix: added handling in `UnstructuredTableTransformerModel` for if `recognize` returns an empty
list in `run_prediction`.

## 0.7.22

* fix: add logic to handle computation of intersections betwen 2 `Rectangle`s when a `Rectangle` has `None` value in its coordinates

## 0.7.21

* fix: fix a bug where chipper, or any element extraction model based `PageLayout` object, lack `image_metadata` and other attributes that are required for downstream processing; this fix also reduces the memory overhead of using chipper model
Expand Down
15 changes: 15 additions & 0 deletions test_unstructured_inference/models/test_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -932,6 +932,21 @@ def test_table_prediction_output_format(
assert expectation in result


def test_table_prediction_runs_with_empty_recognize(
table_transformer,
example_image,
mocker,
mocked_ocr_tokens,
):
mocker.patch.object(tables, "recognize", return_value=[])
mocker.patch.object(
tables.UnstructuredTableTransformerModel,
"get_structure",
return_value=None,
)
assert table_transformer.run_prediction(example_image, ocr_tokens=mocked_ocr_tokens) == ""


def test_table_prediction_with_ocr_tokens(table_transformer, example_image, mocked_ocr_tokens):
prediction = table_transformer.predict(example_image, ocr_tokens=mocked_ocr_tokens)
assert '<table><thead><th rowspan="2">' in prediction
Expand Down
23 changes: 19 additions & 4 deletions test_unstructured_inference/test_elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@

from unstructured_inference.constants import ElementType
from unstructured_inference.inference import elements
from unstructured_inference.inference.elements import TextRegion
from unstructured_inference.inference.elements import Rectangle, TextRegion
from unstructured_inference.inference.layoutelement import (
LayoutElement,
merge_inferred_layout_with_extracted_layout,
partition_groups_from_regions,
separate,
merge_inferred_layout_with_extracted_layout,
LayoutElement,
)

skip_outside_ci = os.getenv("CI", "").lower() in {"", "false", "f", "0"}
Expand All @@ -31,6 +31,18 @@ def rand_rect(size=10):
return elements.Rectangle(x1, y1, x1 + size, y1 + size)


@pytest.mark.parametrize(
("rect1", "rect2", "expected"),
[
(Rectangle(0, 0, 1, 1), Rectangle(0, 0, None, None), None),
(Rectangle(0, 0, None, None), Rectangle(0, 0, 1, 1), None),
],
)
def test_unhappy_intersection(rect1, rect2, expected):
assert rect1.intersection(rect2) == expected
assert not rect1.intersects(rect2)


@pytest.mark.parametrize("second_size", [10, 20])
def test_intersects(second_size):
for _ in range(1000):
Expand Down Expand Up @@ -198,7 +210,10 @@ def test_intersection_over_min(


def test_grow_region_to_match_region():
from unstructured_inference.inference.elements import Rectangle, grow_region_to_match_region
from unstructured_inference.inference.elements import (
Rectangle,
grow_region_to_match_region,
)

a = Rectangle(1, 1, 2, 2)
b = Rectangle(1, 1, 5, 5)
Expand Down
2 changes: 1 addition & 1 deletion unstructured_inference/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.7.22" # pragma: no cover
__version__ = "0.7.24" # pragma: no cover
8 changes: 8 additions & 0 deletions unstructured_inference/inference/elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ def is_disjoint(self, other: Rectangle) -> bool:

def intersects(self, other: Rectangle) -> bool:
"""Checks whether this rectangle intersects another rectangle."""
if self._has_none() or other._has_none():
return False
return intersections(self, other)[0, 1]

def is_in(self, other: Rectangle, error_margin: Optional[Union[int, float]] = None) -> bool:
Expand All @@ -81,6 +83,10 @@ def is_in(self, other: Rectangle, error_margin: Optional[Union[int, float]] = No
],
)

def _has_none(self) -> bool:
"""return false when one of the coord is nan"""
return any((self.x1 is None, self.x2 is None, self.y1 is None, self.y2 is None))

@property
def coordinates(self):
"""Gets coordinates of the rectangle"""
Expand All @@ -89,6 +95,8 @@ def coordinates(self):
def intersection(self, other: Rectangle) -> Optional[Rectangle]:
"""Gives the rectangle that is the intersection of two rectangles, or None if the
rectangles are disjoint."""
if self._has_none() or other._has_none():
return None
x1 = max(self.x1, other.x1)
x2 = min(self.x2, other.x2)
y1 = max(self.y1, other.y1)
Expand Down
9 changes: 8 additions & 1 deletion unstructured_inference/models/tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,14 @@ def run_prediction(
outputs_structure = self.get_structure(x, pad_for_structure_detection)
if ocr_tokens is None:
raise ValueError("Cannot predict table structure with no OCR tokens")
prediction = recognize(outputs_structure, x, tokens=ocr_tokens)[0]

recognized_table = recognize(outputs_structure, x, tokens=ocr_tokens)
if len(recognized_table) > 0:
prediction = recognized_table[0]
# NOTE(robinson) - This means that the table was not recognized
else:
return ""

if result_format == "html":
# Convert cells to HTML
prediction = cells_to_html(prediction) or ""
Expand Down

0 comments on commit ca82299

Please sign in to comment.