diff --git a/TableGeneration/GenerateTable.py b/TableGeneration/GenerateTable.py index 4ed75f8..0c3482f 100644 --- a/TableGeneration/GenerateTable.py +++ b/TableGeneration/GenerateTable.py @@ -1,40 +1,43 @@ import json import os -import sys +import platform import random import string -from PIL import Image +import sys from io import BytesIO -from tqdm import tqdm + import numpy as np +from PIL import Image from selenium.webdriver.common.by import By -from selenium.webdriver.support.ui import WebDriverWait from selenium.webdriver.support import expected_conditions as EC +from selenium.webdriver.support.ui import WebDriverWait +from tqdm import tqdm from TableGeneration.Table import Table class GenerateTable: - def __init__(self, - output, - ch_dict_path, - en_dict_path, - cell_box_type='cell', - min_row=3, - max_row=20, - min_col=3, - max_col=10, - max_span_row_count=3, - max_span_col_count=3, - max_span_value=20, - min_txt_len=2, - max_txt_len=7, - color_prob=0, - cell_max_width=0, - cell_max_height=0, - brower='chrome', - brower_width=1920, - brower_height=1920): + def __init__( + self, + output, + ch_dict_path, + en_dict_path, + cell_box_type="cell", + min_row=3, + max_row=20, + min_col=3, + max_col=10, + max_span_row_count=3, + max_span_col_count=3, + max_span_value=20, + min_txt_len=2, + max_txt_len=7, + color_prob=0, + cell_max_width=0, + cell_max_height=0, + brower="chrome", + brower_width=1920, + brower_height=1920, ): self.output = output # wheter to store images separately or not self.ch_dict_path = ch_dict_path self.en_dict_path = en_dict_path @@ -55,21 +58,24 @@ def __init__(self, self.brower_height = brower_height # brower height self.brower_width = brower_width # brower width - if self.brower == 'chrome': + if self.brower == "chrome": from selenium.webdriver import Chrome as Brower from selenium.webdriver import ChromeOptions as Options else: from selenium.webdriver import Firefox as Brower from selenium.webdriver import FirefoxOptions as Options opts = Options() - opts.add_argument('--headless') - opts.add_argument('--no-sandbox') + opts.add_argument("--headless") + opts.add_argument("--no-sandbox") self.driver = Brower(options=opts) + self.is_macos = platform.system() == "Darwin" + self.ratio = 2 + def gen_table_img(self, img_count): os.makedirs(self.output, exist_ok=True) f_gt = open( - os.path.join(self.output, 'gt.txt'), encoding='utf-8', mode='w') + os.path.join(self.output, "gt.txt"), encoding="utf-8", mode="w") for i in tqdm(range(img_count)): # data_arr contains the images of generated tables and all_table_categories contains the table category of each of the table out = self.generate_table() @@ -80,30 +86,30 @@ def gen_table_img(self, img_count): im, contens = self.clip_white(im, contens) # randomly select a name of length=20 for file. - output_file_name = ''.join( + output_file_name = "".join( random.choices( string.ascii_uppercase + string.digits, k=20)) - output_file_name = '{}_{}_{}'.format(border, i, output_file_name) + output_file_name = "{}_{}_{}".format(border, i, output_file_name) # print('{}/{}, {}'.format(i, img_count, output_file_name)) # if the image and equivalent html is need to be stored - os.makedirs(os.path.join(self.output, 'html'), exist_ok=True) - os.makedirs(os.path.join(self.output, 'img'), exist_ok=True) - - html_save_path = os.path.join(self.output, 'html', - output_file_name + '.html') - img_save_path = os.path.join(self.output, 'img', - output_file_name + '.jpg') - with open(html_save_path, encoding='utf-8', mode='w') as f: + os.makedirs(os.path.join(self.output, "html"), exist_ok=True) + os.makedirs(os.path.join(self.output, "img"), exist_ok=True) + + html_save_path = os.path.join(self.output, "html", + output_file_name + ".html") + img_save_path = os.path.join(self.output, "img", + output_file_name + ".jpg") + with open(html_save_path, encoding="utf-8", mode="w") as f: f.write(html_content) im.save(img_save_path, dpi=(600, 600)) # 构造标注信息 - img_file_name = os.path.join('img', output_file_name + '.jpg') + img_file_name = os.path.join("img", output_file_name + ".jpg") label_info = self.make_ppstructure_label(structure, contens, img_file_name) - f_gt.write('{}\n'.format( + f_gt.write("{}\n".format( json.dumps( label_info, ensure_ascii=False))) # convert to PP-Structure label format @@ -116,12 +122,20 @@ def generate_table(self): rows = random.randint(self.min_row, self.max_row) try: # initialize table class - table = Table(self.ch_dict_path, self.en_dict_path, - self.cell_box_type, rows, cols, self.min_txt_len, - self.max_txt_len, self.max_span_row_count, - self.max_span_col_count, self.max_span_value, - self.color_prob, self.cell_max_width, - self.cell_max_height) + table = Table( + self.ch_dict_path, + self.en_dict_path, + self.cell_box_type, + rows, + cols, + self.min_txt_len, + self.max_txt_len, + self.max_span_row_count, + self.max_span_col_count, + self.max_span_value, + self.color_prob, + self.cell_max_width, + self.cell_max_height, ) # get table of rows and cols based on unlv distribution and get features of this table # (same row, col and cell matrices, total unique ids, html conversion of table and its category) id_count, html_content, structure, border = table.create() @@ -132,46 +146,49 @@ def generate_table(self): return im, html_content, structure, contens, border except KeyboardInterrupt: import sys + sys.exit() except: import traceback + traceback.print_exc() return None return None def make_ppstructure_label(self, structure, bboxes, img_path): d = { - 'filename': img_path, - 'html': { - 'structure': { - 'tokens': structure + "filename": img_path, + "html": { + "structure": { + "tokens": structure } } } cells = [] for bbox in bboxes: text = bbox[1] - cells.append({'tokens': list(text), 'bbox': bbox[2:]}) - d['html']['cells'] = cells - d['gt'] = self.rebuild_html_from_ppstructure_label(d) + cells.append({"tokens": list(text), "bbox": bbox[2:]}) + d["html"]["cells"] = cells + d["gt"] = self.rebuild_html_from_ppstructure_label(d) return d def rebuild_html_from_ppstructure_label(self, label_info): from html import escape - html_code = label_info['html']['structure']['tokens'].copy() + + html_code = label_info["html"]["structure"]["tokens"].copy() to_insert = [ - i for i, tag in enumerate(html_code) if tag in ('