diff --git a/easyocr/easyocr.py b/easyocr/easyocr.py index c08fe0388d..b62e4b3ef7 100644 --- a/easyocr/easyocr.py +++ b/easyocr/easyocr.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- -from .recognition import get_recognizer, get_text +from .recognition import get_recognizer, get_text, get_text_prob from .utils import group_text_box, get_image_list, calculate_md5, get_paragraph,\ download_and_unzip, printProgressBar, diff, reformat_input,\ make_rotated_img_list, set_result_with_confidence,\ @@ -350,6 +350,93 @@ def detect(self, img, min_size = 20, text_threshold = 0.7, low_text = 0.4,\ return horizontal_list_agg, free_list_agg + def recognize_prob(self, img_cv_grey, horizontal_list=None, free_list=None,\ + decoder = 'greedy', beamWidth= 5, batch_size = 1,\ + workers = 0, allowlist = None, blocklist = None, detail = 1,\ + rotation_info = None,paragraph = False,\ + contrast_ths = 0.1,adjust_contrast = 0.5, filter_ths = 0.003,\ + y_ths = 0.5, x_ths = 1.0, reformat=True, output_format='standard'): + + if reformat: + img, img_cv_grey = reformat_input(img_cv_grey) + + if allowlist: + ignore_char = ''.join(set(self.character)-set(allowlist)) + elif blocklist: + ignore_char = ''.join(set(blocklist)) + else: + ignore_char = ''.join(set(self.character)-set(self.lang_char)) + + if self.model_lang in ['chinese_tra','chinese_sim']: decoder = 'greedy' + + if (horizontal_list==None) and (free_list==None): + y_max, x_max = img_cv_grey.shape + horizontal_list = [[0, x_max, 0, y_max]] + free_list = [] + + # without gpu/parallelization, it is faster to process image one by one + if ((batch_size == 1) or (self.device == 'cpu')) and not rotation_info: + result = [] + for bbox in horizontal_list: + h_list = [bbox] + f_list = [] + image_list, max_width = get_image_list(h_list, f_list, img_cv_grey, model_height = imgH) + result0 = get_text_prob(self.character, imgH, int(max_width), self.recognizer, self.converter, image_list,\ + ignore_char, decoder, beamWidth, batch_size, contrast_ths, adjust_contrast, filter_ths,\ + workers, self.device) + result += result0 + for bbox in free_list: + h_list = [] + f_list = [bbox] + image_list, max_width = get_image_list(h_list, f_list, img_cv_grey, model_height = imgH) + result0 = get_text_prob(self.character, imgH, int(max_width), self.recognizer, self.converter, image_list,\ + ignore_char, decoder, beamWidth, batch_size, contrast_ths, adjust_contrast, filter_ths,\ + workers, self.device) + result += result0 + # default mode will try to process multiple boxes at the same time + else: + image_list, max_width = get_image_list(horizontal_list, free_list, img_cv_grey, model_height = imgH) + image_len = len(image_list) + if rotation_info and image_list: + image_list = make_rotated_img_list(rotation_info, image_list) + max_width = max(max_width, imgH) + + result = get_text_prob(self.character, imgH, int(max_width), self.recognizer, self.converter, image_list,\ + ignore_char, decoder, beamWidth, batch_size, contrast_ths, adjust_contrast, filter_ths,\ + workers, self.device) + + if rotation_info and (horizontal_list+free_list): + # Reshape result to be a list of lists, each row being for + # one of the rotations (first row being no rotation) + result = set_result_with_confidence( + [result[image_len*i:image_len*(i+1)] for i in range(len(rotation_info) + 1)]) + + if self.model_lang == 'arabic': + direction_mode = 'rtl' + result = [list(item) for item in result] + for item in result: + item[1] = get_display(item[1]) + else: + direction_mode = 'ltr' + + if paragraph: + result = get_paragraph(result, x_ths=x_ths, y_ths=y_ths, mode = direction_mode) + + if detail == 0: + return [item[1] for item in result] + elif output_format == 'dict': + if paragraph: + return [ {'boxes':item[0],'text':item[1]} for item in result] + return [ {'boxes':item[0],'text':item[1],'confident':item[2]} for item in result] + elif output_format == 'json': + if paragraph: + return [json.dumps({'boxes':[list(map(int, lst)) for lst in item[0]],'text':item[1]}, ensure_ascii=False) for item in result] + return [json.dumps({'boxes':[list(map(int, lst)) for lst in item[0]],'text':item[1],'confident':item[2]}, ensure_ascii=False) for item in result] + elif output_format == 'free_merge': + return merge_to_free(result, free_list) + else: + return result + def recognize(self, img_cv_grey, horizontal_list=None, free_list=None,\ decoder = 'greedy', beamWidth= 5, batch_size = 1,\ workers = 0, allowlist = None, blocklist = None, detail = 1,\ @@ -472,6 +559,42 @@ def readtext(self, image, decoder = 'greedy', beamWidth= 5, batch_size = 1,\ filter_ths, y_ths, x_ths, False, output_format) return result + + def readtext_prob(self, image, decoder = 'greedy', beamWidth= 5, batch_size = 1,\ + workers = 0, allowlist = None, blocklist = None, detail = 1,\ + rotation_info = None, paragraph = False, min_size = 20,\ + contrast_ths = 0.1,adjust_contrast = 0.5, filter_ths = 0.003,\ + text_threshold = 0.7, low_text = 0.4, link_threshold = 0.4,\ + canvas_size = 2560, mag_ratio = 1.,\ + slope_ths = 0.1, ycenter_ths = 0.5, height_ths = 0.5,\ + width_ths = 0.5, y_ths = 0.5, x_ths = 1.0, add_margin = 0.1, + threshold = 0.2, bbox_min_score = 0.2, bbox_min_size = 3, max_candidates = 0, + output_format='standard'): + ''' + Parameters: + image: file path or numpy-array or a byte stream object + ''' + img, img_cv_grey = reformat_input(image) + + horizontal_list, free_list = self.detect(img, + min_size = min_size, text_threshold = text_threshold,\ + low_text = low_text, link_threshold = link_threshold,\ + canvas_size = canvas_size, mag_ratio = mag_ratio,\ + slope_ths = slope_ths, ycenter_ths = ycenter_ths,\ + height_ths = height_ths, width_ths= width_ths,\ + add_margin = add_margin, reformat = False,\ + threshold = threshold, bbox_min_score = bbox_min_score,\ + bbox_min_size = bbox_min_size, max_candidates = max_candidates + ) + # get the 1st result from hor & free list as self.detect returns a list of depth 3 + horizontal_list, free_list = horizontal_list[0], free_list[0] + result = self.recognize_prob(img_cv_grey, horizontal_list, free_list,\ + decoder, beamWidth, batch_size,\ + workers, allowlist, blocklist, detail, rotation_info,\ + paragraph, contrast_ths, adjust_contrast,\ + filter_ths, y_ths, x_ths, False, output_format) + + return result def readtextlang(self, image, decoder = 'greedy', beamWidth= 5, batch_size = 1,\ workers = 0, allowlist = None, blocklist = None, detail = 1,\ @@ -577,3 +700,18 @@ def readtext_batched(self, image, n_width=None, n_height=None,\ filter_ths, y_ths, x_ths, False, output_format)) return result_agg + + +def convert_prob_to_word(prob, converter): + """ + For use with the readtest_prob outputs. + + - prob should be 2d + - convert = reader.converter + """ + assert prob.ndim == 2 + preds_index = np.argmax(prob, axis=1) + preds_index = preds_index.flatten() + preds_size = np.array([prob.shape[0]]) + preds_str = converter.decode_greedy(preds_index, preds_size)[0] + return preds_str \ No newline at end of file diff --git a/easyocr/recognition.py b/easyocr/recognition.py index 530ef9517e..bdae0c5cdc 100644 --- a/easyocr/recognition.py +++ b/easyocr/recognition.py @@ -150,6 +150,48 @@ def recognizer_predict(model, converter, test_loader, batch_max_length,\ return result +def recognizer_predict_prob(model, converter, test_loader, batch_max_length,\ + ignore_idx, char_group_idx, decoder = 'greedy', beamWidth= 5, device = 'cpu'): + model.eval() + result = [] + with torch.no_grad(): + for image_tensors in test_loader: + batch_size = image_tensors.size(0) + image = image_tensors.to(device) + # For max length prediction + length_for_pred = torch.IntTensor([batch_max_length] * batch_size).to(device) + text_for_pred = torch.LongTensor(batch_size, batch_max_length + 1).fill_(0).to(device) + + preds = model(image, text_for_pred) + + # Select max probabilty (greedy decoding) then decode index to character + preds_size = torch.IntTensor([preds.size(1)] * batch_size) + + ######## filter ignore_char, rebalance + preds_prob = F.softmax(preds, dim=2) + preds_prob = preds_prob.cpu().detach().numpy() + preds_prob[:,:,ignore_idx] = 0. + pred_norm = preds_prob.sum(axis=2) + preds_prob = preds_prob/np.expand_dims(pred_norm, axis=-1) + preds_prob = torch.from_numpy(preds_prob).float().to(device) + preds_prob = preds_prob.cpu().detach().numpy() + + values = preds_prob.max(axis=2) + indices = preds_prob.argmax(axis=2) + preds_max_prob = [] + for v,i in zip(values, indices): + max_probs = v[i!=0] # this removes blanks + if len(max_probs)>0: + preds_max_prob.append(max_probs) + else: + preds_max_prob.append(np.array([0])) + + for pred_max_prob in preds_max_prob: + confidence_score = custom_mean(pred_max_prob) + result.append([preds_prob, confidence_score]) + + return result + def get_recognizer(recog_network, network_params, character,\ separator_list, dict_list, model_path,\ device = 'cpu', quantize = True): @@ -231,3 +273,52 @@ def get_text(character, imgH, imgW, recognizer, converter, image_list,\ result.append( (box, pred1[0], pred1[1]) ) return result + +def get_text_prob(character, imgH, imgW, recognizer, converter, image_list,\ + ignore_char = '',decoder = 'greedy', beamWidth =5, batch_size=1, contrast_ths=0.1,\ + adjust_contrast=0.5, filter_ths = 0.003, workers = 1, device = 'cpu'): + batch_max_length = int(imgW/10) + + char_group_idx = {} + ignore_idx = [] + for char in ignore_char: + try: ignore_idx.append(character.index(char)+1) + except: pass + + coord = [item[0] for item in image_list] + img_list = [item[1] for item in image_list] + AlignCollate_normal = AlignCollate(imgH=imgH, imgW=imgW, keep_ratio_with_pad=True) + test_data = ListDataset(img_list) + test_loader = torch.utils.data.DataLoader( + test_data, batch_size=batch_size, shuffle=False, + num_workers=int(workers), collate_fn=AlignCollate_normal, pin_memory=True) + + # predict first round + result1 = recognizer_predict_prob(recognizer, converter, test_loader,batch_max_length,\ + ignore_idx, char_group_idx, decoder, beamWidth, device = device) + + # predict second round + low_confident_idx = [i for i,item in enumerate(result1) if (item[1] < contrast_ths)] + if len(low_confident_idx) > 0: + img_list2 = [img_list[i] for i in low_confident_idx] + AlignCollate_contrast = AlignCollate(imgH=imgH, imgW=imgW, keep_ratio_with_pad=True, adjust_contrast=adjust_contrast) + test_data = ListDataset(img_list2) + test_loader = torch.utils.data.DataLoader( + test_data, batch_size=batch_size, shuffle=False, + num_workers=int(workers), collate_fn=AlignCollate_contrast, pin_memory=True) + result2 = recognizer_predict_prob(recognizer, converter, test_loader, batch_max_length,\ + ignore_idx, char_group_idx, decoder, beamWidth, device = device) + + result = [] + for i, zipped in enumerate(zip(coord, result1)): + box, pred1 = zipped + if i in low_confident_idx: + pred2 = result2[low_confident_idx.index(i)] + if pred1[1]>pred2[1]: + result.append( (box, pred1[0], pred1[1]) ) + else: + result.append( (box, pred2[0], pred2[1]) ) + else: + result.append( (box, pred1[0], pred1[1]) ) + + return result