Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batch-prediction across multiple GPUs and more efficient patch-prediction #48

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Rewrote binarization script to always use patches, but in a much more…
… efficient way and adding support for batch-conversion with multiple GPUs.
apacha committed Aug 30, 2022
commit 4112c6fe71b2226ee872c1d2f4eaff6a62805e62
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -3,3 +3,4 @@ setuptools >= 41
opencv-python-headless
ocrd >= 2.22.3
tensorflow >= 2.4.0
mpire
420 changes: 158 additions & 262 deletions sbb_binarize/sbb_binarize.py
Original file line number Diff line number Diff line change
@@ -1,272 +1,168 @@
"""
Tool to load model and binarize a given image.
"""
import argparse
import sys
from os import environ, devnull
import gc
import itertools
import math
import os
from pathlib import Path
from typing import Union
from typing import Union, List, Any

import cv2
import numpy as np

environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
stderr = sys.stderr
sys.stderr = open(devnull, 'w')
import tensorflow as tf
from tensorflow.keras.models import load_model
from tensorflow.python.keras import backend as tensorflow_backend

sys.stderr = stderr

import logging


def resize_image(img_in, input_height, input_width):
return cv2.resize(img_in, (input_width, input_height), interpolation=cv2.INTER_NEAREST)
from mpire import WorkerPool
from mpire.utils import make_single_arguments
from tensorflow.python.keras.saving.save import load_model


class SbbBinarizer:

def __init__(self, model_dir: Union[str, Path], logger=None):
def __init__(self) -> None:
super().__init__()
self.model: Any = None
self.model_height: int = 0
self.model_width: int = 0
self.n_classes: int = 0

def load_model(self, model_dir: Union[str, Path]):
model_dir = Path(model_dir)
self.log = logger if logger else logging.getLogger('SbbBinarizer')

self.start_new_session()

self.model_files = list([str(p.absolute()) for p in model_dir.rglob("*.h5")])
if not self.model_files:
raise ValueError(f"No models found in {str(model_dir)}")

self.models = []
for model_file in self.model_files:
self.models.append(self.load_model(model_file))

def start_new_session(self):
config = tf.compat.v1.ConfigProto()
config.gpu_options.allow_growth = True

self.session = tf.compat.v1.Session(config=config) # tf.InteractiveSession()
tensorflow_backend.set_session(self.session)

def end_session(self):
tensorflow_backend.clear_session()
self.session.close()
del self.session

def load_model(self, model_path: str):
model = load_model(model_path, compile=False)
model_height = model.layers[len(model.layers) - 1].output_shape[1]
model_width = model.layers[len(model.layers) - 1].output_shape[2]
n_classes = model.layers[len(model.layers) - 1].output_shape[3]
return model, model_height, model_width, n_classes

def predict(self, model_in, img, use_patches):
tensorflow_backend.set_session(self.session)
model, model_height, model_width, n_classes = model_in

img_org_h = img.shape[0]
img_org_w = img.shape[1]

if img.shape[0] < model_height and img.shape[1] >= model_width:
img_padded = np.zeros((model_height, img.shape[1], img.shape[2]))

index_start_h = int(abs(img.shape[0] - model_height) / 2.)
index_start_w = 0

img_padded[index_start_h: index_start_h + img.shape[0], :, :] = img[:, :, :]

elif img.shape[0] >= model_height and img.shape[1] < model_width:
img_padded = np.zeros((img.shape[0], model_width, img.shape[2]))

index_start_h = 0
index_start_w = int(abs(img.shape[1] - model_width) / 2.)

img_padded[:, index_start_w: index_start_w + img.shape[1], :] = img[:, :, :]


elif img.shape[0] < model_height and img.shape[1] < model_width:
img_padded = np.zeros((model_height, model_width, img.shape[2]))

index_start_h = int(abs(img.shape[0] - model_height) / 2.)
index_start_w = int(abs(img.shape[1] - model_width) / 2.)

img_padded[index_start_h: index_start_h + img.shape[0], index_start_w: index_start_w + img.shape[1], :] = img[:, :, :]

else:
index_start_h = 0
index_start_w = 0
img_padded = np.copy(img)

img = np.copy(img_padded)

if use_patches:

margin = int(0.1 * model_width)

width_mid = model_width - 2 * margin
height_mid = model_height - 2 * margin

img = img / float(255.0)

img_h = img.shape[0]
img_w = img.shape[1]

prediction_true = np.zeros((img_h, img_w, 3))
mask_true = np.zeros((img_h, img_w))
nxf = img_w / float(width_mid)
nyf = img_h / float(height_mid)

if nxf > int(nxf):
nxf = int(nxf) + 1
else:
nxf = int(nxf)

if nyf > int(nyf):
nyf = int(nyf) + 1
else:
nyf = int(nyf)

for i in range(nxf):
for j in range(nyf):

if i == 0:
index_x_d = i * width_mid
index_x_u = index_x_d + model_width
elif i > 0:
index_x_d = i * width_mid
index_x_u = index_x_d + model_width

if j == 0:
index_y_d = j * height_mid
index_y_u = index_y_d + model_height
elif j > 0:
index_y_d = j * height_mid
index_y_u = index_y_d + model_height

if index_x_u > img_w:
index_x_u = img_w
index_x_d = img_w - model_width
if index_y_u > img_h:
index_y_u = img_h
index_y_d = img_h - model_height

img_patch = img[index_y_d:index_y_u, index_x_d:index_x_u, :]

label_p_pred = model.predict(img_patch.reshape(1, img_patch.shape[0], img_patch.shape[1], img_patch.shape[2]))

seg = np.argmax(label_p_pred, axis=3)[0]

seg_color = np.repeat(seg[:, :, np.newaxis], 3, axis=2)

if i == 0 and j == 0:
seg_color = seg_color[0:seg_color.shape[0] - margin, 0:seg_color.shape[1] - margin, :]
seg = seg[0:seg.shape[0] - margin, 0:seg.shape[1] - margin]

mask_true[index_y_d + 0:index_y_u - margin, index_x_d + 0:index_x_u - margin] = seg
prediction_true[index_y_d + 0:index_y_u - margin, index_x_d + 0:index_x_u - margin, :] = seg_color

elif i == nxf - 1 and j == nyf - 1:
seg_color = seg_color[margin:seg_color.shape[0] - 0, margin:seg_color.shape[1] - 0, :]
seg = seg[margin:seg.shape[0] - 0, margin:seg.shape[1] - 0]

mask_true[index_y_d + margin:index_y_u - 0, index_x_d + margin:index_x_u - 0] = seg
prediction_true[index_y_d + margin:index_y_u - 0, index_x_d + margin:index_x_u - 0, :] = seg_color

elif i == 0 and j == nyf - 1:
seg_color = seg_color[margin:seg_color.shape[0] - 0, 0:seg_color.shape[1] - margin, :]
seg = seg[margin:seg.shape[0] - 0, 0:seg.shape[1] - margin]

mask_true[index_y_d + margin:index_y_u - 0, index_x_d + 0:index_x_u - margin] = seg
prediction_true[index_y_d + margin:index_y_u - 0, index_x_d + 0:index_x_u - margin, :] = seg_color

elif i == nxf - 1 and j == 0:
seg_color = seg_color[0:seg_color.shape[0] - margin, margin:seg_color.shape[1] - 0, :]
seg = seg[0:seg.shape[0] - margin, margin:seg.shape[1] - 0]

mask_true[index_y_d + 0:index_y_u - margin, index_x_d + margin:index_x_u - 0] = seg
prediction_true[index_y_d + 0:index_y_u - margin, index_x_d + margin:index_x_u - 0, :] = seg_color

elif i == 0 and j != 0 and j != nyf - 1:
seg_color = seg_color[margin:seg_color.shape[0] - margin, 0:seg_color.shape[1] - margin, :]
seg = seg[margin:seg.shape[0] - margin, 0:seg.shape[1] - margin]

mask_true[index_y_d + margin:index_y_u - margin, index_x_d + 0:index_x_u - margin] = seg
prediction_true[index_y_d + margin:index_y_u - margin, index_x_d + 0:index_x_u - margin, :] = seg_color

elif i == nxf - 1 and j != 0 and j != nyf - 1:
seg_color = seg_color[margin:seg_color.shape[0] - margin, margin:seg_color.shape[1] - 0, :]
seg = seg[margin:seg.shape[0] - margin, margin:seg.shape[1] - 0]

mask_true[index_y_d + margin:index_y_u - margin, index_x_d + margin:index_x_u - 0] = seg
prediction_true[index_y_d + margin:index_y_u - margin, index_x_d + margin:index_x_u - 0, :] = seg_color

elif i != 0 and i != nxf - 1 and j == 0:
seg_color = seg_color[0:seg_color.shape[0] - margin, margin:seg_color.shape[1] - margin, :]
seg = seg[0:seg.shape[0] - margin, margin:seg.shape[1] - margin]

mask_true[index_y_d + 0:index_y_u - margin, index_x_d + margin:index_x_u - margin] = seg
prediction_true[index_y_d + 0:index_y_u - margin, index_x_d + margin:index_x_u - margin, :] = seg_color

elif i != 0 and i != nxf - 1 and j == nyf - 1:
seg_color = seg_color[margin:seg_color.shape[0] - 0, margin:seg_color.shape[1] - margin, :]
seg = seg[margin:seg.shape[0] - 0, margin:seg.shape[1] - margin]

mask_true[index_y_d + margin:index_y_u - 0, index_x_d + margin:index_x_u - margin] = seg
prediction_true[index_y_d + margin:index_y_u - 0, index_x_d + margin:index_x_u - margin, :] = seg_color

else:
seg_color = seg_color[margin:seg_color.shape[0] - margin, margin:seg_color.shape[1] - margin, :]
seg = seg[margin:seg.shape[0] - margin, margin:seg.shape[1] - margin]

mask_true[index_y_d + margin:index_y_u - margin, index_x_d + margin:index_x_u - margin] = seg
prediction_true[index_y_d + margin:index_y_u - margin, index_x_d + margin:index_x_u - margin, :] = seg_color

prediction_true = prediction_true[index_start_h: index_start_h + img_org_h, index_start_w: index_start_w + img_org_w, :]
prediction_true = prediction_true.astype(np.uint8)

else:
img_h_page = img.shape[0]
img_w_page = img.shape[1]
img = img / float(255.0)
img = resize_image(img, model_height, model_width)

label_p_pred = model.predict(img.reshape(1, img.shape[0], img.shape[1], img.shape[2]))

seg = np.argmax(label_p_pred, axis=3)[0]
seg_color = np.repeat(seg[:, :, np.newaxis], 3, axis=2)
prediction_true = resize_image(seg_color, img_h_page, img_w_page)
prediction_true = prediction_true.astype(np.uint8)
return prediction_true[:, :, 0]

def run(self, image=None, image_path=None, save=None, use_patches=False):
if (image is not None and image_path is not None) or (image is None and image_path is None):
raise ValueError("Must pass either a opencv2 image or an image_path")
if image_path is not None:
image = cv2.imread(image_path)
img_last = 0
for n, (model, model_file) in enumerate(zip(self.models, self.model_files)):
self.log.info(f"Predicting with model {model_file} [{n + 1}/{len(self.model_files)}]")

res = self.predict(model, image, use_patches)

img_fin = np.zeros((res.shape[0], res.shape[1], 3))
res[:, :][res[:, :] == 0] = 2
res = res - 1
res = res * 255
img_fin[:, :, 0] = res
img_fin[:, :, 1] = res
img_fin[:, :, 2] = res

img_fin = img_fin.astype(np.uint8)
img_fin = (res[:, :] == 0) * 255
img_last = img_last + img_fin

kernel = np.ones((5, 5), np.uint8)
img_last[:, :][img_last[:, :] > 0] = 255
img_last = (img_last[:, :] == 0) * 255
if save:
# Create the output directory (and if necessary it's parents) if it doesn't exist already
Path(save).parent.mkdir(parents=True, exist_ok=True)
cv2.imwrite(save, img_last)
return img_last
self.model = load_model(str(model_dir.absolute()), compile=False)
self.model_height = self.model.layers[len(self.model.layers) - 1].output_shape[1]
self.model_width = self.model.layers[len(self.model.layers) - 1].output_shape[2]
self.n_classes = self.model.layers[len(self.model.layers) - 1].output_shape[3]

def binarize_image(self, image_path: Path, save_path: Path):
if not image_path.exists():
raise ValueError(f"Image not found: {str(image_path)}")

# Most operations are expecting BGR as this is the standard way how CV2 reads images
# noinspection PyUnresolvedReferences
img = cv2.imread(str(image_path))
original_image_height, original_image_width, image_channels = img.shape

# Padded images must be multiples of model size
padded_image_height = math.ceil(original_image_height / self.model_height) * self.model_height
padded_image_width = math.ceil(original_image_width / self.model_width) * self.model_width
padded_image = np.zeros((padded_image_height, padded_image_width, image_channels))
padded_image[0:original_image_height, 0:original_image_width, :] = img[:, :, :]

image_batch = np.expand_dims(padded_image, 0) # To create the batch information
patches = tf.image.extract_patches(
images=image_batch,
sizes=[1, self.model_height, self.model_width, 1],
strides=[1, self.model_height, self.model_width, 1],
rates=[1, 1, 1, 1],
padding='SAME'
)

number_of_horizontal_patches = patches.shape[1]
number_of_vertical_patches = patches.shape[2]
total_number_of_patches = number_of_horizontal_patches * number_of_vertical_patches
target_shape = (total_number_of_patches, self.model_height, self.model_width, image_channels)
# Squeeze all image patches (n, m, width, height, channels) into a single big batch (b, width, height, channels)
image_patches = tf.reshape(patches, target_shape)
# Normalize the image to values between 0.0 - 1.0
image_patches = image_patches / float(255.0)

predicted_patches = self.model.predict(image_patches)
# We have to manually call garbage collection and clear_session here to avoid memory leaks.
# Taken from https://medium.com/dive-into-ml-ai/dealing-with-memory-leak-issue-in-keras-model-training-e703907a6501
gc.collect()
tf.keras.backend.clear_session()

binary_patches = np.invert(np.argmax(predicted_patches, axis=3).astype(bool)).astype(np.uint8) * 255
full_image_with_padding = self._patches_to_image(
binary_patches,
padded_image_height,
padded_image_width,
self.model_height,
self.model_width
)
full_image = full_image_with_padding[0:original_image_height, 0:original_image_width]
Path(save_path).parent.mkdir(parents=True, exist_ok=True)
# noinspection PyUnresolvedReferences
cv2.imwrite(str(save_path), full_image)

def _patches_to_image(
self,
patches: np.ndarray,
image_height: int,
image_width: int,
patch_height: int,
patch_width: int
):
height = math.ceil(image_height / patch_height) * patch_height
width = math.ceil(image_width / patch_width) * patch_width

image_reshaped = np.reshape(
np.squeeze(patches),
[height // patch_height, width // patch_width, patch_height, patch_width]
)
image_transposed = np.transpose(a=image_reshaped, axes=[0, 2, 1, 3])
image_resized = np.reshape(image_transposed, [height, width])
return image_resized


def split_list_into_worker_batches(files: List[Any], number_of_workers: int) -> List[List[Any]]:
""" Splits any given list into batches for the specified number of workers and returns a list of lists. """
batches = []
batch_size = math.ceil(len(files) / number_of_workers)
batch_start = 0
for i in range(1, number_of_workers + 1):
batch_end = i * batch_size
file_batch_to_delete = files[batch_start: batch_end]
batches.append(file_batch_to_delete)
batch_start = batch_end
return batches


def batch_predict(input_data):
model_dir, input_images, output_images, worker_number = input_data
print(f"Setting visible cuda devices to {str(worker_number)}")
os.environ["CUDA_VISIBLE_DEVICES"] = str(worker_number)

binarizer = SbbBinarizer()
binarizer.load_model(model_dir)

for image_path, output_path in zip(input_images, output_images):
binarizer.binarize_image(image_path=image_path, save_path=output_path)
print(f"Binarized {image_path}")


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-m', '--model_dir', default="model_2021_03_09", help="Path to the directory where the TF model resides or path to an h5 file.")
parser.add_argument('-i', '--input-path', required=True)
parser.add_argument('-o', '--output-path', required=True)
args = parser.parse_args()

input_path = Path(args.input_path)
output_path = Path(args.output_path)
model_directory = args.model_dir

if input_path.is_dir():
print(f"Enumerating all PNG files in {str(input_path)}")
all_input_images = list(input_path.rglob("*.png"))
print(f"Filtering images that have already been binarized in {str(output_path)}")
input_images = [i for i in all_input_images if not (output_path / (i.relative_to(input_path))).exists()]
output_images = [output_path / (i.relative_to(input_path)) for i in input_images]
input_images = [i for i in input_images]

print(f"Starting binarization of {len(input_images)} images")

number_of_gpus = len(tf.config.list_physical_devices('GPU'))
number_of_workers = max(1, number_of_gpus)
image_batches = split_list_into_worker_batches(input_images, number_of_workers)
output_batches = split_list_into_worker_batches(output_images, number_of_workers)

with WorkerPool(n_jobs=number_of_workers, start_method='spawn') as pool:
model_dirs = itertools.repeat(model_directory, len(image_batches))
input_data = zip(model_dirs, image_batches, output_batches, range(number_of_workers))
contents = pool.map_unordered(
batch_predict,
make_single_arguments(input_data),
iterable_len=number_of_workers,
progress_bar=False
)
else:
binarizer = SbbBinarizer()
binarizer.load_model(model_directory)
binarizer.binarize_image(image_path=input_path, save_path=output_path)