Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Save table prediction in cells format #335

Merged
merged 8 commits into from
Apr 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
## 0.7.28

* feat: allow table transformer agent to return table prediction in not parsed format

## 0.7.27

* fix: remove pin from `onnxruntime` dependency.
Expand Down
19 changes: 19 additions & 0 deletions test_unstructured_inference/models/test_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -932,6 +932,25 @@ def test_table_prediction_output_format(
assert expectation in result


def test_table_prediction_output_format_when_wrong_type_then_value_error(
table_transformer,
example_image,
mocker,
example_table_cells,
mocked_ocr_tokens,
):
mocker.patch.object(tables, "recognize", return_value=example_table_cells)
mocker.patch.object(
tables.UnstructuredTableTransformerModel,
"get_structure",
return_value=None,
)
with pytest.raises(ValueError):
table_transformer.run_prediction(
example_image, result_format="Wrong format", ocr_tokens=mocked_ocr_tokens
)


def test_table_prediction_runs_with_empty_recognize(
table_transformer,
example_image,
Expand Down
2 changes: 1 addition & 1 deletion unstructured_inference/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.7.27" # pragma: no cover
__version__ = "0.7.28" # pragma: no cover
17 changes: 15 additions & 2 deletions unstructured_inference/models/tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,12 @@ class UnstructuredTableTransformerModel(UnstructuredModel):
def __init__(self):
pass

def predict(self, x: PILImage.Image, ocr_tokens: Optional[List[Dict]] = None):
def predict(
self,
x: PILImage.Image,
ocr_tokens: Optional[List[Dict]] = None,
result_format: str = "html",
):
"""Predict table structure deferring to run_prediction with ocr tokens

Note:
Expand All @@ -44,7 +49,7 @@ def predict(self, x: PILImage.Image, ocr_tokens: Optional[List[Dict]] = None):
FIXME: refactor token data into a dataclass so we have clear expectations of the fields
"""
super().predict(x)
return self.run_prediction(x, ocr_tokens=ocr_tokens)
return self.run_prediction(x, ocr_tokens=ocr_tokens, result_format=result_format)

def initialize(
self,
Expand Down Expand Up @@ -109,6 +114,14 @@ def run_prediction(
prediction = cells_to_html(prediction) or ""
elif result_format == "dataframe":
prediction = table_cells_to_dataframe(prediction)
elif result_format == "cells":
prediction = prediction
else:
raise ValueError(
f"result_format {result_format} is not a valid format. "
f'Valid formats are: "html", "dataframe", "cells"'
)

return prediction


Expand Down
Loading