From 45ebc1f0758dafb725d781ad86f405583a72cf39 Mon Sep 17 00:00:00 2001 From: Steve Canny Date: Thu, 18 Apr 2024 20:28:10 -0700 Subject: [PATCH] packaging: update onnxruntime pin (#337) **Summary** Update pin on onnxruntime dependency. The prior constraint is no longer necessary and now produces conflicts on some installations. --- CHANGELOG.md | 4 + pyproject.toml | 2 + requirements/base.in | 3 +- requirements/base.txt | 93 +++++----- requirements/dev.txt | 190 ++++++++++---------- requirements/test.txt | 61 +++---- unstructured_inference/__version__.py | 2 +- unstructured_inference/inference/layout.py | 4 +- unstructured_inference/models/base.py | 26 ++- unstructured_inference/models/chipper.py | 8 +- unstructured_inference/models/detectron2.py | 6 +- unstructured_inference/models/donut.py | 6 +- unstructured_inference/models/tables.py | 14 +- unstructured_inference/models/yolox.py | 6 +- 14 files changed, 218 insertions(+), 207 deletions(-) create mode 100644 pyproject.toml diff --git a/CHANGELOG.md b/CHANGELOG.md index 78ad2510..ef93fc15 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +## 0.7.27 + +* fix: remove pin from `onnxruntime` dependency. + ## 0.7.26 * feat: add a set of new `ElementType`s to extend future element types recognition diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..aa4949aa --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,2 @@ +[tool.black] +line-length = 100 diff --git a/requirements/base.in b/requirements/base.in index 5f4a8f5d..9f321b26 100644 --- a/requirements/base.in +++ b/requirements/base.in @@ -4,8 +4,7 @@ python-multipart huggingface-hub opencv-python!=4.7.0.68 onnx -# NOTE(benjamin): Pinned because onnxruntime changed the way quantization is done, and we need to update our code to support it -onnxruntime<1.16 +onnxruntime>=1.17.0 # NOTE(alan): Pinned because this is when the most recent module we import appeared transformers>=4.25.1 rapidfuzz diff --git a/requirements/base.txt b/requirements/base.txt index 28857d7a..e4512b00 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -1,12 +1,12 @@ # -# This file is autogenerated by pip-compile with Python 3.8 +# This file is autogenerated by pip-compile with Python 3.9 # by the following command: # # pip-compile requirements/base.in # antlr4-python3-runtime==4.9.3 # via omegaconf -certifi==2023.7.22 +certifi==2024.2.2 # via requests cffi==1.16.0 # via cryptography @@ -16,28 +16,28 @@ charset-normalizer==3.3.2 # requests coloredlogs==15.0.1 # via onnxruntime -contourpy==1.1.1 +contourpy==1.2.1 # via matplotlib -cryptography==41.0.5 +cryptography==42.0.5 # via pdfminer-six cycler==0.12.1 # via matplotlib effdet==0.4.1 # via layoutparser -filelock==3.13.1 +filelock==3.13.4 # via # huggingface-hub # torch # transformers -flatbuffers==23.5.26 +flatbuffers==24.3.25 # via onnxruntime -fonttools==4.44.3 +fonttools==4.51.0 # via matplotlib -fsspec==2023.10.0 +fsspec==2024.3.1 # via # huggingface-hub # torch -huggingface-hub==0.19.4 +huggingface-hub==0.22.2 # via # -r requirements/base.in # timm @@ -45,27 +45,27 @@ huggingface-hub==0.19.4 # transformers humanfriendly==10.0 # via coloredlogs -idna==3.4 +idna==3.7 # via requests -importlib-resources==6.1.1 +importlib-resources==6.4.0 # via matplotlib iopath==0.1.10 # via layoutparser -jinja2==3.1.2 +jinja2==3.1.3 # via torch kiwisolver==1.4.5 # via matplotlib layoutparser[layoutmodels,tesseract]==0.3.4 # via -r requirements/base.in -markupsafe==2.1.3 +markupsafe==2.1.5 # via jinja2 -matplotlib==3.7.3 +matplotlib==3.8.4 # via pycocotools mpmath==1.3.0 # via sympy -networkx==3.1 +networkx==3.2.1 # via torch -numpy==1.24.4 +numpy==1.26.4 # via # contourpy # layoutparser @@ -80,30 +80,30 @@ numpy==1.24.4 # transformers omegaconf==2.3.0 # via effdet -onnx==1.15.0 +onnx==1.16.0 # via -r requirements/base.in -onnxruntime==1.15.1 +onnxruntime==1.17.3 # via -r requirements/base.in -opencv-python==4.8.1.78 +opencv-python==4.9.0.80 # via # -r requirements/base.in # layoutparser -packaging==23.2 +packaging==24.0 # via # huggingface-hub # matplotlib # onnxruntime # pytesseract # transformers -pandas==2.0.3 +pandas==2.2.2 # via layoutparser -pdf2image==1.16.3 +pdf2image==1.17.0 # via layoutparser -pdfminer-six==20221105 +pdfminer-six==20231228 # via pdfplumber -pdfplumber==0.10.3 +pdfplumber==0.11.0 # via layoutparser -pillow==10.1.0 +pillow==10.3.0 # via # layoutparser # matplotlib @@ -113,27 +113,27 @@ pillow==10.1.0 # torchvision portalocker==2.8.2 # via iopath -protobuf==4.25.1 +protobuf==5.26.1 # via # onnx # onnxruntime pycocotools==2.0.7 # via effdet -pycparser==2.21 +pycparser==2.22 # via cffi -pyparsing==3.1.1 +pyparsing==3.1.2 # via matplotlib -pypdfium2==4.24.0 +pypdfium2==4.29.0 # via pdfplumber pytesseract==0.3.10 # via layoutparser -python-dateutil==2.8.2 +python-dateutil==2.9.0.post0 # via # matplotlib # pandas -python-multipart==0.0.6 +python-multipart==0.0.9 # via -r requirements/base.in -pytz==2023.3.post1 +pytz==2024.1 # via pandas pyyaml==6.0.1 # via @@ -142,20 +142,19 @@ pyyaml==6.0.1 # omegaconf # timm # transformers -rapidfuzz==3.5.2 +rapidfuzz==3.8.1 # via -r requirements/base.in -regex==2023.10.3 +regex==2024.4.16 # via transformers requests==2.31.0 # via # huggingface-hub - # torchvision # transformers -safetensors==0.4.0 +safetensors==0.4.3 # via # timm # transformers -scipy==1.10.1 +scipy==1.13.0 # via layoutparser six==1.16.0 # via python-dateutil @@ -163,36 +162,36 @@ sympy==1.12 # via # onnxruntime # torch -timm==0.9.10 +timm==0.9.16 # via effdet -tokenizers==0.15.0 +tokenizers==0.19.1 # via transformers -torch==2.1.1 +torch==2.2.2 # via # effdet # layoutparser # timm # torchvision -torchvision==0.16.1 +torchvision==0.17.2 # via # effdet # layoutparser # timm -tqdm==4.66.1 +tqdm==4.66.2 # via # huggingface-hub # iopath # transformers -transformers==4.35.2 +transformers==4.40.0 # via -r requirements/base.in -typing-extensions==4.8.0 +typing-extensions==4.11.0 # via # huggingface-hub # iopath # torch -tzdata==2023.3 +tzdata==2024.1 # via pandas -urllib3==2.1.0 +urllib3==2.2.1 # via requests -zipp==3.17.0 +zipp==3.18.1 # via importlib-resources diff --git a/requirements/dev.txt b/requirements/dev.txt index 8f45a80c..b065f8f4 100644 --- a/requirements/dev.txt +++ b/requirements/dev.txt @@ -1,17 +1,16 @@ # -# This file is autogenerated by pip-compile with Python 3.8 +# This file is autogenerated by pip-compile with Python 3.9 # by the following command: # # pip-compile requirements/dev.in # -anyio==4.0.0 +anyio==4.3.0 # via # -c requirements/test.txt + # httpx # jupyter-server -appnope==0.1.3 - # via - # ipykernel - # ipython +appnope==0.1.4 + # via ipykernel argon2-cffi==23.1.0 # via jupyter-server argon2-cffi-bindings==21.2.0 @@ -22,24 +21,24 @@ asttokens==2.4.1 # via stack-data async-lru==2.0.4 # via jupyterlab -attrs==23.1.0 +attrs==23.2.0 # via # jsonschema # referencing -babel==2.13.1 +babel==2.14.0 # via jupyterlab-server -backcall==0.2.0 - # via ipython -beautifulsoup4==4.12.2 +beautifulsoup4==4.12.3 # via nbconvert bleach==6.1.0 # via nbconvert -build==1.0.3 +build==1.2.1 # via pip-tools -certifi==2023.7.22 +certifi==2024.2.2 # via # -c requirements/base.txt # -c requirements/test.txt + # httpcore + # httpx # requests cffi==1.16.0 # via @@ -54,11 +53,11 @@ click==8.1.7 # via # -c requirements/test.txt # pip-tools -comm==0.2.0 +comm==0.2.2 # via # ipykernel # ipywidgets -contourpy==1.1.1 +contourpy==1.2.1 # via # -c requirements/base.txt # matplotlib @@ -66,34 +65,48 @@ cycler==0.12.1 # via # -c requirements/base.txt # matplotlib -debugpy==1.8.0 +debugpy==1.8.1 # via ipykernel decorator==5.1.1 # via ipython defusedxml==0.7.1 # via nbconvert -exceptiongroup==1.1.3 +exceptiongroup==1.2.1 # via # -c requirements/test.txt # anyio + # ipython executing==2.0.1 # via stack-data -fastjsonschema==2.19.0 +fastjsonschema==2.19.1 # via nbformat -fonttools==4.44.3 +fonttools==4.51.0 # via # -c requirements/base.txt # matplotlib fqdn==1.5.1 # via jsonschema -idna==3.4 +h11==0.14.0 + # via + # -c requirements/test.txt + # httpcore +httpcore==1.0.5 + # via + # -c requirements/test.txt + # httpx +httpx==0.27.0 + # via + # -c requirements/test.txt + # jupyterlab +idna==3.7 # via # -c requirements/base.txt # -c requirements/test.txt # anyio + # httpx # jsonschema # requests -importlib-metadata==6.8.0 +importlib-metadata==7.1.0 # via # build # jupyter-client @@ -101,52 +114,49 @@ importlib-metadata==6.8.0 # jupyterlab # jupyterlab-server # nbconvert -importlib-resources==6.1.1 +importlib-resources==6.4.0 # via # -c requirements/base.txt - # jsonschema - # jsonschema-specifications - # jupyterlab # matplotlib -ipykernel==6.26.0 +ipykernel==6.29.4 # via # jupyter # jupyter-console # jupyterlab # qtconsole -ipython==8.12.3 +ipython==8.18.1 # via # -r requirements/dev.in # ipykernel # ipywidgets # jupyter-console -ipywidgets==8.1.1 +ipywidgets==8.1.2 # via jupyter isoduration==20.11.0 # via jsonschema jedi==0.19.1 # via ipython -jinja2==3.1.2 +jinja2==3.1.3 # via # -c requirements/base.txt # jupyter-server # jupyterlab # jupyterlab-server # nbconvert -json5==0.9.14 +json5==0.9.25 # via jupyterlab-server jsonpointer==2.4 # via jsonschema -jsonschema[format-nongpl]==4.20.0 +jsonschema[format-nongpl]==4.21.1 # via # jupyter-events # jupyterlab-server # nbformat -jsonschema-specifications==2023.11.1 +jsonschema-specifications==2023.12.1 # via jsonschema jupyter==1.0.0 # via -r requirements/dev.in -jupyter-client==8.6.0 +jupyter-client==8.6.1 # via # ipykernel # jupyter-console @@ -155,7 +165,7 @@ jupyter-client==8.6.0 # qtconsole jupyter-console==6.6.3 # via jupyter -jupyter-core==5.5.0 +jupyter-core==5.7.2 # via # ipykernel # jupyter-client @@ -166,75 +176,75 @@ jupyter-core==5.5.0 # nbconvert # nbformat # qtconsole -jupyter-events==0.9.0 +jupyter-events==0.10.0 # via jupyter-server -jupyter-lsp==2.2.0 +jupyter-lsp==2.2.5 # via jupyterlab -jupyter-server==2.10.1 +jupyter-server==2.14.0 # via # jupyter-lsp # jupyterlab # jupyterlab-server # notebook # notebook-shim -jupyter-server-terminals==0.4.4 +jupyter-server-terminals==0.5.3 # via jupyter-server -jupyterlab==4.0.8 +jupyterlab==4.1.6 # via notebook -jupyterlab-pygments==0.2.2 +jupyterlab-pygments==0.3.0 # via nbconvert -jupyterlab-server==2.25.1 +jupyterlab-server==2.26.0 # via # jupyterlab # notebook -jupyterlab-widgets==3.0.9 +jupyterlab-widgets==3.0.10 # via ipywidgets kiwisolver==1.4.5 # via # -c requirements/base.txt # matplotlib -markupsafe==2.1.3 +markupsafe==2.1.5 # via # -c requirements/base.txt # jinja2 # nbconvert -matplotlib==3.7.3 +matplotlib==3.8.4 # via # -c requirements/base.txt # -r requirements/dev.in -matplotlib-inline==0.1.6 +matplotlib-inline==0.1.7 # via # ipykernel # ipython mistune==3.0.2 # via nbconvert -nbclient==0.9.0 +nbclient==0.10.0 # via nbconvert -nbconvert==7.11.0 +nbconvert==7.16.3 # via # jupyter # jupyter-server -nbformat==5.9.2 +nbformat==5.10.4 # via # jupyter-server # nbclient # nbconvert -nest-asyncio==1.5.8 +nest-asyncio==1.6.0 # via ipykernel -notebook==7.0.6 +notebook==7.1.3 # via jupyter -notebook-shim==0.2.3 +notebook-shim==0.2.4 # via # jupyterlab # notebook -numpy==1.24.4 +numpy==1.26.4 # via # -c requirements/base.txt # contourpy # matplotlib -overrides==7.4.0 +overrides==7.7.0 # via jupyter-server -packaging==23.2 +packaging==24.0 # via # -c requirements/base.txt # -c requirements/test.txt @@ -247,34 +257,30 @@ packaging==23.2 # nbconvert # qtconsole # qtpy -pandocfilters==1.5.0 +pandocfilters==1.5.1 # via nbconvert -parso==0.8.3 +parso==0.8.4 # via jedi -pexpect==4.8.0 - # via ipython -pickleshare==0.7.5 +pexpect==4.9.0 # via ipython -pillow==10.1.0 +pillow==10.3.0 # via # -c requirements/base.txt # -c requirements/test.txt # matplotlib -pip-tools==7.3.0 +pip-tools==7.4.1 # via -r requirements/dev.in -pkgutil-resolve-name==1.3.10 - # via jsonschema -platformdirs==4.0.0 +platformdirs==4.2.0 # via # -c requirements/test.txt # jupyter-core -prometheus-client==0.18.0 +prometheus-client==0.20.0 # via jupyter-server -prompt-toolkit==3.0.41 +prompt-toolkit==3.0.43 # via # ipython # jupyter-console -psutil==5.9.6 +psutil==5.9.8 # via ipykernel ptyprocess==0.7.0 # via @@ -282,23 +288,25 @@ ptyprocess==0.7.0 # terminado pure-eval==0.2.2 # via stack-data -pycparser==2.21 +pycparser==2.22 # via # -c requirements/base.txt # cffi -pygments==2.16.1 +pygments==2.17.2 # via # ipython # jupyter-console # nbconvert # qtconsole -pyparsing==3.1.1 +pyparsing==3.1.2 # via # -c requirements/base.txt # matplotlib pyproject-hooks==1.0.0 - # via build -python-dateutil==2.8.2 + # via + # build + # pip-tools +python-dateutil==2.9.0.post0 # via # -c requirements/base.txt # arrow @@ -306,16 +314,12 @@ python-dateutil==2.8.2 # matplotlib python-json-logger==2.0.7 # via jupyter-events -pytz==2023.3.post1 - # via - # -c requirements/base.txt - # babel pyyaml==6.0.1 # via # -c requirements/base.txt # -c requirements/test.txt # jupyter-events -pyzmq==25.1.1 +pyzmq==26.0.0 # via # ipykernel # jupyter-client @@ -326,7 +330,7 @@ qtconsole==5.5.1 # via jupyter qtpy==2.4.1 # via qtconsole -referencing==0.31.0 +referencing==0.34.0 # via # jsonschema # jsonschema-specifications @@ -344,11 +348,11 @@ rfc3986-validator==0.1.1 # via # jsonschema # jupyter-events -rpds-py==0.13.0 +rpds-py==0.18.0 # via # jsonschema # referencing -send2trash==1.8.2 +send2trash==1.8.3 # via jupyter-server six==1.16.0 # via @@ -357,15 +361,16 @@ six==1.16.0 # bleach # python-dateutil # rfc3339-validator -sniffio==1.3.0 +sniffio==1.3.1 # via # -c requirements/test.txt # anyio + # httpx soupsieve==2.5 # via beautifulsoup4 stack-data==0.6.3 # via ipython -terminado==0.18.0 +terminado==0.18.1 # via # jupyter-server # jupyter-server-terminals @@ -378,7 +383,7 @@ tomli==2.0.1 # jupyterlab # pip-tools # pyproject-hooks -tornado==6.3.3 +tornado==6.4 # via # ipykernel # jupyter-client @@ -386,7 +391,7 @@ tornado==6.3.3 # jupyterlab # notebook # terminado -traitlets==5.13.0 +traitlets==5.14.2 # via # comm # ipykernel @@ -403,22 +408,23 @@ traitlets==5.13.0 # nbconvert # nbformat # qtconsole -types-python-dateutil==2.8.19.14 +types-python-dateutil==2.9.0.20240316 # via arrow -typing-extensions==4.8.0 +typing-extensions==4.11.0 # via # -c requirements/base.txt # -c requirements/test.txt + # anyio # async-lru # ipython uri-template==1.3.0 # via jsonschema -urllib3==2.1.0 +urllib3==2.2.1 # via # -c requirements/base.txt # -c requirements/test.txt # requests -wcwidth==0.2.10 +wcwidth==0.2.13 # via prompt-toolkit webcolors==1.13 # via jsonschema @@ -426,13 +432,13 @@ webencodings==0.5.1 # via # bleach # tinycss2 -websocket-client==1.6.4 +websocket-client==1.7.0 # via jupyter-server -wheel==0.41.3 +wheel==0.43.0 # via pip-tools -widgetsnbextension==4.0.9 +widgetsnbextension==4.0.10 # via ipywidgets -zipp==3.17.0 +zipp==3.18.1 # via # -c requirements/base.txt # importlib-metadata diff --git a/requirements/test.txt b/requirements/test.txt index 4e5a2adb..864b5fe3 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -1,14 +1,14 @@ # -# This file is autogenerated by pip-compile with Python 3.8 +# This file is autogenerated by pip-compile with Python 3.9 # by the following command: # # pip-compile requirements/test.in # -anyio==4.0.0 +anyio==4.3.0 # via httpx -black==23.11.0 +black==24.4.0 # via -r requirements/test.in -certifi==2023.7.22 +certifi==2024.2.2 # via # -c requirements/base.txt # httpcore @@ -22,39 +22,39 @@ click==8.1.7 # via # -r requirements/test.in # black -coverage[toml]==7.3.2 +coverage[toml]==7.4.4 # via # -r requirements/test.in # pytest-cov -exceptiongroup==1.1.3 +exceptiongroup==1.2.1 # via # anyio # pytest -filelock==3.13.1 +filelock==3.13.4 # via # -c requirements/base.txt # huggingface-hub -flake8==6.1.0 +flake8==7.0.0 # via # -r requirements/test.in # flake8-docstrings flake8-docstrings==1.7.0 # via -r requirements/test.in -fsspec==2023.10.0 +fsspec==2024.3.1 # via # -c requirements/base.txt # huggingface-hub h11==0.14.0 # via httpcore -httpcore==1.0.2 +httpcore==1.0.5 # via httpx -httpx==0.25.1 +httpx==0.27.0 # via -r requirements/test.in -huggingface-hub==0.19.4 +huggingface-hub==0.22.2 # via # -c requirements/base.txt # -r requirements/test.in -idna==3.4 +idna==3.7 # via # -c requirements/base.txt # anyio @@ -64,45 +64,45 @@ iniconfig==2.0.0 # via pytest mccabe==0.7.0 # via flake8 -mypy==1.7.0 +mypy==1.9.0 # via -r requirements/test.in mypy-extensions==1.0.0 # via # black # mypy -packaging==23.2 +packaging==24.0 # via # -c requirements/base.txt # black # huggingface-hub # pytest -pathspec==0.11.2 +pathspec==0.12.1 # via black -pdf2image==1.16.3 +pdf2image==1.17.0 # via # -c requirements/base.txt # -r requirements/test.in -pillow==10.1.0 +pillow==10.3.0 # via # -c requirements/base.txt # pdf2image -platformdirs==4.0.0 +platformdirs==4.2.0 # via black -pluggy==1.3.0 +pluggy==1.4.0 # via pytest pycodestyle==2.11.1 # via flake8 pydocstyle==6.3.0 # via flake8-docstrings -pyflakes==3.1.0 +pyflakes==3.2.0 # via flake8 -pytest==7.4.3 +pytest==8.1.1 # via # pytest-cov # pytest-mock -pytest-cov==4.1.0 +pytest-cov==5.0.0 # via -r requirements/test.in -pytest-mock==3.12.0 +pytest-mock==3.14.0 # via -r requirements/test.in pyyaml==6.0.1 # via @@ -112,9 +112,9 @@ requests==2.31.0 # via # -c requirements/base.txt # huggingface-hub -ruff==0.1.5 +ruff==0.4.0 # via -r requirements/test.in -sniffio==1.3.0 +sniffio==1.3.1 # via # anyio # httpx @@ -126,19 +126,20 @@ tomli==2.0.1 # coverage # mypy # pytest -tqdm==4.66.1 +tqdm==4.66.2 # via # -c requirements/base.txt # huggingface-hub -types-pyyaml==6.0.12.12 +types-pyyaml==6.0.12.20240311 # via -r requirements/test.in -typing-extensions==4.8.0 +typing-extensions==4.11.0 # via # -c requirements/base.txt + # anyio # black # huggingface-hub # mypy -urllib3==2.1.0 +urllib3==2.2.1 # via # -c requirements/base.txt # requests diff --git a/unstructured_inference/__version__.py b/unstructured_inference/__version__.py index 04863d10..023f1e0a 100644 --- a/unstructured_inference/__version__.py +++ b/unstructured_inference/__version__.py @@ -1 +1 @@ -__version__ = "0.7.26" # pragma: no cover +__version__ = "0.7.27" # pragma: no cover diff --git a/unstructured_inference/inference/layout.py b/unstructured_inference/inference/layout.py index 27c3eefe..a22d68b8 100644 --- a/unstructured_inference/inference/layout.py +++ b/unstructured_inference/inference/layout.py @@ -140,7 +140,7 @@ def __init__( ): if detection_model is not None and element_extraction_model is not None: raise ValueError("Only one of detection_model and extraction_model should be passed.") - self.image = image + self.image: Optional[Image.Image] = image if image_metadata is None: image_metadata = {} self.image_metadata = image_metadata @@ -167,6 +167,7 @@ def get_elements_using_image_extraction( raise ValueError( "Cannot get elements using image extraction, no image extraction model defined", ) + assert self.image is not None elements = self.element_extraction_model(self.image) if inplace: self.elements = elements @@ -188,6 +189,7 @@ def get_elements_with_detection_model( # NOTE(mrobinson) - We'll want make this model inference step some kind of # remote call in the future. + assert self.image is not None inferred_layout: List[LayoutElement] = self.detection_model(self.image) inferred_layout = self.detection_model.deduplicate_detected_elements( inferred_layout, diff --git a/unstructured_inference/models/base.py b/unstructured_inference/models/base.py index 59c3c704..fd8eaded 100644 --- a/unstructured_inference/models/base.py +++ b/unstructured_inference/models/base.py @@ -6,13 +6,9 @@ from unstructured_inference.models.chipper import MODEL_TYPES as CHIPPER_MODEL_TYPES from unstructured_inference.models.chipper import UnstructuredChipperModel -from unstructured_inference.models.detectron2 import ( - MODEL_TYPES as DETECTRON2_MODEL_TYPES, -) +from unstructured_inference.models.detectron2 import MODEL_TYPES as DETECTRON2_MODEL_TYPES from unstructured_inference.models.detectron2 import UnstructuredDetectronModel -from unstructured_inference.models.detectron2onnx import ( - MODEL_TYPES as DETECTRON2_ONNX_MODEL_TYPES, -) +from unstructured_inference.models.detectron2onnx import MODEL_TYPES as DETECTRON2_ONNX_MODEL_TYPES from unstructured_inference.models.detectron2onnx import UnstructuredDetectronONNXModel from unstructured_inference.models.unstructuredmodel import UnstructuredModel from unstructured_inference.models.yolox import MODEL_TYPES as YOLOX_MODEL_TYPES @@ -24,12 +20,10 @@ models: Dict[str, UnstructuredModel] = {} -def get_default_model_mappings() -> ( - Tuple[ - Dict[str, Type[UnstructuredModel]], - Dict[str, dict | LazyDict], - ] -): +def get_default_model_mappings() -> Tuple[ + Dict[str, Type[UnstructuredModel]], + Dict[str, dict | LazyDict], +]: """default model mappings for models that are in `unstructured_inference` repo""" return { **{name: UnstructuredDetectronModel for name in DETECTRON2_MODEL_TYPES}, @@ -48,8 +42,10 @@ def get_default_model_mappings() -> ( def register_new_model(model_config: dict, model_class: UnstructuredModel): - """registering a new model by updating the model_config_map and model_class_map with the new - model class information""" + """Register this model in model_config_map and model_class_map. + + Those maps are updated with the with the new model class information. + """ model_config_map.update(model_config) model_class_map.update({name: model_class for name in model_config}) @@ -90,6 +86,6 @@ def get_model(model_name: Optional[str] = None) -> UnstructuredModel: class UnknownModelException(Exception): - """Exception for the case where a model is called for with an unrecognized identifier.""" + """A model was requested with an unrecognized identifier.""" pass diff --git a/unstructured_inference/models/chipper.py b/unstructured_inference/models/chipper.py index ed0fd8ad..9128897c 100644 --- a/unstructured_inference/models/chipper.py +++ b/unstructured_inference/models/chipper.py @@ -463,7 +463,7 @@ def get_bounding_box( np.asarray( [ agg_heatmap, - cv2.resize( + cv2.resize( # type: ignore hmap, (final_w, final_h), interpolation=cv2.INTER_LINEAR_EXACT, # cv2.INTER_CUBIC @@ -620,7 +620,7 @@ def reduce_bbox_no_overlap( ): return input_bbox - nimage = np.array(image.crop(input_bbox)) + nimage = np.array(image.crop(input_bbox)) # type: ignore nimage = self.remove_horizontal_lines(nimage) @@ -669,7 +669,7 @@ def reduce_bbox_overlap( ): return input_bbox - nimage = np.array(image.crop(input_bbox)) + nimage = np.array(image.crop(input_bbox)) # type: ignore nimage = self.remove_horizontal_lines(nimage) @@ -773,7 +773,7 @@ def largest_margin( ): return None - nimage = np.array(image.crop(input_bbox)) + nimage = np.array(image.crop(input_bbox)) # type: ignore if nimage.shape[0] * nimage.shape[1] == 0: return None diff --git a/unstructured_inference/models/detectron2.py b/unstructured_inference/models/detectron2.py index 98939f88..c38f6848 100644 --- a/unstructured_inference/models/detectron2.py +++ b/unstructured_inference/models/detectron2.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from pathlib import Path from typing import Any, Dict, Final, List, Optional, Union @@ -7,7 +9,7 @@ is_detectron2_available, ) from layoutparser.models.model_config import LayoutModelConfig -from PIL import Image +from PIL import Image as PILImage from unstructured_inference.constants import ElementType from unstructured_inference.inference.layoutelement import LayoutElement @@ -65,7 +67,7 @@ class UnstructuredDetectronModel(UnstructuredObjectDetectionModel): """Unstructured model wrapper for Detectron2LayoutModel.""" - def predict(self, x: Image): + def predict(self, x: PILImage.Image): """Makes a prediction using detectron2 model.""" super().predict(x) prediction = self.model.detect(x) diff --git a/unstructured_inference/models/donut.py b/unstructured_inference/models/donut.py index 1f753f56..bc60d2c6 100644 --- a/unstructured_inference/models/donut.py +++ b/unstructured_inference/models/donut.py @@ -3,7 +3,7 @@ from typing import Optional, Union import torch -from PIL import Image +from PIL import Image as PILImage from transformers import ( DonutProcessor, VisionEncoderDecoderConfig, @@ -16,7 +16,7 @@ class UnstructuredDonutModel(UnstructuredModel): """Unstructured model wrapper for Donut image transformer.""" - def predict(self, x: Image): + def predict(self, x: PILImage.Image): """Make prediction using donut model""" super().predict(x) return self.run_prediction(x) @@ -50,7 +50,7 @@ def initialize( raise ImportError("Review the parameters to initialize a UnstructuredDonutModel obj") self.model.to(device) - def run_prediction(self, x: Image): + def run_prediction(self, x: PILImage.Image): """Internal prediction method.""" pixel_values = self.processor(x, return_tensors="pt").pixel_values decoder_input_ids = self.processor.tokenizer( diff --git a/unstructured_inference/models/tables.py b/unstructured_inference/models/tables.py index ca7e75a4..b6607812 100644 --- a/unstructured_inference/models/tables.py +++ b/unstructured_inference/models/tables.py @@ -8,7 +8,7 @@ import cv2 import numpy as np import torch -from PIL import Image +from PIL import Image as PILImage from transformers import DetrImageProcessor, TableTransformerForObjectDetection from unstructured_inference.config import inference_config @@ -27,7 +27,7 @@ class UnstructuredTableTransformerModel(UnstructuredModel): def __init__(self): pass - def predict(self, x: Image, ocr_tokens: Optional[List[Dict]] = None): + def predict(self, x: PILImage.Image, ocr_tokens: Optional[List[Dict]] = None): """Predict table structure deferring to run_prediction with ocr tokens Note: @@ -70,7 +70,7 @@ def initialize( def get_structure( self, - x: Image, + x: PILImage.Image, pad_for_structure_detection: int = inference_config.TABLE_IMAGE_BACKGROUND_PAD, ) -> dict: """get the table structure as a dictionary contaning different types of elements as @@ -87,7 +87,7 @@ def get_structure( def run_prediction( self, - x: Image, + x: PILImage.Image, pad_for_structure_detection: int = inference_config.TABLE_IMAGE_BACKGROUND_PAD, ocr_tokens: Optional[List[Dict]] = None, result_format: Optional[str] = "html", @@ -155,7 +155,7 @@ def get_class_map(data_type: str): } -def recognize(outputs: dict, img: Image, tokens: list): +def recognize(outputs: dict, img: PILImage.Image, tokens: list): """Recognize table elements.""" str_class_name2idx = get_class_map("structure") str_class_idx2name = {v: k for k, v in str_class_name2idx.items()} @@ -655,7 +655,7 @@ def cells_to_html(cells): return str(ET.tostring(table, encoding="unicode", short_empty_elements=False)) -def zoom_image(image: Image, zoom: float) -> Image: +def zoom_image(image: PILImage.Image, zoom: float) -> PILImage.Image: """scale an image based on the zoom factor using cv2; the scaled image is post processed by dilation then erosion to improve edge sharpness for OCR tasks""" if zoom <= 0: @@ -673,4 +673,4 @@ def zoom_image(image: Image, zoom: float) -> Image: new_image = cv2.dilate(new_image, kernel, iterations=1) new_image = cv2.erode(new_image, kernel, iterations=1) - return Image.fromarray(new_image) + return PILImage.fromarray(new_image) diff --git a/unstructured_inference/models/yolox.py b/unstructured_inference/models/yolox.py index 47455cf4..852e15b3 100644 --- a/unstructured_inference/models/yolox.py +++ b/unstructured_inference/models/yolox.py @@ -10,7 +10,7 @@ import onnxruntime from huggingface_hub import hf_hub_download from onnxruntime.capi import _pybind_state as C -from PIL import Image +from PIL import Image as PILImage from unstructured_inference.constants import ElementType, Source from unstructured_inference.inference.layoutelement import LayoutElement @@ -60,7 +60,7 @@ class UnstructuredYoloXModel(UnstructuredObjectDetectionModel): - def predict(self, x: Image): + def predict(self, x: PILImage.Image): """Predict using YoloX model.""" super().predict(x) return self.image_processing(x) @@ -86,7 +86,7 @@ def initialize(self, model_path: str, label_map: dict): def image_processing( self, - image: Image = None, + image: PILImage.Image, ) -> List[LayoutElement]: """Method runing YoloX for layout detection, returns a PageLayout parameters