Skip to content

Commit

Permalink
inference with batch size bigger than 1
Browse files Browse the repository at this point in the history
  • Loading branch information
vahidrezanezhad committed Aug 23, 2024
1 parent 4f8210d commit c10a525
Showing 1 changed file with 100 additions and 72 deletions.
172 changes: 100 additions & 72 deletions qurator/eynollah/eynollah.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,11 +548,11 @@ def resize_and_enhance_image_with_column_classifier(self,light_version):
if self.input_binary:
img = self.imread()
if self.dir_in:
prediction_bin = self.do_prediction(True, img, self.model_bin)
prediction_bin = self.do_prediction(True, img, self.model_bin, n_batch_inference=5)
else:

model_bin, session_bin = self.start_new_session_and_model(self.model_dir_of_binarization)
prediction_bin = self.do_prediction(True, img, model_bin)
prediction_bin = self.do_prediction(True, img, model_bin, n_batch_inference=5)

prediction_bin=prediction_bin[:,:,0]
prediction_bin = (prediction_bin[:,:]==0)*1
Expand Down Expand Up @@ -703,7 +703,7 @@ def start_new_session_and_model(self, model_dir):

return model, None

def do_prediction(self, patches, img, model, marginal_of_patch_percent=0.1):
def do_prediction(self, patches, img, model, n_batch_inference=1, marginal_of_patch_percent=0.1):
self.logger.debug("enter do_prediction")

img_height_model = model.layers[len(model.layers) - 1].output_shape[1]
Expand Down Expand Up @@ -745,7 +745,17 @@ def do_prediction(self, patches, img, model, marginal_of_patch_percent=0.1):
nyf = img_h / float(height_mid)
nxf = int(nxf) + 1 if nxf > int(nxf) else int(nxf)
nyf = int(nyf) + 1 if nyf > int(nyf) else int(nyf)


list_i_s = []
list_j_s = []
list_x_u = []
list_x_d = []
list_y_u = []
list_y_d = []

batch_indexer = 0

img_patch = np.zeros((n_batch_inference, img_height_model, img_width_model, 3))
for i in range(nxf):
for j in range(nyf):
if i == 0:
Expand All @@ -766,59 +776,77 @@ def do_prediction(self, patches, img, model, marginal_of_patch_percent=0.1):
if index_y_u > img_h:
index_y_u = img_h
index_y_d = img_h - img_height_model

list_i_s.append(i)
list_j_s.append(j)
list_x_u.append(index_x_u)
list_x_d.append(index_x_d)
list_y_d.append(index_y_d)
list_y_u.append(index_y_u)


img_patch = img[index_y_d:index_y_u, index_x_d:index_x_u, :]
label_p_pred = model.predict(img_patch.reshape(1, img_patch.shape[0], img_patch.shape[1], img_patch.shape[2]),
verbose=0)
seg = np.argmax(label_p_pred, axis=3)[0]
seg_color = np.repeat(seg[:, :, np.newaxis], 3, axis=2)

if i == 0 and j == 0:
seg_color = seg_color[0 : seg_color.shape[0] - margin, 0 : seg_color.shape[1] - margin, :]
#seg = seg[0 : seg.shape[0] - margin, 0 : seg.shape[1] - margin]
#mask_true[index_y_d + 0 : index_y_u - margin, index_x_d + 0 : index_x_u - margin] = seg
prediction_true[index_y_d + 0 : index_y_u - margin, index_x_d + 0 : index_x_u - margin, :] = seg_color
elif i == nxf - 1 and j == nyf - 1:
seg_color = seg_color[margin : seg_color.shape[0] - 0, margin : seg_color.shape[1] - 0, :]
#seg = seg[margin : seg.shape[0] - 0, margin : seg.shape[1] - 0]
#mask_true[index_y_d + margin : index_y_u - 0, index_x_d + margin : index_x_u - 0] = seg
prediction_true[index_y_d + margin : index_y_u - 0, index_x_d + margin : index_x_u - 0, :] = seg_color
elif i == 0 and j == nyf - 1:
seg_color = seg_color[margin : seg_color.shape[0] - 0, 0 : seg_color.shape[1] - margin, :]
#seg = seg[margin : seg.shape[0] - 0, 0 : seg.shape[1] - margin]
#mask_true[index_y_d + margin : index_y_u - 0, index_x_d + 0 : index_x_u - margin] = seg
prediction_true[index_y_d + margin : index_y_u - 0, index_x_d + 0 : index_x_u - margin, :] = seg_color
elif i == nxf - 1 and j == 0:
seg_color = seg_color[0 : seg_color.shape[0] - margin, margin : seg_color.shape[1] - 0, :]
#seg = seg[0 : seg.shape[0] - margin, margin : seg.shape[1] - 0]
#mask_true[index_y_d + 0 : index_y_u - margin, index_x_d + margin : index_x_u - 0] = seg
prediction_true[index_y_d + 0 : index_y_u - margin, index_x_d + margin : index_x_u - 0, :] = seg_color
elif i == 0 and j != 0 and j != nyf - 1:
seg_color = seg_color[margin : seg_color.shape[0] - margin, 0 : seg_color.shape[1] - margin, :]
#seg = seg[margin : seg.shape[0] - margin, 0 : seg.shape[1] - margin]
#mask_true[index_y_d + margin : index_y_u - margin, index_x_d + 0 : index_x_u - margin] = seg
prediction_true[index_y_d + margin : index_y_u - margin, index_x_d + 0 : index_x_u - margin, :] = seg_color
elif i == nxf - 1 and j != 0 and j != nyf - 1:
seg_color = seg_color[margin : seg_color.shape[0] - margin, margin : seg_color.shape[1] - 0, :]
#seg = seg[margin : seg.shape[0] - margin, margin : seg.shape[1] - 0]
#mask_true[index_y_d + margin : index_y_u - margin, index_x_d + margin : index_x_u - 0] = seg
prediction_true[index_y_d + margin : index_y_u - margin, index_x_d + margin : index_x_u - 0, :] = seg_color
elif i != 0 and i != nxf - 1 and j == 0:
seg_color = seg_color[0 : seg_color.shape[0] - margin, margin : seg_color.shape[1] - margin, :]
#seg = seg[0 : seg.shape[0] - margin, margin : seg.shape[1] - margin]
#mask_true[index_y_d + 0 : index_y_u - margin, index_x_d + margin : index_x_u - margin] = seg
prediction_true[index_y_d + 0 : index_y_u - margin, index_x_d + margin : index_x_u - margin, :] = seg_color
elif i != 0 and i != nxf - 1 and j == nyf - 1:
seg_color = seg_color[margin : seg_color.shape[0] - 0, margin : seg_color.shape[1] - margin, :]
#seg = seg[margin : seg.shape[0] - 0, margin : seg.shape[1] - margin]
#mask_true[index_y_d + margin : index_y_u - 0, index_x_d + margin : index_x_u - margin] = seg
prediction_true[index_y_d + margin : index_y_u - 0, index_x_d + margin : index_x_u - margin, :] = seg_color
else:
seg_color = seg_color[margin : seg_color.shape[0] - margin, margin : seg_color.shape[1] - margin, :]
#seg = seg[margin : seg.shape[0] - margin, margin : seg.shape[1] - margin]
#mask_true[index_y_d + margin : index_y_u - margin, index_x_d + margin : index_x_u - margin] = seg
prediction_true[index_y_d + margin : index_y_u - margin, index_x_d + margin : index_x_u - margin, :] = seg_color

img_patch[batch_indexer,:,:,:] = img[index_y_d:index_y_u, index_x_d:index_x_u, :]

batch_indexer = batch_indexer + 1

if batch_indexer == n_batch_inference:

label_p_pred = model.predict(img_patch,verbose=0)

seg = np.argmax(label_p_pred, axis=3)

indexer_inside_batch = 0
for i_batch, j_batch in zip(list_i_s, list_j_s):
seg_in = seg[indexer_inside_batch,:,:]
seg_color = np.repeat(seg_in[:, :, np.newaxis], 3, axis=2)

index_y_u_in = list_y_u[indexer_inside_batch]
index_y_d_in = list_y_d[indexer_inside_batch]

index_x_u_in = list_x_u[indexer_inside_batch]
index_x_d_in = list_x_d[indexer_inside_batch]

if i_batch == 0 and j_batch == 0:
seg_color = seg_color[0 : seg_color.shape[0] - margin, 0 : seg_color.shape[1] - margin, :]
prediction_true[index_y_d_in + 0 : index_y_u_in - margin, index_x_d_in + 0 : index_x_u_in - margin, :] = seg_color
elif i_batch == nxf - 1 and j_batch == nyf - 1:
seg_color = seg_color[margin : seg_color.shape[0] - 0, margin : seg_color.shape[1] - 0, :]
prediction_true[index_y_d_in + margin : index_y_u_in - 0, index_x_d_in + margin : index_x_u_in - 0, :] = seg_color
elif i_batch == 0 and j_batch == nyf - 1:
seg_color = seg_color[margin : seg_color.shape[0] - 0, 0 : seg_color.shape[1] - margin, :]
prediction_true[index_y_d_in + margin : index_y_u_in - 0, index_x_d_in + 0 : index_x_u_in - margin, :] = seg_color
elif i_batch == nxf - 1 and j_batch == 0:
seg_color = seg_color[0 : seg_color.shape[0] - margin, margin : seg_color.shape[1] - 0, :]
prediction_true[index_y_d_in + 0 : index_y_u_in - margin, index_x_d_in + margin : index_x_u_in - 0, :] = seg_color
elif i_batch == 0 and j_batch != 0 and j_batch != nyf - 1:
seg_color = seg_color[margin : seg_color.shape[0] - margin, 0 : seg_color.shape[1] - margin, :]
prediction_true[index_y_d_in + margin : index_y_u_in - margin, index_x_d_in + 0 : index_x_u_in - margin, :] = seg_color
elif i_batch == nxf - 1 and j_batch != 0 and j_batch != nyf - 1:
seg_color = seg_color[margin : seg_color.shape[0] - margin, margin : seg_color.shape[1] - 0, :]
prediction_true[index_y_d_in + margin : index_y_u_in - margin, index_x_d_in + margin : index_x_u_in - 0, :] = seg_color
elif i_batch != 0 and i_batch != nxf - 1 and j_batch == 0:
seg_color = seg_color[0 : seg_color.shape[0] - margin, margin : seg_color.shape[1] - margin, :]
prediction_true[index_y_d_in + 0 : index_y_u_in - margin, index_x_d_in + margin : index_x_u_in - margin, :] = seg_color
elif i_batch != 0 and i_batch != nxf - 1 and j_batch == nyf - 1:
seg_color = seg_color[margin : seg_color.shape[0] - 0, margin : seg_color.shape[1] - margin, :]
prediction_true[index_y_d_in + margin : index_y_u_in - 0, index_x_d_in + margin : index_x_u_in - margin, :] = seg_color
else:
seg_color = seg_color[margin : seg_color.shape[0] - margin, margin : seg_color.shape[1] - margin, :]
prediction_true[index_y_d_in + margin : index_y_u_in - margin, index_x_d_in + margin : index_x_u_in - margin, :] = seg_color

indexer_inside_batch = indexer_inside_batch +1


list_i_s = []
list_j_s = []
list_x_u = []
list_x_d = []
list_y_u = []
list_y_d = []

batch_indexer = 0

img_patch = np.zeros((n_batch_inference, img_height_model, img_width_model, 3))
prediction_true = prediction_true.astype(np.uint8)
#del model
#gc.collect()
Expand All @@ -835,7 +863,7 @@ def do_prediction_new_concept(self, patches, img, model, marginal_of_patch_perce
img = img / float(255.0)
img = resize_image(img, img_height_model, img_width_model)

label_p_pred = model.predict(img.reshape(1, img.shape[0], img.shape[1], img.shape[2]))
label_p_pred = model.predict(img.reshape(1, img.shape[0], img.shape[1], img.shape[2]), verbose=0)


seg = np.argmax(label_p_pred, axis=3)[0]
Expand Down Expand Up @@ -1147,7 +1175,7 @@ def extract_text_regions_new(self, img, patches, cols):

marginal_of_patch_percent = 0.1

prediction_regions = self.do_prediction(patches, img, model_region, marginal_of_patch_percent)
prediction_regions = self.do_prediction(patches, img, model_region, marginal_of_patch_percent=marginal_of_patch_percent)

prediction_regions = resize_image(prediction_regions, img_height_h, img_width_h)
self.logger.debug("exit extract_text_regions")
Expand All @@ -1173,23 +1201,23 @@ def extract_text_regions(self, img, patches, cols):
img2 = img2.astype(np.uint8)
img2 = resize_image(img2, int(img_height_h * 0.7), int(img_width_h * 0.7))
marginal_of_patch_percent = 0.1
prediction_regions2 = self.do_prediction(patches, img2, model_region, marginal_of_patch_percent)
prediction_regions2 = self.do_prediction(patches, img2, model_region, marginal_of_patch_percent=marginal_of_patch_percent)
prediction_regions2 = resize_image(prediction_regions2, img_height_h, img_width_h)

if cols == 2:
img2 = otsu_copy_binary(img)
img2 = img2.astype(np.uint8)
img2 = resize_image(img2, int(img_height_h * 0.4), int(img_width_h * 0.4))
marginal_of_patch_percent = 0.1
prediction_regions2 = self.do_prediction(patches, img2, model_region, marginal_of_patch_percent)
prediction_regions2 = self.do_prediction(patches, img2, model_region, marginal_of_patch_percent=marginal_of_patch_percent)
prediction_regions2 = resize_image(prediction_regions2, img_height_h, img_width_h)

elif cols > 2:
img2 = otsu_copy_binary(img)
img2 = img2.astype(np.uint8)
img2 = resize_image(img2, int(img_height_h * 0.3), int(img_width_h * 0.3))
marginal_of_patch_percent = 0.1
prediction_regions2 = self.do_prediction(patches, img2, model_region, marginal_of_patch_percent)
prediction_regions2 = self.do_prediction(patches, img2, model_region, marginal_of_patch_percent=marginal_of_patch_percent)
prediction_regions2 = resize_image(prediction_regions2, img_height_h, img_width_h)

if cols == 2:
Expand Down Expand Up @@ -1245,7 +1273,7 @@ def extract_text_regions(self, img, patches, cols):
img= resize_image(img, int(img_height_h * 0.9), int(img_width_h * 0.9))

marginal_of_patch_percent = 0.1
prediction_regions = self.do_prediction(patches, img, model_region, marginal_of_patch_percent)
prediction_regions = self.do_prediction(patches, img, model_region, marginal_of_patch_percent=marginal_of_patch_percent)
prediction_regions = resize_image(prediction_regions, img_height_h, img_width_h)
self.logger.debug("exit extract_text_regions")
return prediction_regions, prediction_regions2
Expand Down Expand Up @@ -1634,9 +1662,9 @@ def textline_contours(self, img, patches, scaler_h, scaler_w):
img = resize_image(img_org, int(img_org.shape[0] * scaler_h), int(img_org.shape[1] * scaler_w))
#print(img.shape,'bin shape')
if not self.dir_in:
prediction_textline = self.do_prediction(patches, img, model_textline)
prediction_textline = self.do_prediction(patches, img, model_textline, n_batch_inference=4)
else:
prediction_textline = self.do_prediction(patches, img, self.model_textline)
prediction_textline = self.do_prediction(patches, img, self.model_textline, n_batch_inference=4)
prediction_textline = resize_image(prediction_textline, img_h, img_w)
if not self.dir_in:
prediction_textline_longshot = self.do_prediction(False, img, model_textline)
Expand Down Expand Up @@ -1721,9 +1749,9 @@ def get_regions_light_v(self,img,is_image_enhanced, num_col_classifier):

if not self.dir_in:
model_bin, session_bin = self.start_new_session_and_model(self.model_dir_of_binarization)
prediction_bin = self.do_prediction(True, img_resized, model_bin)
prediction_bin = self.do_prediction(True, img_resized, model_bin, n_batch_inference=5)
else:
prediction_bin = self.do_prediction(True, img_resized, self.model_bin)
prediction_bin = self.do_prediction(True, img_resized, self.model_bin, n_batch_inference=5)
prediction_bin=prediction_bin[:,:,0]
prediction_bin = (prediction_bin[:,:]==0)*1
prediction_bin = prediction_bin*255
Expand Down Expand Up @@ -1870,9 +1898,9 @@ def get_regions_from_xy_2models(self,img,is_image_enhanced, num_col_classifier):
img = resize_image(img_org, int(img_org.shape[0]), int(img_org.shape[1]))

if self.dir_in:
prediction_regions_org2 = self.do_prediction(True, img, self.model_region_p2, 0.2)
prediction_regions_org2 = self.do_prediction(True, img, self.model_region_p2, marginal_of_patch_percent=0.2)
else:
prediction_regions_org2 = self.do_prediction(True, img, model_region, 0.2)
prediction_regions_org2 = self.do_prediction(True, img, model_region, marginal_of_patch_percent=0.2)
prediction_regions_org2=resize_image(prediction_regions_org2, img_height_h, img_width_h )


Expand Down Expand Up @@ -1905,9 +1933,9 @@ def get_regions_from_xy_2models(self,img,is_image_enhanced, num_col_classifier):
else:
if not self.dir_in:
model_bin, session_bin = self.start_new_session_and_model(self.model_dir_of_binarization)
prediction_bin = self.do_prediction(True, img_org, model_bin)
prediction_bin = self.do_prediction(True, img_org, model_bin, n_batch_inference=5)
else:
prediction_bin = self.do_prediction(True, img_org, self.model_bin)
prediction_bin = self.do_prediction(True, img_org, self.model_bin, n_batch_inference=5)
prediction_bin = resize_image(prediction_bin, img_height_h, img_width_h )

prediction_bin=prediction_bin[:,:,0]
Expand Down Expand Up @@ -1958,9 +1986,9 @@ def get_regions_from_xy_2models(self,img,is_image_enhanced, num_col_classifier):

if not self.dir_in:
model_bin, session_bin = self.start_new_session_and_model(self.model_dir_of_binarization)
prediction_bin = self.do_prediction(True, img_org, model_bin)
prediction_bin = self.do_prediction(True, img_org, model_bin, n_batch_inference=5)
else:
prediction_bin = self.do_prediction(True, img_org, self.model_bin)
prediction_bin = self.do_prediction(True, img_org, self.model_bin, n_batch_inference=5)
prediction_bin = resize_image(prediction_bin, img_height_h, img_width_h )
prediction_bin=prediction_bin[:,:,0]

Expand Down

0 comments on commit c10a525

Please sign in to comment.