diff --git a/CHANGELOG.md b/CHANGELOG.md index ef93fc15..8b1fee96 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +## 0.7.28 + +* feat: allow table transformer agent to return table prediction in not parsed format + ## 0.7.27 * fix: remove pin from `onnxruntime` dependency. diff --git a/test_unstructured_inference/models/test_tables.py b/test_unstructured_inference/models/test_tables.py index 2df6e157..88bb02bc 100644 --- a/test_unstructured_inference/models/test_tables.py +++ b/test_unstructured_inference/models/test_tables.py @@ -932,6 +932,25 @@ def test_table_prediction_output_format( assert expectation in result +def test_table_prediction_output_format_when_wrong_type_then_value_error( + table_transformer, + example_image, + mocker, + example_table_cells, + mocked_ocr_tokens, +): + mocker.patch.object(tables, "recognize", return_value=example_table_cells) + mocker.patch.object( + tables.UnstructuredTableTransformerModel, + "get_structure", + return_value=None, + ) + with pytest.raises(ValueError): + table_transformer.run_prediction( + example_image, result_format="Wrong format", ocr_tokens=mocked_ocr_tokens + ) + + def test_table_prediction_runs_with_empty_recognize( table_transformer, example_image, diff --git a/unstructured_inference/__version__.py b/unstructured_inference/__version__.py index 023f1e0a..20ef40a7 100644 --- a/unstructured_inference/__version__.py +++ b/unstructured_inference/__version__.py @@ -1 +1 @@ -__version__ = "0.7.27" # pragma: no cover +__version__ = "0.7.28" # pragma: no cover diff --git a/unstructured_inference/models/tables.py b/unstructured_inference/models/tables.py index b6607812..2cf622e0 100644 --- a/unstructured_inference/models/tables.py +++ b/unstructured_inference/models/tables.py @@ -27,7 +27,12 @@ class UnstructuredTableTransformerModel(UnstructuredModel): def __init__(self): pass - def predict(self, x: PILImage.Image, ocr_tokens: Optional[List[Dict]] = None): + def predict( + self, + x: PILImage.Image, + ocr_tokens: Optional[List[Dict]] = None, + result_format: str = "html", + ): """Predict table structure deferring to run_prediction with ocr tokens Note: @@ -44,7 +49,7 @@ def predict(self, x: PILImage.Image, ocr_tokens: Optional[List[Dict]] = None): FIXME: refactor token data into a dataclass so we have clear expectations of the fields """ super().predict(x) - return self.run_prediction(x, ocr_tokens=ocr_tokens) + return self.run_prediction(x, ocr_tokens=ocr_tokens, result_format=result_format) def initialize( self, @@ -109,6 +114,14 @@ def run_prediction( prediction = cells_to_html(prediction) or "" elif result_format == "dataframe": prediction = table_cells_to_dataframe(prediction) + elif result_format == "cells": + prediction = prediction + else: + raise ValueError( + f"result_format {result_format} is not a valid format. " + f'Valid formats are: "html", "dataframe", "cells"' + ) + return prediction