Skip to content

Commit

Permalink
Save table prediction in cells format (#335)
Browse files Browse the repository at this point in the history
This pull request allows to return predictions in raw cell
representation from table transformer. It will used in to 
save prediction in a cells format for simpler metrics calculation

---------

Co-authored-by: Yao You <[email protected]>
  • Loading branch information
plutasnyy and badGarnet authored Apr 22, 2024
1 parent 45ebc1f commit 4304c83
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 3 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
## 0.7.28

* feat: allow table transformer agent to return table prediction in not parsed format

## 0.7.27

* fix: remove pin from `onnxruntime` dependency.
Expand Down
19 changes: 19 additions & 0 deletions test_unstructured_inference/models/test_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -932,6 +932,25 @@ def test_table_prediction_output_format(
assert expectation in result


def test_table_prediction_output_format_when_wrong_type_then_value_error(
table_transformer,
example_image,
mocker,
example_table_cells,
mocked_ocr_tokens,
):
mocker.patch.object(tables, "recognize", return_value=example_table_cells)
mocker.patch.object(
tables.UnstructuredTableTransformerModel,
"get_structure",
return_value=None,
)
with pytest.raises(ValueError):
table_transformer.run_prediction(
example_image, result_format="Wrong format", ocr_tokens=mocked_ocr_tokens
)


def test_table_prediction_runs_with_empty_recognize(
table_transformer,
example_image,
Expand Down
2 changes: 1 addition & 1 deletion unstructured_inference/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.7.27" # pragma: no cover
__version__ = "0.7.28" # pragma: no cover
17 changes: 15 additions & 2 deletions unstructured_inference/models/tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,12 @@ class UnstructuredTableTransformerModel(UnstructuredModel):
def __init__(self):
pass

def predict(self, x: PILImage.Image, ocr_tokens: Optional[List[Dict]] = None):
def predict(
self,
x: PILImage.Image,
ocr_tokens: Optional[List[Dict]] = None,
result_format: str = "html",
):
"""Predict table structure deferring to run_prediction with ocr tokens
Note:
Expand All @@ -44,7 +49,7 @@ def predict(self, x: PILImage.Image, ocr_tokens: Optional[List[Dict]] = None):
FIXME: refactor token data into a dataclass so we have clear expectations of the fields
"""
super().predict(x)
return self.run_prediction(x, ocr_tokens=ocr_tokens)
return self.run_prediction(x, ocr_tokens=ocr_tokens, result_format=result_format)

def initialize(
self,
Expand Down Expand Up @@ -109,6 +114,14 @@ def run_prediction(
prediction = cells_to_html(prediction) or ""
elif result_format == "dataframe":
prediction = table_cells_to_dataframe(prediction)
elif result_format == "cells":
prediction = prediction
else:
raise ValueError(
f"result_format {result_format} is not a valid format. "
f'Valid formats are: "html", "dataframe", "cells"'
)

return prediction


Expand Down

0 comments on commit 4304c83

Please sign in to comment.