From ceae60efe64bf2d7c2f9ed81d8a76e5851138005 Mon Sep 17 00:00:00 2001 From: honeyxu1108 Date: Sun, 29 Dec 2024 22:27:29 +0800 Subject: [PATCH] finally fix some hard-code problem --- logo_matching.py | 42 +++++++++++++++++++++++++++--------------- logo_recog.py | 15 +++++++++++---- setup.sh | 1 + 3 files changed, 39 insertions(+), 19 deletions(-) diff --git a/logo_matching.py b/logo_matching.py index 9d1bdf1..755b399 100644 --- a/logo_matching.py +++ b/logo_matching.py @@ -88,24 +88,29 @@ def cache_reference_list(model, targetlist_path: str, grayscale=False): :return file_name_list: targetlist paths ''' - # Prediction for targetlists + # Prediction for targetlists logo_feat_list = [] file_name_list = [] - for target in tqdm(os.listdir(targetlist_path)): + target_list = os.listdir(targetlist_path) + for target in tqdm(target_list): if target.startswith('.'): # skip hidden files continue - for logo_path in os.listdir(os.path.join(targetlist_path, target)): - if logo_path.endswith('.png') or logo_path.endswith('.jpeg') or logo_path.endswith( - '.jpg') or logo_path.endswith('.PNG') \ - or logo_path.endswith('.JPG') or logo_path.endswith('.JPEG'): - if logo_path.startswith('loginpage') or logo_path.startswith('homepage'): # skip homepage/loginpage + logo_list = os.listdir(os.path.join(targetlist_path, target)) + for logo_path in logo_list: + # List of valid image extensions + valid_extensions = ['.png', '.jpeg', '.jpg', 'PNG','.JPG', '.JPEG'] + if any(logo_path.endswith(ext) for ext in valid_extensions): + skip_prefixes = ['loginpage', 'homepage'] + if any(logo_path.startswith(prefix) for prefix in skip_prefixes): # skip homepage/loginpage + continue + try: + logo_feat_list.append(get_embedding(img=os.path.join(targetlist_path, target, logo_path), + model=model, grayscale=grayscale)) + file_name_list.append(str(os.path.join(targetlist_path, target, logo_path))) + except OSError: + print(f"Error opening image: {os.path.join(targetlist_path, target, logo_path)}") continue - logo_feat_list.append(get_embedding(img=os.path.join(targetlist_path, target, logo_path), - model=model, grayscale=grayscale)) - file_name_list.append(str(os.path.join(targetlist_path, target, logo_path))) - - return np.asarray(logo_feat_list), np.asarray(file_name_list) @torch.no_grad() @@ -133,9 +138,16 @@ def get_embedding(img, model, grayscale=False): ## Resize the image while keeping the original aspect ratio pad_color = 255 if grayscale else (255, 255, 255) - img = ImageOps.expand(img, ( - (max(img.size) - img.size[0]) // 2, (max(img.size) - img.size[1]) // 2, - (max(img.size) - img.size[0]) // 2, (max(img.size) - img.size[1]) // 2), fill=pad_color) + img = ImageOps.expand( + img, + ( + (max(img.size) - img.size[0]) // 2, + (max(img.size) - img.size[1]) // 2, + (max(img.size) - img.size[0]) // 2, + (max(img.size) - img.size[1]) // 2 + ), + fill=pad_color + ) img = img.resize((img_size, img_size)) diff --git a/logo_recog.py b/logo_recog.py index 0297995..bd35fff 100644 --- a/logo_recog.py +++ b/logo_recog.py @@ -24,8 +24,8 @@ def pred_rcnn(im, predictor): outputs = predictor(im) instances = outputs['instances'] - pred_classes = instances.pred_classes # tensor - pred_boxes = instances.pred_boxes # Boxes object + pred_classes = instances.pred_classes # tensor + pred_boxes = instances.pred_boxes # Boxes object logo_boxes = pred_boxes[pred_classes == 1].tensor @@ -52,6 +52,13 @@ def config_rcnn(cfg_path, weights_path, conf_threshold): predictor = DefaultPredictor(cfg) return predictor +COLORS = { + 0: (255, 255, 0), # logo + 1: (36, 255, 12), # input + 2: (0, 255, 255), # button + 3: (0, 0, 255), # label + 4: (255, 0, 0) # block +} def vis(img_path, pred_boxes): ''' @@ -71,8 +78,8 @@ def vis(img_path, pred_boxes): # draw rectangle for j, box in enumerate(pred_boxes): if j == 0: - cv2.rectangle(check, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), (255, 255, 0), 2) + cv2.rectangle(check, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), COLORS['0'], 2) else: - cv2.rectangle(check, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), (36, 255, 12), 2) + cv2.rectangle(check, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), COLORS['1'], 2) return check diff --git a/setup.sh b/setup.sh index 1957d72..0fdd141 100755 --- a/setup.sh +++ b/setup.sh @@ -6,6 +6,7 @@ set -e # Function to display error messages and exit error_exit() { echo "$1" >&2 + echo "$(date): $1" >> error.log # Log error to a file for debugging exit 1 }