Skip to content

Commit

Permalink
feat: migration from onnx to pytorch script (#37)
Browse files Browse the repository at this point in the history

---------

Signed-off-by: Ahmed <[email protected]>
Signed-off-by: Maxim Lysak <[email protected]>
Co-authored-by: Ahmed <[email protected]>
Co-authored-by: Maxim Lysak <[email protected]>
  • Loading branch information
3 people authored Oct 3, 2024
1 parent 2a8be46 commit 59e2941
Show file tree
Hide file tree
Showing 4 changed files with 705 additions and 457 deletions.
78 changes: 31 additions & 47 deletions docling_ibm_models/layoutmodel/layout_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from typing import Union

import numpy as np
import onnxruntime as ort
import torch
import torchvision.transforms as T
from PIL import Image

MODEL_CHECKPOINT_FN = "model.pt"
Expand All @@ -16,14 +17,14 @@

class LayoutPredictor:
r"""
Document layout prediction using ONNX
Document layout prediction using torch
"""

def __init__(
self, artifact_path: str, num_threads: int = None, use_cpu_only: bool = False
):
r"""
Provide the artifact path that contains the LayoutModel ONNX file
Provide the artifact path that contains the LayoutModel file
The number of threads is decided, in the following order, by:
1. The init method parameter `num_threads`, if it is set.
Expand All @@ -38,13 +39,13 @@ def __init__(
Parameters
----------
artifact_path: Path for the model ONNX file.
artifact_path: Path for the model torch file.
num_threads: (Optional) Number of threads to run the inference.
use_cpu_only: (Optional) If True, it forces CPU as the execution provider.
Raises
------
FileNotFoundError when the model's ONNX file is missing
FileNotFoundError when the model's torch file is missing
"""
# Initialize classes map:
self._classes_map = {
Expand Down Expand Up @@ -75,46 +76,27 @@ def __init__(
self._threshold = 0.6 # Score threshold
self._image_size = 640
self._size = np.asarray([[self._image_size, self._image_size]], dtype=np.int64)
self._use_cpu_only = use_cpu_only or ("USE_CPU_ONLY" in os.environ)

# Model file
self._torch_fn = os.path.join(artifact_path, MODEL_CHECKPOINT_FN)
if not os.path.isfile(self._torch_fn):
raise FileNotFoundError("Missing torch file: {}".format(self._torch_fn))

# Get env vars
self._use_cpu_only = use_cpu_only or ("USE_CPU_ONLY" in os.environ)
if num_threads is None:
num_threads = int(os.environ.get("OMP_NUM_THREADS", DEFAULT_NUM_THREADS))
self._num_threads = num_threads

# Decide the execution providers
if (
not self._use_cpu_only
and "CUDAExecutionProvider" in ort.get_available_providers()
):
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
else:
providers = ["CPUExecutionProvider"]
self._providers = providers

# Model ONNX file
self._onnx_fn = os.path.join(artifact_path, MODEL_CHECKPOINT_FN)
if not os.path.isfile(self._onnx_fn):
raise FileNotFoundError("Missing ONNX file: {}".format(self._onnx_fn))

# ONNX options
self._options = ort.SessionOptions()
self._options.intra_op_num_threads = self._num_threads
self.sess = ort.InferenceSession(
self._onnx_fn,
sess_options=self._options,
providers=self._providers,
)
self.model = torch.jit.load(self._torch_fn)

def info(self) -> dict:
r"""
Get information about the configuration of LayoutPredictor
"""
info = {
"onnx_file": self._onnx_fn,
"intra_op_num_threads": self._num_threads,
"torch_file": self._torch_fn,
"use_cpu_only": self._use_cpu_only,
"providers": self._providers,
"image_size": self._image_size,
"threshold": self._threshold,
}
Expand Down Expand Up @@ -147,33 +129,35 @@ def predict(self, orig_img: Union[Image.Image, np.ndarray]) -> Iterable[dict]:
raise TypeError("Not supported input image format")

w, h = page_img.size
page_img = page_img.resize((self._image_size, self._image_size))
page_data = np.array(page_img, dtype=np.uint8) / np.float32(255.0)
page_data = np.expand_dims(np.transpose(page_data, axes=[2, 0, 1]), axis=0)
orig_size = torch.tensor([w, h])[None]

# Predict
labels, boxes, scores = self.sess.run(
output_names=None,
input_feed={
"images": page_data,
"orig_target_sizes": self._size,
},
transforms = T.Compose(
[
T.Resize((640, 640)),
T.ToTensor(),
]
)
img = transforms(page_img)[None]
# Predict
with torch.no_grad():
labels, boxes, scores = self.model(img, orig_size)

# Yield output
for label_idx, box, score in zip(labels[0], boxes[0], scores[0]):
# Filter out blacklisted classes
label = self._classes_map[label_idx]
label_idx = int(label_idx.item())
score = float(score.item())
label = self._classes_map[label_idx + 1]
if label in self._black_classes:
continue

# Check against threshold
if score > self._threshold:
yield {
"l": box[0] / self._image_size * w,
"t": box[1] / self._image_size * h,
"r": box[2] / self._image_size * w,
"b": box[3] / self._image_size * h,
"l": box[0],
"t": box[1],
"r": box[2],
"b": box[3],
"label": label,
"confidence": score,
}
Loading

0 comments on commit 59e2941

Please sign in to comment.