diff --git a/configs.py b/configs.py index 25980b1..8ed7a64 100644 --- a/configs.py +++ b/configs.py @@ -51,7 +51,8 @@ def load_config(reload_targetlist=False): if reload_targetlist or (not os.path.exists(os.path.join(os.path.dirname(__file__), LOGO_FEATS_NAME))): LOGO_FEATS, LOGO_FILES = cache_reference_list(model=SIAMESE_MODEL, - targetlist_path=full_targetlist_folder_dir) + targetlist_path=full_targetlist_folder_dir, + data_augmentation=configs['SIAMESE_MODEL']['DATA_AUGMENTATION']) print('Finish loading protected logo list') np.save(os.path.join(os.path.dirname(__file__), LOGO_FEATS_NAME), LOGO_FEATS) np.save(os.path.join(os.path.dirname(__file__), LOGO_FILES_NAME), LOGO_FILES) diff --git a/configs.yaml b/configs.yaml index 6b318ae..a8a90b5 100644 --- a/configs.yaml +++ b/configs.yaml @@ -8,4 +8,5 @@ SIAMESE_MODEL: MATCH_THRE: 0.87 # FIXME: threshold is 0.87 in phish-discovery? WEIGHTS_PATH: models/resnetv2_rgb_new.pth.tar TARGETLIST_PATH: models/expand_targetlist.zip - DOMAIN_MAP_PATH: models/domain_map.pkl \ No newline at end of file + DOMAIN_MAP_PATH: models/domain_map.pkl + DATA_AUGMENTATION: False # whether to use data augmentation, default is False, 1 logo generate 7 embeddings \ No newline at end of file diff --git a/logo_matching.py b/logo_matching.py index 77590f7..a882474 100644 --- a/logo_matching.py +++ b/logo_matching.py @@ -1,4 +1,4 @@ -from PIL import Image, ImageOps +from PIL import Image, ImageOps, ImageEnhance, ImageFilter from torchvision import transforms from utils import brand_converter, resolution_alignment, l2_norm from models import KNOWN_MODELS @@ -79,7 +79,7 @@ def load_model_weights(num_classes: int, weights_path: str): return model -def cache_reference_list(model, targetlist_path: str, grayscale=False): +def cache_reference_list(model, targetlist_path: str, grayscale=False, data_augmentation=False): ''' cache the embeddings of the reference list :param targetlist_path: targetlist folder @@ -88,31 +88,95 @@ def cache_reference_list(model, targetlist_path: str, grayscale=False): :return file_name_list: targetlist paths ''' - # Prediction for targetlists + # Prediction for targetlists logo_feat_list = [] file_name_list = [] - target_list = os.listdir(targetlist_path) - for target in tqdm(target_list): + for target in tqdm(os.listdir(targetlist_path)): if target.startswith('.'): # skip hidden files continue - logo_list = os.listdir(os.path.join(targetlist_path, target)) - for logo_path in logo_list: - # List of valid image extensions - valid_extensions = ['.png', 'PNG', '.jpeg', '.jpg', '.JPG', '.JPEG'] - if any(logo_path.endswith(ext) for ext in valid_extensions): - skip_prefixes = ['loginpage', 'homepage'] - if any(logo_path.startswith(prefix) for prefix in skip_prefixes): # skip homepage/loginpage - continue - try: - logo_feat_list.append(get_embedding(img=os.path.join(targetlist_path, target, logo_path), - model=model, grayscale=grayscale)) - file_name_list.append(str(os.path.join(targetlist_path, target, logo_path))) - except OSError: - print(f"Error opening image: {os.path.join(targetlist_path, target, logo_path)}") + for logo_path in os.listdir(os.path.join(targetlist_path, target)): + if logo_path.endswith('.png') or logo_path.endswith('.jpeg') or logo_path.endswith( + '.jpg') or logo_path.endswith('.PNG') \ + or logo_path.endswith('.JPG') or logo_path.endswith('.JPEG'): + if logo_path.startswith('loginpage') or logo_path.startswith('homepage'): # skip homepage/loginpage continue + + full_path = os.path.join(targetlist_path, target, logo_path) + + if data_augmentation: + # 对每张图片进行数据增强 + augmented_images = apply_augmentations(full_path) + else: + augmented_images = [Image.open(full_path).convert('RGB')] + # 为每个增强后的图片生成embedding + for aug_img in augmented_images: + logo_feat_list.append(get_embedding(img=aug_img, model=model, grayscale=grayscale)) + file_name_list.append(str(full_path)) # 使用原始文件路径,这样可以追踪到原图 + + return np.asarray(logo_feat_list), np.asarray(file_name_list) + + +def apply_augmentations(img_path): + ''' + 对图片进行数据增强 + :param img_path: 图片路径 + :return: 增强后的图片列表 + ''' + img = Image.open(img_path).convert('RGB') + augmented_images = [img] # 原始图片 + # return augmented_images + # 1. 亮度50% + enhancer = ImageEnhance.Brightness(img) + augmented_images.append(enhancer.enhance(0.5)) + + # # 2. 颜色反转(深夜模式,只反转黑白颜色) + # img_array = np.array(img) + # # 计算图片的亮度 + # brightness = np.mean(img_array, axis=2) + # # 创建掩码:True表示非常亮的像素(接近白色) + # white_mask = brightness > 240 + # # 创建掩码:True表示非常暗的像素(接近黑色) + # black_mask = brightness < 30 + + # # 复制原图 + # dark_mode = img_array.copy() + # # 将白色区域变成深灰色 (50, 50, 50) + # dark_mode[white_mask] = [50, 50, 50] + # # 将黑色区域变成白色 (255, 255, 255) + # dark_mode[black_mask] = [255, 255, 255] + + # augmented_images.append(Image.fromarray(dark_mode)) + + # 3. 降噪+锐化 + blurred = img.filter(ImageFilter.GaussianBlur(radius=1)) + enhancer = ImageEnhance.Sharpness(blurred) + augmented_images.append(enhancer.enhance(2.0)) + + # 4. 抗锯齿 + w, h = img.size + upscaled = img.resize((w * 2, h * 2), Image.LANCZOS) + downscaled = upscaled.resize((w, h), Image.LANCZOS) + augmented_images.append(downscaled) + + # 5. 锐化2.5x + enhancer = ImageEnhance.Sharpness(img) + augmented_images.append(enhancer.enhance(2.5)) + + # 6. 饱和度150% + enhancer = ImageEnhance.Color(img) + augmented_images.append(enhancer.enhance(1.5)) + + # 7. 饱和度50% + enhancer = ImageEnhance.Color(img) + augmented_images.append(enhancer.enhance(0.5)) + + # # 8. 全灰度 + # grayscale = ImageOps.grayscale(img) + # grayscale_rgb = Image.merge('RGB', (grayscale, grayscale, grayscale)) + # augmented_images.append(grayscale_rgb) - return logo_feat_list, file_name_list + return augmented_images @torch.no_grad() @@ -140,16 +204,9 @@ def get_embedding(img, model, grayscale=False): ## Resize the image while keeping the original aspect ratio pad_color = 255 if grayscale else (255, 255, 255) - img = ImageOps.expand( - img, - ( - (max(img.size) - img.size[0]) // 2, - (max(img.size) - img.size[1]) // 2, - (max(img.size) - img.size[0]) // 2, - (max(img.size) - img.size[1]) // 2 - ), - fill=pad_color - ) + img = ImageOps.expand(img, ( + (max(img.size) - img.size[0]) // 2, (max(img.size) - img.size[1]) // 2, + (max(img.size) - img.size[0]) // 2, (max(img.size) - img.size[1]) // 2), fill=pad_color) img = img.resize((img_size, img_size)) @@ -187,17 +244,17 @@ def pred_brand(model, domain_map, logo_feat_list, file_name_list, shot_path: str print('Screenshot cannot be open') return None, None, None - # get predicted box --> crop from screenshot + ## get predicted box --> crop from screenshot cropped = img.crop((gt_bbox[0], gt_bbox[1], gt_bbox[2], gt_bbox[3])) img_feat = get_embedding(cropped, model, grayscale=grayscale) - # get cosine similarity with every protected logo + ## get cosine similarity with every protected logo sim_list = logo_feat_list @ img_feat.T # take dot product for every pair of embeddings (Cosine Similarity) pred_brand_list = file_name_list assert len(sim_list) == len(pred_brand_list) - # get top 3 brands + ## get top 3 brands idx = np.argsort(sim_list)[::-1][:3] pred_brand_list = np.array(pred_brand_list)[idx] sim_list = np.array(sim_list)[idx] @@ -210,17 +267,17 @@ def pred_brand(model, domain_map, logo_feat_list, file_name_list, shot_path: str for j in range(3): predicted_brand, predicted_domain = None, None - # If we are trying those lower rank logo, the predicted brand of them should be the same as top1 logo, otherwise might be false positive + ## If we are trying those lower rank logo, the predicted brand of them should be the same as top1 logo, otherwise might be false positive if top3_brandlist[j] != top3_brandlist[0]: continue - # If the largest similarity exceeds threshold + ## If the largest similarity exceeds threshold if top3_simlist[j] >= similarity_threshold: predicted_brand = top3_brandlist[j] predicted_domain = top3_domainlist[j] final_sim = top3_simlist[j] - # Else if not exceed, try resolution alignment, see if can improve + ## Else if not exceed, try resolution alignment, see if can improve elif do_resolution_alignment: orig_candidate_logo = Image.open(pred_brand_list[j]) cropped, candidate_logo = resolution_alignment(cropped, orig_candidate_logo)