Skip to content

Commit

Permalink
fix: CI mypy errors
Browse files Browse the repository at this point in the history
  • Loading branch information
scanny committed Apr 19, 2024
1 parent 8bb9ffc commit a3e4027
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 20 deletions.
4 changes: 3 additions & 1 deletion unstructured_inference/inference/layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def __init__(
):
if detection_model is not None and element_extraction_model is not None:
raise ValueError("Only one of detection_model and extraction_model should be passed.")
self.image = image
self.image: Optional[Image.Image] = image
if image_metadata is None:
image_metadata = {}
self.image_metadata = image_metadata
Expand All @@ -167,6 +167,7 @@ def get_elements_using_image_extraction(
raise ValueError(
"Cannot get elements using image extraction, no image extraction model defined",
)
assert self.image is not None
elements = self.element_extraction_model(self.image)
if inplace:
self.elements = elements
Expand All @@ -188,6 +189,7 @@ def get_elements_with_detection_model(

# NOTE(mrobinson) - We'll want make this model inference step some kind of
# remote call in the future.
assert self.image is not None
inferred_layout: List[LayoutElement] = self.detection_model(self.image)
inferred_layout = self.detection_model.deduplicate_detected_elements(
inferred_layout,
Expand Down
8 changes: 4 additions & 4 deletions unstructured_inference/models/chipper.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ def get_bounding_box(
np.asarray(
[
agg_heatmap,
cv2.resize(
cv2.resize( # type: ignore
hmap,
(final_w, final_h),
interpolation=cv2.INTER_LINEAR_EXACT, # cv2.INTER_CUBIC
Expand Down Expand Up @@ -620,7 +620,7 @@ def reduce_bbox_no_overlap(
):
return input_bbox

nimage = np.array(image.crop(input_bbox))
nimage = np.array(image.crop(input_bbox)) # type: ignore

nimage = self.remove_horizontal_lines(nimage)

Expand Down Expand Up @@ -669,7 +669,7 @@ def reduce_bbox_overlap(
):
return input_bbox

nimage = np.array(image.crop(input_bbox))
nimage = np.array(image.crop(input_bbox)) # type: ignore

nimage = self.remove_horizontal_lines(nimage)

Expand Down Expand Up @@ -773,7 +773,7 @@ def largest_margin(
):
return None

nimage = np.array(image.crop(input_bbox))
nimage = np.array(image.crop(input_bbox)) # type: ignore

if nimage.shape[0] * nimage.shape[1] == 0:
return None
Expand Down
6 changes: 4 additions & 2 deletions unstructured_inference/models/detectron2.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from pathlib import Path
from typing import Any, Dict, Final, List, Optional, Union

Expand All @@ -7,7 +9,7 @@
is_detectron2_available,
)
from layoutparser.models.model_config import LayoutModelConfig
from PIL import Image
from PIL import Image as PILImage

from unstructured_inference.constants import ElementType
from unstructured_inference.inference.layoutelement import LayoutElement
Expand Down Expand Up @@ -65,7 +67,7 @@
class UnstructuredDetectronModel(UnstructuredObjectDetectionModel):
"""Unstructured model wrapper for Detectron2LayoutModel."""

def predict(self, x: Image):
def predict(self, x: PILImage.Image):
"""Makes a prediction using detectron2 model."""
super().predict(x)
prediction = self.model.detect(x)
Expand Down
6 changes: 3 additions & 3 deletions unstructured_inference/models/donut.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Optional, Union

import torch
from PIL import Image
from PIL import Image as PILImage
from transformers import (
DonutProcessor,
VisionEncoderDecoderConfig,
Expand All @@ -16,7 +16,7 @@
class UnstructuredDonutModel(UnstructuredModel):
"""Unstructured model wrapper for Donut image transformer."""

def predict(self, x: Image):
def predict(self, x: PILImage.Image):
"""Make prediction using donut model"""
super().predict(x)
return self.run_prediction(x)
Expand Down Expand Up @@ -50,7 +50,7 @@ def initialize(
raise ImportError("Review the parameters to initialize a UnstructuredDonutModel obj")
self.model.to(device)

def run_prediction(self, x: Image):
def run_prediction(self, x: PILImage.Image):
"""Internal prediction method."""
pixel_values = self.processor(x, return_tensors="pt").pixel_values
decoder_input_ids = self.processor.tokenizer(
Expand Down
14 changes: 7 additions & 7 deletions unstructured_inference/models/tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import cv2
import numpy as np
import torch
from PIL import Image
from PIL import Image as PILImage
from transformers import DetrImageProcessor, TableTransformerForObjectDetection

from unstructured_inference.config import inference_config
Expand All @@ -27,7 +27,7 @@ class UnstructuredTableTransformerModel(UnstructuredModel):
def __init__(self):
pass

def predict(self, x: Image, ocr_tokens: Optional[List[Dict]] = None):
def predict(self, x: PILImage.Image, ocr_tokens: Optional[List[Dict]] = None):
"""Predict table structure deferring to run_prediction with ocr tokens
Note:
Expand Down Expand Up @@ -70,7 +70,7 @@ def initialize(

def get_structure(
self,
x: Image,
x: PILImage.Image,
pad_for_structure_detection: int = inference_config.TABLE_IMAGE_BACKGROUND_PAD,
) -> dict:
"""get the table structure as a dictionary contaning different types of elements as
Expand All @@ -87,7 +87,7 @@ def get_structure(

def run_prediction(
self,
x: Image,
x: PILImage.Image,
pad_for_structure_detection: int = inference_config.TABLE_IMAGE_BACKGROUND_PAD,
ocr_tokens: Optional[List[Dict]] = None,
result_format: Optional[str] = "html",
Expand Down Expand Up @@ -155,7 +155,7 @@ def get_class_map(data_type: str):
}


def recognize(outputs: dict, img: Image, tokens: list):
def recognize(outputs: dict, img: PILImage.Image, tokens: list):
"""Recognize table elements."""
str_class_name2idx = get_class_map("structure")
str_class_idx2name = {v: k for k, v in str_class_name2idx.items()}
Expand Down Expand Up @@ -655,7 +655,7 @@ def cells_to_html(cells):
return str(ET.tostring(table, encoding="unicode", short_empty_elements=False))


def zoom_image(image: Image, zoom: float) -> Image:
def zoom_image(image: PILImage.Image, zoom: float) -> PILImage.Image:
"""scale an image based on the zoom factor using cv2; the scaled image is post processed by
dilation then erosion to improve edge sharpness for OCR tasks"""
if zoom <= 0:
Expand All @@ -673,4 +673,4 @@ def zoom_image(image: Image, zoom: float) -> Image:
new_image = cv2.dilate(new_image, kernel, iterations=1)
new_image = cv2.erode(new_image, kernel, iterations=1)

return Image.fromarray(new_image)
return PILImage.fromarray(new_image)
6 changes: 3 additions & 3 deletions unstructured_inference/models/yolox.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import onnxruntime
from huggingface_hub import hf_hub_download
from onnxruntime.capi import _pybind_state as C
from PIL import Image
from PIL import Image as PILImage

from unstructured_inference.constants import ElementType, Source
from unstructured_inference.inference.layoutelement import LayoutElement
Expand Down Expand Up @@ -60,7 +60,7 @@


class UnstructuredYoloXModel(UnstructuredObjectDetectionModel):
def predict(self, x: Image):
def predict(self, x: PILImage.Image):
"""Predict using YoloX model."""
super().predict(x)
return self.image_processing(x)
Expand All @@ -86,7 +86,7 @@ def initialize(self, model_path: str, label_map: dict):

def image_processing(
self,
image: Image = None,
image: PILImage.Image,
) -> List[LayoutElement]:
"""Method runing YoloX for layout detection, returns a PageLayout
parameters
Expand Down

0 comments on commit a3e4027

Please sign in to comment.