From 827cc5ac3c6b1467e549b5c49b62dca3f11c1dd2 Mon Sep 17 00:00:00 2001 From: wangxu60 Date: Sun, 1 Dec 2024 12:45:01 +0800 Subject: [PATCH] pep change --- .idea/.gitignore | 8 ++ .idea/Phishpedia.iml | 7 ++ .idea/deployment.xml | 28 +++++ .../inspectionProfiles/profiles_settings.xml | 6 + .idea/vcs.xml | 6 + GUI/function.py | 12 +- GUI/ui.py | 4 +- configs.py | 15 +-- logo_matching.py | 20 ++-- logo_recog.py | 3 +- models.py | 104 ++++++++-------- phishpedia.py | 48 ++++---- text_recog.py | 5 +- utils.py | 111 +++++++++--------- 14 files changed, 228 insertions(+), 149 deletions(-) create mode 100644 .idea/.gitignore create mode 100644 .idea/Phishpedia.iml create mode 100644 .idea/deployment.xml create mode 100644 .idea/inspectionProfiles/profiles_settings.xml create mode 100644 .idea/vcs.xml diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..35410ca --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,8 @@ +# 默认忽略的文件 +/shelf/ +/workspace.xml +# 基于编辑器的 HTTP 客户端请求 +/httpRequests/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml diff --git a/.idea/Phishpedia.iml b/.idea/Phishpedia.iml new file mode 100644 index 0000000..ec63674 --- /dev/null +++ b/.idea/Phishpedia.iml @@ -0,0 +1,7 @@ + + + + + \ No newline at end of file diff --git a/.idea/deployment.xml b/.idea/deployment.xml new file mode 100644 index 0000000..3e51933 --- /dev/null +++ b/.idea/deployment.xml @@ -0,0 +1,28 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 0000000..105ce2d --- /dev/null +++ b/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000..35eb1dd --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/GUI/function.py b/GUI/function.py index b1940fe..182c835 100644 --- a/GUI/function.py +++ b/GUI/function.py @@ -4,6 +4,7 @@ from phishpedia import PhishpediaWrapper import cv2 + class PhishpediaFunction: def __init__(self, ui): self.ui = ui @@ -12,7 +13,8 @@ def __init__(self, ui): def upload_image(self): options = QFileDialog.Options() - file_name, _ = QFileDialog.getOpenFileName(self.ui, "Select Screenshot", "", "Images (*.png *.jpg *.jpeg)", options=options) + file_name, _ = QFileDialog.getOpenFileName(self.ui, "Select Screenshot", "", "Images (*.png *.jpg *.jpeg)", + options=options) if file_name: self.ui.image_input.setText(file_name) @@ -24,7 +26,8 @@ def detect_phishing(self): self.ui.result_display.setText("Please enter URL and upload a screenshot.") return - phish_category, pred_target, matched_domain, plotvis, siamese_conf, pred_boxes, logo_recog_time, logo_match_time = self.phishpedia_cls.test_orig_phishpedia(url, screenshot_path, None) + phish_category, pred_target, matched_domain, plotvis, siamese_conf, pred_boxes, logo_recog_time, logo_match_time = self.phishpedia_cls.test_orig_phishpedia( + url, screenshot_path, None) # 根据 phish_category 改变颜色 phish_category_color = 'green' if phish_category == 0 else 'red' @@ -49,7 +52,7 @@ def display_image(self, plotvis): height, width, channel = plotvis_rgb.shape bytes_per_line = 3 * width plotvis_qimage = QImage(plotvis_rgb.data, width, height, bytes_per_line, QImage.Format_RGB888) - + self.current_pixmap = QPixmap.fromImage(plotvis_qimage) self.update_image_display() except Exception as e: @@ -60,7 +63,8 @@ def update_image_display(self): available_width = self.ui.width() available_height = self.ui.height() - self.ui.visualization_label.geometry().bottom() - 50 - scaled_pixmap = self.current_pixmap.scaled(available_width, available_height, Qt.KeepAspectRatio, Qt.SmoothTransformation) + scaled_pixmap = self.current_pixmap.scaled(available_width, available_height, Qt.KeepAspectRatio, + Qt.SmoothTransformation) self.ui.visualization_display.setPixmap(scaled_pixmap) def on_resize(self, event): diff --git a/GUI/ui.py b/GUI/ui.py index bde47fe..5a281f0 100644 --- a/GUI/ui.py +++ b/GUI/ui.py @@ -4,6 +4,7 @@ from PyQt5.QtWidgets import QApplication from .function import PhishpediaFunction + class PhishpediaUI(QWidget): def __init__(self): super().__init__() @@ -117,7 +118,8 @@ def set_dynamic_font_size(self): font = QFont() font.setPointSizeF(font_size) - for widget in self.findChildren(QLabel) + self.findChildren(QLineEdit) + self.findChildren(QPushButton) + [self.result_display]: + for widget in self.findChildren(QLabel) + self.findChildren(QLineEdit) + self.findChildren(QPushButton) + [ + self.result_display]: widget.setFont(font) def init_phish_test_page(self): diff --git a/configs.py b/configs.py index d5b2c2e..25980b1 100644 --- a/configs.py +++ b/configs.py @@ -5,12 +5,13 @@ import os import numpy as np + def get_absolute_path(relative_path): base_path = os.path.dirname(__file__) return os.path.abspath(os.path.join(base_path, relative_path)) -def load_config(reload_targetlist=False): +def load_config(reload_targetlist=False): with open(os.path.join(os.path.dirname(__file__), 'configs.yaml')) as file: configs = yaml.load(file, Loader=yaml.FullLoader) @@ -42,8 +43,8 @@ def load_config(reload_targetlist=False): # os.makedirs(full_targetlist_folder_dir, exist_ok=True) # subprocess.run(f'unzip -o "{targetlist_zip_path}" -d "{full_targetlist_folder_dir}"', shell=True) - SIAMESE_MODEL = load_model_weights( num_classes=configs['SIAMESE_MODEL']['NUM_CLASSES'], - weights_path=configs['SIAMESE_MODEL']['WEIGHTS_PATH']) + SIAMESE_MODEL = load_model_weights(num_classes=configs['SIAMESE_MODEL']['NUM_CLASSES'], + weights_path=configs['SIAMESE_MODEL']['WEIGHTS_PATH']) LOGO_FEATS_NAME = 'LOGO_FEATS.npy' LOGO_FILES_NAME = 'LOGO_FILES.npy' @@ -52,12 +53,12 @@ def load_config(reload_targetlist=False): LOGO_FEATS, LOGO_FILES = cache_reference_list(model=SIAMESE_MODEL, targetlist_path=full_targetlist_folder_dir) print('Finish loading protected logo list') - np.save(os.path.join(os.path.dirname(__file__),LOGO_FEATS_NAME), LOGO_FEATS) - np.save(os.path.join(os.path.dirname(__file__),LOGO_FILES_NAME), LOGO_FILES) + np.save(os.path.join(os.path.dirname(__file__), LOGO_FEATS_NAME), LOGO_FEATS) + np.save(os.path.join(os.path.dirname(__file__), LOGO_FILES_NAME), LOGO_FILES) else: - LOGO_FEATS, LOGO_FILES = np.load(os.path.join(os.path.dirname(__file__),LOGO_FEATS_NAME)), \ - np.load(os.path.join(os.path.dirname(__file__),LOGO_FILES_NAME)) + LOGO_FEATS, LOGO_FILES = np.load(os.path.join(os.path.dirname(__file__), LOGO_FEATS_NAME)), \ + np.load(os.path.join(os.path.dirname(__file__), LOGO_FILES_NAME)) DOMAIN_MAP_PATH = configs['SIAMESE_MODEL']['DOMAIN_MAP_PATH'] diff --git a/logo_matching.py b/logo_matching.py index 224b562..943f104 100644 --- a/logo_matching.py +++ b/logo_matching.py @@ -10,18 +10,17 @@ from tldextract import tldextract import pickle + def check_domain_brand_inconsistency(logo_boxes, domain_map_path: str, model, logo_feat_list, file_name_list, shot_path: str, url: str, ts: float, - topk: float=3): - + topk: float = 3): # targetlist domain list with open(domain_map_path, 'rb') as handle: domain_map = pickle.load(handle) - print('number of logo boxes:', len(logo_boxes)) extracted_domain = tldextract.extract(url).domain + '.' + tldextract.extract(url).suffix matched_target, matched_domain, matched_coord, this_conf = None, None, None, None @@ -36,9 +35,10 @@ def check_domain_brand_inconsistency(logo_boxes, min_x, min_y, max_x, max_y = coord bbox = [float(min_x), float(min_y), float(max_x), float(max_y)] matched_target, matched_domain, this_conf = pred_brand(model, domain_map, - logo_feat_list, file_name_list, - shot_path, bbox, t_s=ts, grayscale=False, - do_aspect_ratio_check=False, do_resolution_alignment=False) + logo_feat_list, file_name_list, + shot_path, bbox, t_s=ts, grayscale=False, + do_aspect_ratio_check=False, + do_resolution_alignment=False) # print(target_this, domain_this, this_conf) # domain matcher to avoid FP if matched_target and matched_domain: @@ -51,6 +51,7 @@ def check_domain_brand_inconsistency(logo_boxes, return brand_converter(matched_target), matched_domain, matched_coord, this_conf + def load_model_weights(num_classes: int, weights_path: str): # Initialize model device = 'cuda' if torch.cuda.is_available() else 'cpu' @@ -72,6 +73,7 @@ def load_model_weights(num_classes: int, weights_path: str): model.eval() return model + def cache_reference_list(model, targetlist_path: str, grayscale=False): ''' cache the embeddings of the reference list @@ -92,7 +94,8 @@ def cache_reference_list(model, targetlist_path: str, grayscale=False): if target.startswith('.'): # skip hidden files continue for logo_path in os.listdir(os.path.join(targetlist_path, target)): - if logo_path.endswith('.png') or logo_path.endswith('.jpeg') or logo_path.endswith('.jpg') or logo_path.endswith('.PNG') \ + if logo_path.endswith('.png') or logo_path.endswith('.jpeg') or logo_path.endswith( + '.jpg') or logo_path.endswith('.PNG') \ or logo_path.endswith('.JPG') or logo_path.endswith('.JPEG'): if logo_path.startswith('loginpage') or logo_path.startswith('homepage'): # skip homepage/loginpage continue @@ -102,6 +105,7 @@ def cache_reference_list(model, targetlist_path: str, grayscale=False): return np.asarray(logo_feat_list), np.asarray(file_name_list) + @torch.no_grad() def get_embedding(img, model, grayscale=False): ''' @@ -122,7 +126,7 @@ def get_embedding(img, model, grayscale=False): img_transforms = transforms.Compose( [transforms.ToTensor(), transforms.Normalize(mean=mean, std=std), - ]) + ]) img = Image.open(img) if isinstance(img, str) else img img = img.convert("L").convert("RGB") if grayscale else img.convert("RGB") diff --git a/logo_recog.py b/logo_recog.py index 21640b3..a6434f3 100644 --- a/logo_recog.py +++ b/logo_recog.py @@ -4,6 +4,7 @@ import numpy as np import torch + def pred_rcnn(im, predictor): ''' Perform inference for RCNN @@ -72,4 +73,4 @@ def vis(img_path, pred_boxes): else: cv2.rectangle(check, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), (36, 255, 12), 2) - return check \ No newline at end of file + return check diff --git a/models.py b/models.py index 34d3992..4912c94 100644 --- a/models.py +++ b/models.py @@ -15,7 +15,7 @@ # Lint as: python3 """Bottleneck ResNet v2 with GroupNorm and Weight Standardization.""" -from collections import OrderedDict # pylint: disable=g-importing-member +from collections import OrderedDict # pylint: disable=g-importing-member import torch import torch.nn as nn @@ -29,17 +29,17 @@ def forward(self, x): v, m = torch.var_mean(w, dim=[1, 2, 3], keepdim=True, unbiased=False) w = (w - m) / torch.sqrt(v + 1e-10) return F.conv2d(x, w, self.bias, self.stride, self.padding, - self.dilation, self.groups) + self.dilation, self.groups) def conv3x3(cin, cout, stride=1, groups=1, bias=False): return StdConv2d(cin, cout, kernel_size=3, stride=stride, - padding=1, bias=bias, groups=groups) + padding=1, bias=bias, groups=groups) def conv1x1(cin, cout, stride=1, bias=False): return StdConv2d(cin, cout, kernel_size=1, stride=stride, - padding=0, bias=bias) + padding=0, bias=bias) def tf2th(conv_weights): @@ -61,12 +61,12 @@ class PreActBottleneck(nn.Module): def __init__(self, cin, cout=None, cmid=None, stride=1): super().__init__() cout = cout or cin - cmid = cmid or cout//4 + cmid = cmid or cout // 4 self.gn1 = nn.GroupNorm(32, cin) self.conv1 = conv1x1(cin, cmid) self.gn2 = nn.GroupNorm(32, cmid) - self.conv2 = conv3x3(cmid, cmid, stride) # Original code has it on conv1!! + self.conv2 = conv3x3(cmid, cmid, stride) # Original code has it on conv1!! self.gn3 = nn.GroupNorm(32, cmid) self.conv3 = conv1x1(cmid, cout) self.relu = nn.ReLU(inplace=True) @@ -112,44 +112,48 @@ class ResNetV2(nn.Module): def __init__(self, block_units, width_factor, head_size=21843, zero_head=False): super().__init__() - wf = width_factor # shortcut 'cause we'll use it a lot. + wf = width_factor # shortcut 'cause we'll use it a lot. # The following will be unreadable if we split lines. # pylint: disable=line-too-long self.root = nn.Sequential(OrderedDict([ - ('conv', StdConv2d(3, 64*wf, kernel_size=7, stride=2, padding=3, bias=False)), - ('pad', nn.ConstantPad2d(1, 0)), - ('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=0)), - # The following is subtly not the same! - # ('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)), + ('conv', StdConv2d(3, 64 * wf, kernel_size=7, stride=2, padding=3, bias=False)), + ('pad', nn.ConstantPad2d(1, 0)), + ('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=0)), + # The following is subtly not the same! + # ('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)), ])) self.body = nn.Sequential(OrderedDict([ - ('block1', nn.Sequential(OrderedDict( - [('unit01', PreActBottleneck(cin=64*wf, cout=256*wf, cmid=64*wf))] + - [(f'unit{i:02d}', PreActBottleneck(cin=256*wf, cout=256*wf, cmid=64*wf)) for i in range(2, block_units[0] + 1)], - ))), - ('block2', nn.Sequential(OrderedDict( - [('unit01', PreActBottleneck(cin=256*wf, cout=512*wf, cmid=128*wf, stride=2))] + - [(f'unit{i:02d}', PreActBottleneck(cin=512*wf, cout=512*wf, cmid=128*wf)) for i in range(2, block_units[1] + 1)], - ))), - ('block3', nn.Sequential(OrderedDict( - [('unit01', PreActBottleneck(cin=512*wf, cout=1024*wf, cmid=256*wf, stride=2))] + - [(f'unit{i:02d}', PreActBottleneck(cin=1024*wf, cout=1024*wf, cmid=256*wf)) for i in range(2, block_units[2] + 1)], - ))), - ('block4', nn.Sequential(OrderedDict( - [('unit01', PreActBottleneck(cin=1024*wf, cout=2048*wf, cmid=512*wf, stride=2))] + - [(f'unit{i:02d}', PreActBottleneck(cin=2048*wf, cout=2048*wf, cmid=512*wf)) for i in range(2, block_units[3] + 1)], - ))), + ('block1', nn.Sequential(OrderedDict( + [('unit01', PreActBottleneck(cin=64 * wf, cout=256 * wf, cmid=64 * wf))] + + [(f'unit{i:02d}', PreActBottleneck(cin=256 * wf, cout=256 * wf, cmid=64 * wf)) for i in + range(2, block_units[0] + 1)], + ))), + ('block2', nn.Sequential(OrderedDict( + [('unit01', PreActBottleneck(cin=256 * wf, cout=512 * wf, cmid=128 * wf, stride=2))] + + [(f'unit{i:02d}', PreActBottleneck(cin=512 * wf, cout=512 * wf, cmid=128 * wf)) for i in + range(2, block_units[1] + 1)], + ))), + ('block3', nn.Sequential(OrderedDict( + [('unit01', PreActBottleneck(cin=512 * wf, cout=1024 * wf, cmid=256 * wf, stride=2))] + + [(f'unit{i:02d}', PreActBottleneck(cin=1024 * wf, cout=1024 * wf, cmid=256 * wf)) for i in + range(2, block_units[2] + 1)], + ))), + ('block4', nn.Sequential(OrderedDict( + [('unit01', PreActBottleneck(cin=1024 * wf, cout=2048 * wf, cmid=512 * wf, stride=2))] + + [(f'unit{i:02d}', PreActBottleneck(cin=2048 * wf, cout=2048 * wf, cmid=512 * wf)) for i in + range(2, block_units[3] + 1)], + ))), ])) # pylint: enable=line-too-long self.zero_head = zero_head self.head = nn.Sequential(OrderedDict([ - ('gn', nn.GroupNorm(32, 2048*wf)), - ('relu', nn.ReLU(inplace=True)), - ('avg', nn.AdaptiveAvgPool2d(output_size=1)), - ('conv', nn.Conv2d(2048*wf, head_size, kernel_size=1, bias=True)), + ('gn', nn.GroupNorm(32, 2048 * wf)), + ('relu', nn.ReLU(inplace=True)), + ('avg', nn.AdaptiveAvgPool2d(output_size=1)), + ('conv', nn.Conv2d(2048 * wf, head_size, kernel_size=1, bias=True)), ])) def features(self, x): @@ -159,19 +163,21 @@ def features(self, x): def forward(self, x): x = self.head(self.body(self.root(x))) - assert x.shape[-2:] == (1, 1) # We should have no spatial shape left. - return x[...,0,0] + assert x.shape[-2:] == (1, 1) # We should have no spatial shape left. + return x[..., 0, 0] def load_from(self, weights, prefix='resnet/'): with torch.no_grad(): - self.root.conv.weight.copy_(tf2th(weights[f'{prefix}root_block/standardized_conv2d/kernel'])) # pylint: disable=line-too-long + self.root.conv.weight.copy_( + tf2th(weights[f'{prefix}root_block/standardized_conv2d/kernel'])) # pylint: disable=line-too-long self.head.gn.weight.copy_(tf2th(weights[f'{prefix}group_norm/gamma'])) self.head.gn.bias.copy_(tf2th(weights[f'{prefix}group_norm/beta'])) if self.zero_head: nn.init.zeros_(self.head.conv.weight) nn.init.zeros_(self.head.conv.bias) else: - self.head.conv.weight.copy_(tf2th(weights[f'{prefix}head/conv2d/kernel'])) # pylint: disable=line-too-long + self.head.conv.weight.copy_( + tf2th(weights[f'{prefix}head/conv2d/kernel'])) # pylint: disable=line-too-long self.head.conv.bias.copy_(tf2th(weights[f'{prefix}head/conv2d/bias'])) for bname, block in self.body.named_children(): @@ -180,16 +186,16 @@ def load_from(self, weights, prefix='resnet/'): KNOWN_MODELS = OrderedDict([ - ('BiT-M-R50x1', lambda *a, **kw: ResNetV2([3, 4, 6, 3], 1, *a, **kw)), - ('BiT-M-R50x3', lambda *a, **kw: ResNetV2([3, 4, 6, 3], 3, *a, **kw)), - ('BiT-M-R101x1', lambda *a, **kw: ResNetV2([3, 4, 23, 3], 1, *a, **kw)), - ('BiT-M-R101x3', lambda *a, **kw: ResNetV2([3, 4, 23, 3], 3, *a, **kw)), - ('BiT-M-R152x2', lambda *a, **kw: ResNetV2([3, 8, 36, 3], 2, *a, **kw)), - ('BiT-M-R152x4', lambda *a, **kw: ResNetV2([3, 8, 36, 3], 4, *a, **kw)), - ('BiT-S-R50x1', lambda *a, **kw: ResNetV2([3, 4, 6, 3], 1, *a, **kw)), - ('BiT-S-R50x3', lambda *a, **kw: ResNetV2([3, 4, 6, 3], 3, *a, **kw)), - ('BiT-S-R101x1', lambda *a, **kw: ResNetV2([3, 4, 23, 3], 1, *a, **kw)), - ('BiT-S-R101x3', lambda *a, **kw: ResNetV2([3, 4, 23, 3], 3, *a, **kw)), - ('BiT-S-R152x2', lambda *a, **kw: ResNetV2([3, 8, 36, 3], 2, *a, **kw)), - ('BiT-S-R152x4', lambda *a, **kw: ResNetV2([3, 8, 36, 3], 4, *a, **kw)), -]) \ No newline at end of file + ('BiT-M-R50x1', lambda *a, **kw: ResNetV2([3, 4, 6, 3], 1, *a, **kw)), + ('BiT-M-R50x3', lambda *a, **kw: ResNetV2([3, 4, 6, 3], 3, *a, **kw)), + ('BiT-M-R101x1', lambda *a, **kw: ResNetV2([3, 4, 23, 3], 1, *a, **kw)), + ('BiT-M-R101x3', lambda *a, **kw: ResNetV2([3, 4, 23, 3], 3, *a, **kw)), + ('BiT-M-R152x2', lambda *a, **kw: ResNetV2([3, 8, 36, 3], 2, *a, **kw)), + ('BiT-M-R152x4', lambda *a, **kw: ResNetV2([3, 8, 36, 3], 4, *a, **kw)), + ('BiT-S-R50x1', lambda *a, **kw: ResNetV2([3, 4, 6, 3], 1, *a, **kw)), + ('BiT-S-R50x3', lambda *a, **kw: ResNetV2([3, 4, 6, 3], 3, *a, **kw)), + ('BiT-S-R101x1', lambda *a, **kw: ResNetV2([3, 4, 23, 3], 1, *a, **kw)), + ('BiT-S-R101x3', lambda *a, **kw: ResNetV2([3, 4, 23, 3], 3, *a, **kw)), + ('BiT-S-R152x2', lambda *a, **kw: ResNetV2([3, 8, 36, 3], 2, *a, **kw)), + ('BiT-S-R152x4', lambda *a, **kw: ResNetV2([3, 8, 36, 3], 4, *a, **kw)), +]) diff --git a/phishpedia.py b/phishpedia.py index 595c2ec..f704123 100644 --- a/phishpedia.py +++ b/phishpedia.py @@ -1,4 +1,3 @@ - import time from datetime import datetime import argparse @@ -13,9 +12,12 @@ from tqdm import tqdm import re -os.environ['KMP_DUPLICATE_LIB_OK']='True' -def result_file_write(f,folder,url,phish_category,pred_target,matched_domain,siamese_conf,logo_recog_time,logo_match_time): +os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' + + +def result_file_write(f, folder, url, phish_category, pred_target, matched_domain, siamese_conf, logo_recog_time, + logo_match_time): f.write(folder + "\t") f.write(url + "\t") f.write(str(phish_category) + "\t") @@ -25,6 +27,7 @@ def result_file_write(f,folder,url,phish_category,pred_target,matched_domain,sia f.write(str(round(logo_recog_time, 4)) + "\t") f.write(str(round(logo_match_time, 4)) + "\n") + class PhishpediaWrapper: _caller_prefix = "PhishpediaWrapper" _DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' @@ -66,6 +69,7 @@ def _to_device(self): # return False '''Phishpedia''' + # @profile def test_orig_phishpedia(self, url, screenshot_path, html_path): # 0 for benign, 1 for phish, default is benign @@ -93,15 +97,16 @@ def test_orig_phishpedia(self, url, screenshot_path, html_path): ######################## Step2: Siamese (Logo matcher) ######################################## start_time = time.time() - pred_target, matched_domain, matched_coord, siamese_conf = check_domain_brand_inconsistency(logo_boxes=pred_boxes, - domain_map_path=self.DOMAIN_MAP_PATH, - model=self.SIAMESE_MODEL, - logo_feat_list=self.LOGO_FEATS, - file_name_list=self.LOGO_FILES, - url=url, - shot_path=screenshot_path, - ts=self.SIAMESE_THRE, - topk=1) + pred_target, matched_domain, matched_coord, siamese_conf = check_domain_brand_inconsistency( + logo_boxes=pred_boxes, + domain_map_path=self.DOMAIN_MAP_PATH, + model=self.SIAMESE_MODEL, + logo_feat_list=self.LOGO_FEATS, + file_name_list=self.LOGO_FILES, + url=url, + shot_path=screenshot_path, + ts=self.SIAMESE_THRE, + topk=1) logo_match_time = time.time() - start_time if pred_target is None: @@ -111,8 +116,8 @@ def test_orig_phishpedia(self, url, screenshot_path, html_path): ######################## Step3: Simple input box check ############### # has_input_box = self.simple_input_box_regex(html_path=html_path) # if not has_input_box: - # print('No input box') - # return phish_category, pred_target, matched_domain, plotvis, siamese_conf, pred_boxes, logo_recog_time, logo_match_time + # print('No input box') + # return phish_category, pred_target, matched_domain, plotvis, siamese_conf, pred_boxes, logo_recog_time, logo_match_time # else: print('Match to Target: {} with confidence {:.4f}'.format(pred_target, siamese_conf)) phish_category = 1 @@ -121,7 +126,6 @@ def test_orig_phishpedia(self, url, screenshot_path, html_path): (int(matched_coord[0] + 20), int(matched_coord[1] + 20)), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 0), 2) - return phish_category, pred_target, matched_domain, plotvis, siamese_conf, pred_boxes, logo_recog_time, logo_match_time @@ -151,7 +155,6 @@ def test_orig_phishpedia(self, url, screenshot_path, html_path): os.makedirs(request_dir, exist_ok=True) - for folder in tqdm(os.listdir(request_dir)): html_path = os.path.join(request_dir, folder, "html.txt") screenshot_path = os.path.join(request_dir, folder, "shot.png") @@ -175,18 +178,17 @@ def test_orig_phishpedia(self, url, screenshot_path, html_path): continue phish_category, pred_target, matched_domain, \ - plotvis, siamese_conf, pred_boxes, \ - logo_recog_time, logo_match_time = phishpedia_cls.test_orig_phishpedia(url, screenshot_path, html_path) - + plotvis, siamese_conf, pred_boxes, \ + logo_recog_time, logo_match_time = phishpedia_cls.test_orig_phishpedia(url, screenshot_path, html_path) try: with open(result_txt, "a+", encoding='ISO-8859-1') as f: - result_file_write(f,folder,url,phish_category,pred_target,matched_domain,siamese_conf,logo_recog_time,logo_match_time) + result_file_write(f, folder, url, phish_category, pred_target, matched_domain, siamese_conf, + logo_recog_time, logo_match_time) except UnicodeError: with open(result_txt, "a+", encoding='utf-8') as f: - result_file_write(f,folder,url,phish_category,pred_target,matched_domain,siamese_conf,logo_recog_time,logo_match_time) + result_file_write(f, folder, url, phish_category, pred_target, matched_domain, siamese_conf, + logo_recog_time, logo_match_time) if phish_category: os.makedirs(os.path.join(request_dir, folder), exist_ok=True) cv2.imwrite(os.path.join(request_dir, folder, "predict.png"), plotvis) - - diff --git a/text_recog.py b/text_recog.py index 241ee36..9d77fe9 100644 --- a/text_recog.py +++ b/text_recog.py @@ -1,7 +1,7 @@ import re -def pred_text_in_image(ocr_model, shot_path): +def pred_text_in_image(ocr_model, shot_path): result = ocr_model.ocr(shot_path, cls=True) if result is None or result[0] is None: return '' @@ -12,12 +12,14 @@ def pred_text_in_image(ocr_model, shot_path): return detected_text + def check_email_credential_taking(ocr_model, shot_path): detected_text = pred_text_in_image(ocr_model, shot_path) if len(detected_text) > 0: return rule_matching(detected_text) return False, None + def rule_matching(detected_text): email_login_pattern = r'邮箱.*登录|邮箱.*登陆|邮件.*登录|邮件.*登陆' specified_email_pattern = r'@[\w.-]+\.\w+' @@ -28,4 +30,3 @@ def rule_matching(detected_text): return True, find_email[0] return False, None - diff --git a/utils.py b/utils.py index 2fecfa2..5ab87e2 100644 --- a/utils.py +++ b/utils.py @@ -1,6 +1,7 @@ import torch.nn.functional as F import math + def resolution_alignment(img1, img2): ''' Resize two images according to the minimum resolution between the two @@ -14,72 +15,74 @@ def resolution_alignment(img1, img2): if w_min == 0 or h_min == 0: ## something wrong, stop resizing return img1, img2 if w_min < h_min: - img1_resize = img1.resize((int(w_min), math.ceil(h1 * (w_min/w1)))) # ceiling to prevent rounding to 0 - img2_resize = img2.resize((int(w_min), math.ceil(h2 * (w_min/w2)))) + img1_resize = img1.resize((int(w_min), math.ceil(h1 * (w_min / w1)))) # ceiling to prevent rounding to 0 + img2_resize = img2.resize((int(w_min), math.ceil(h2 * (w_min / w2)))) else: - img1_resize = img1.resize((math.ceil(w1 * (h_min/h1)), int(h_min))) - img2_resize = img2.resize((math.ceil(w2 * (h_min/h2)), int(h_min))) + img1_resize = img1.resize((math.ceil(w1 * (h_min / h1)), int(h_min))) + img2_resize = img2.resize((math.ceil(w2 * (h_min / h2)), int(h_min))) return img1_resize, img2_resize + def brand_converter(brand_name): ''' Helper function to deal with inconsistency in brand naming ''' - brand_tran_dict={'Adobe Inc.':'Adobe','Adobe Inc':'Adobe', - 'ADP, LLC':'ADP','ADP, LLC.':'ADP', - 'Amazon.com Inc.':'Amazon','Amazon.com Inc':'Amazon', - 'Americanas.com S,A Comercio Electrnico':'Americanas.com S', - 'AOL Inc.':'AOL','AOL Inc':'AOL', - 'Apple Inc.':'Apple','Apple Inc':'Apple', - 'AT&T Inc.':'AT&T','AT&T Inc':'AT&T', - 'Banco do Brasil S.A.':'Banco do Brasil S.A', - 'Credit Agricole S.A.':'Credit Agricole S.A', - 'DGI (French Tax Authority)':'DGI French Tax Authority', - 'DHL Airways, Inc.':'DHL Airways','DHL Airways, Inc':'DHL Airways','DHL':'DHL Airways', - 'Dropbox, Inc.':'Dropbox','Dropbox, Inc':'Dropbox', - 'eBay Inc.':'eBay','eBay Inc':'eBay', - 'Facebook, Inc.':'Facebook','Facebook, Inc':'Facebook', - 'Free (ISP)':'Free ISP', - 'Google Inc.':'Google','Google Inc':'Google', - 'Mastercard International Incorporated':'Mastercard International', - 'Netflix Inc.':'Netflix','Netflix Inc':'Netflix', - 'PayPal Inc.':'PayPal','PayPal Inc':'PayPal', - 'Royal KPN N.V.':'Royal KPN N.V', - 'SF Express Co.':'SF Express Co', - 'SNS Bank N.V.':'SNS Bank N.V', - 'Square, Inc.':'Square','Square, Inc':'Square', - 'Webmail Providers':'Webmail Provider', - 'Yahoo! Inc':'Yahoo!','Yahoo! Inc.':'Yahoo!', - 'Microsoft OneDrive':'Microsoft','Office365':'Microsoft','Outlook':'Microsoft', - 'Global Sources (HK)':'Global Sources HK', - 'T-Online':'Deutsche Telekom', - 'Airbnb, Inc':'Airbnb, Inc.', - 'azul':'Azul', - 'Raiffeisen Bank S.A':'Raiffeisen Bank S.A.', - 'Twitter, Inc':'Twitter, Inc.','Twitter':'Twitter, Inc.', - 'capital_one':'Capital One Financial Corporation', - 'la_banque_postale':'La Banque postale', - 'db':'Deutsche Bank AG', - 'Swiss Post':'PostFinance','PostFinance':'PostFinance', - 'grupo_bancolombia':'Bancolombia', - 'barclays': 'Barclays Bank Plc', - 'gov_uk':'Government of the United Kingdom', - 'Aruba S.p.A':'Aruba S.p.A.', - 'TSB Bank Plc':'TSB Bank Limited', - 'strato':'Strato AG', - 'cogeco':'Cogeco', - 'Canada Revenue Agency':'Government of Canada', - 'UniCredit Bulbank':'UniCredit Bank Aktiengesellschaft', - 'ameli_fr':'French Health Insurance', - 'Banco de Credito del Peru':'bcp' - } + brand_tran_dict = {'Adobe Inc.': 'Adobe', 'Adobe Inc': 'Adobe', + 'ADP, LLC': 'ADP', 'ADP, LLC.': 'ADP', + 'Amazon.com Inc.': 'Amazon', 'Amazon.com Inc': 'Amazon', + 'Americanas.com S,A Comercio Electrnico': 'Americanas.com S', + 'AOL Inc.': 'AOL', 'AOL Inc': 'AOL', + 'Apple Inc.': 'Apple', 'Apple Inc': 'Apple', + 'AT&T Inc.': 'AT&T', 'AT&T Inc': 'AT&T', + 'Banco do Brasil S.A.': 'Banco do Brasil S.A', + 'Credit Agricole S.A.': 'Credit Agricole S.A', + 'DGI (French Tax Authority)': 'DGI French Tax Authority', + 'DHL Airways, Inc.': 'DHL Airways', 'DHL Airways, Inc': 'DHL Airways', 'DHL': 'DHL Airways', + 'Dropbox, Inc.': 'Dropbox', 'Dropbox, Inc': 'Dropbox', + 'eBay Inc.': 'eBay', 'eBay Inc': 'eBay', + 'Facebook, Inc.': 'Facebook', 'Facebook, Inc': 'Facebook', + 'Free (ISP)': 'Free ISP', + 'Google Inc.': 'Google', 'Google Inc': 'Google', + 'Mastercard International Incorporated': 'Mastercard International', + 'Netflix Inc.': 'Netflix', 'Netflix Inc': 'Netflix', + 'PayPal Inc.': 'PayPal', 'PayPal Inc': 'PayPal', + 'Royal KPN N.V.': 'Royal KPN N.V', + 'SF Express Co.': 'SF Express Co', + 'SNS Bank N.V.': 'SNS Bank N.V', + 'Square, Inc.': 'Square', 'Square, Inc': 'Square', + 'Webmail Providers': 'Webmail Provider', + 'Yahoo! Inc': 'Yahoo!', 'Yahoo! Inc.': 'Yahoo!', + 'Microsoft OneDrive': 'Microsoft', 'Office365': 'Microsoft', 'Outlook': 'Microsoft', + 'Global Sources (HK)': 'Global Sources HK', + 'T-Online': 'Deutsche Telekom', + 'Airbnb, Inc': 'Airbnb, Inc.', + 'azul': 'Azul', + 'Raiffeisen Bank S.A': 'Raiffeisen Bank S.A.', + 'Twitter, Inc': 'Twitter, Inc.', 'Twitter': 'Twitter, Inc.', + 'capital_one': 'Capital One Financial Corporation', + 'la_banque_postale': 'La Banque postale', + 'db': 'Deutsche Bank AG', + 'Swiss Post': 'PostFinance', 'PostFinance': 'PostFinance', + 'grupo_bancolombia': 'Bancolombia', + 'barclays': 'Barclays Bank Plc', + 'gov_uk': 'Government of the United Kingdom', + 'Aruba S.p.A': 'Aruba S.p.A.', + 'TSB Bank Plc': 'TSB Bank Limited', + 'strato': 'Strato AG', + 'cogeco': 'Cogeco', + 'Canada Revenue Agency': 'Government of Canada', + 'UniCredit Bulbank': 'UniCredit Bank Aktiengesellschaft', + 'ameli_fr': 'French Health Insurance', + 'Banco de Credito del Peru': 'bcp' + } # find the value in the dict else return the origin brand name - tran_brand_name=brand_tran_dict.get(brand_name,None) + tran_brand_name = brand_tran_dict.get(brand_name, None) if tran_brand_name: return tran_brand_name else: return brand_name + def l2_norm(x): """ l2 normalization @@ -88,4 +91,4 @@ def l2_norm(x): """ if len(x.shape): x = x.reshape((x.shape[0], -1)) - return F.normalize(x, p=2, dim=1) \ No newline at end of file + return F.normalize(x, p=2, dim=1)