Skip to content

Commit

Permalink
data_augment fix
Browse files Browse the repository at this point in the history
  • Loading branch information
RRFRRF committed Jan 3, 2025
1 parent e107c3e commit e4e95cf
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 38 deletions.
3 changes: 2 additions & 1 deletion configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ def load_config(reload_targetlist=False):

if reload_targetlist or (not os.path.exists(os.path.join(os.path.dirname(__file__), LOGO_FEATS_NAME))):
LOGO_FEATS, LOGO_FILES = cache_reference_list(model=SIAMESE_MODEL,
targetlist_path=full_targetlist_folder_dir)
targetlist_path=full_targetlist_folder_dir,
data_augmentation=configs['SIAMESE_MODEL']['DATA_AUGMENTATION'])
print('Finish loading protected logo list')
np.save(os.path.join(os.path.dirname(__file__), LOGO_FEATS_NAME), LOGO_FEATS)
np.save(os.path.join(os.path.dirname(__file__), LOGO_FILES_NAME), LOGO_FILES)
Expand Down
3 changes: 2 additions & 1 deletion configs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@ SIAMESE_MODEL:
MATCH_THRE: 0.87 # FIXME: threshold is 0.87 in phish-discovery?
WEIGHTS_PATH: models/resnetv2_rgb_new.pth.tar
TARGETLIST_PATH: models/expand_targetlist.zip
DOMAIN_MAP_PATH: models/domain_map.pkl
DOMAIN_MAP_PATH: models/domain_map.pkl
DATA_AUGMENTATION: False # whether to use data augmentation, default is False, 1 logo generate 7 embeddings
129 changes: 93 additions & 36 deletions logo_matching.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from PIL import Image, ImageOps
from PIL import Image, ImageOps, ImageEnhance, ImageFilter
from torchvision import transforms
from utils import brand_converter, resolution_alignment, l2_norm
from models import KNOWN_MODELS
Expand Down Expand Up @@ -79,7 +79,7 @@ def load_model_weights(num_classes: int, weights_path: str):
return model


def cache_reference_list(model, targetlist_path: str, grayscale=False):
def cache_reference_list(model, targetlist_path: str, grayscale=False, data_augmentation=False):
'''
cache the embeddings of the reference list
:param targetlist_path: targetlist folder
Expand All @@ -88,31 +88,95 @@ def cache_reference_list(model, targetlist_path: str, grayscale=False):
:return file_name_list: targetlist paths
'''

# Prediction for targetlists
# Prediction for targetlists
logo_feat_list = []
file_name_list = []

target_list = os.listdir(targetlist_path)
for target in tqdm(target_list):
for target in tqdm(os.listdir(targetlist_path)):
if target.startswith('.'): # skip hidden files
continue
logo_list = os.listdir(os.path.join(targetlist_path, target))
for logo_path in logo_list:
# List of valid image extensions
valid_extensions = ['.png', 'PNG', '.jpeg', '.jpg', '.JPG', '.JPEG']
if any(logo_path.endswith(ext) for ext in valid_extensions):
skip_prefixes = ['loginpage', 'homepage']
if any(logo_path.startswith(prefix) for prefix in skip_prefixes): # skip homepage/loginpage
continue
try:
logo_feat_list.append(get_embedding(img=os.path.join(targetlist_path, target, logo_path),
model=model, grayscale=grayscale))
file_name_list.append(str(os.path.join(targetlist_path, target, logo_path)))
except OSError:
print(f"Error opening image: {os.path.join(targetlist_path, target, logo_path)}")
for logo_path in os.listdir(os.path.join(targetlist_path, target)):
if logo_path.endswith('.png') or logo_path.endswith('.jpeg') or logo_path.endswith(
'.jpg') or logo_path.endswith('.PNG') \
or logo_path.endswith('.JPG') or logo_path.endswith('.JPEG'):
if logo_path.startswith('loginpage') or logo_path.startswith('homepage'): # skip homepage/loginpage
continue

full_path = os.path.join(targetlist_path, target, logo_path)

if data_augmentation:
# 对每张图片进行数据增强
augmented_images = apply_augmentations(full_path)
else:
augmented_images = [Image.open(full_path).convert('RGB')]
# 为每个增强后的图片生成embedding
for aug_img in augmented_images:
logo_feat_list.append(get_embedding(img=aug_img, model=model, grayscale=grayscale))
file_name_list.append(str(full_path)) # 使用原始文件路径,这样可以追踪到原图

return np.asarray(logo_feat_list), np.asarray(file_name_list)


def apply_augmentations(img_path):
'''
对图片进行数据增强
:param img_path: 图片路径
:return: 增强后的图片列表
'''
img = Image.open(img_path).convert('RGB')
augmented_images = [img] # 原始图片
# return augmented_images
# 1. 亮度50%
enhancer = ImageEnhance.Brightness(img)
augmented_images.append(enhancer.enhance(0.5))

# # 2. 颜色反转(深夜模式,只反转黑白颜色)
# img_array = np.array(img)
# # 计算图片的亮度
# brightness = np.mean(img_array, axis=2)
# # 创建掩码:True表示非常亮的像素(接近白色)
# white_mask = brightness > 240
# # 创建掩码:True表示非常暗的像素(接近黑色)
# black_mask = brightness < 30

# # 复制原图
# dark_mode = img_array.copy()
# # 将白色区域变成深灰色 (50, 50, 50)
# dark_mode[white_mask] = [50, 50, 50]
# # 将黑色区域变成白色 (255, 255, 255)
# dark_mode[black_mask] = [255, 255, 255]

# augmented_images.append(Image.fromarray(dark_mode))

# 3. 降噪+锐化
blurred = img.filter(ImageFilter.GaussianBlur(radius=1))
enhancer = ImageEnhance.Sharpness(blurred)
augmented_images.append(enhancer.enhance(2.0))

# 4. 抗锯齿
w, h = img.size
upscaled = img.resize((w * 2, h * 2), Image.LANCZOS)
downscaled = upscaled.resize((w, h), Image.LANCZOS)
augmented_images.append(downscaled)

# 5. 锐化2.5x
enhancer = ImageEnhance.Sharpness(img)
augmented_images.append(enhancer.enhance(2.5))

# 6. 饱和度150%
enhancer = ImageEnhance.Color(img)
augmented_images.append(enhancer.enhance(1.5))

# 7. 饱和度50%
enhancer = ImageEnhance.Color(img)
augmented_images.append(enhancer.enhance(0.5))

# # 8. 全灰度
# grayscale = ImageOps.grayscale(img)
# grayscale_rgb = Image.merge('RGB', (grayscale, grayscale, grayscale))
# augmented_images.append(grayscale_rgb)

return logo_feat_list, file_name_list
return augmented_images


@torch.no_grad()
Expand Down Expand Up @@ -140,16 +204,9 @@ def get_embedding(img, model, grayscale=False):

## Resize the image while keeping the original aspect ratio
pad_color = 255 if grayscale else (255, 255, 255)
img = ImageOps.expand(
img,
(
(max(img.size) - img.size[0]) // 2,
(max(img.size) - img.size[1]) // 2,
(max(img.size) - img.size[0]) // 2,
(max(img.size) - img.size[1]) // 2
),
fill=pad_color
)
img = ImageOps.expand(img, (
(max(img.size) - img.size[0]) // 2, (max(img.size) - img.size[1]) // 2,
(max(img.size) - img.size[0]) // 2, (max(img.size) - img.size[1]) // 2), fill=pad_color)

img = img.resize((img_size, img_size))

Expand Down Expand Up @@ -187,17 +244,17 @@ def pred_brand(model, domain_map, logo_feat_list, file_name_list, shot_path: str
print('Screenshot cannot be open')
return None, None, None

# get predicted box --> crop from screenshot
## get predicted box --> crop from screenshot
cropped = img.crop((gt_bbox[0], gt_bbox[1], gt_bbox[2], gt_bbox[3]))
img_feat = get_embedding(cropped, model, grayscale=grayscale)

# get cosine similarity with every protected logo
## get cosine similarity with every protected logo
sim_list = logo_feat_list @ img_feat.T # take dot product for every pair of embeddings (Cosine Similarity)
pred_brand_list = file_name_list

assert len(sim_list) == len(pred_brand_list)

# get top 3 brands
## get top 3 brands
idx = np.argsort(sim_list)[::-1][:3]
pred_brand_list = np.array(pred_brand_list)[idx]
sim_list = np.array(sim_list)[idx]
Expand All @@ -210,17 +267,17 @@ def pred_brand(model, domain_map, logo_feat_list, file_name_list, shot_path: str
for j in range(3):
predicted_brand, predicted_domain = None, None

# If we are trying those lower rank logo, the predicted brand of them should be the same as top1 logo, otherwise might be false positive
## If we are trying those lower rank logo, the predicted brand of them should be the same as top1 logo, otherwise might be false positive
if top3_brandlist[j] != top3_brandlist[0]:
continue

# If the largest similarity exceeds threshold
## If the largest similarity exceeds threshold
if top3_simlist[j] >= similarity_threshold:
predicted_brand = top3_brandlist[j]
predicted_domain = top3_domainlist[j]
final_sim = top3_simlist[j]

# Else if not exceed, try resolution alignment, see if can improve
## Else if not exceed, try resolution alignment, see if can improve
elif do_resolution_alignment:
orig_candidate_logo = Image.open(pred_brand_list[j])
cropped, candidate_logo = resolution_alignment(cropped, orig_candidate_logo)
Expand Down

0 comments on commit e4e95cf

Please sign in to comment.