diff --git a/.idea/.gitignore b/.idea/.gitignore
new file mode 100644
index 0000000..35410ca
--- /dev/null
+++ b/.idea/.gitignore
@@ -0,0 +1,8 @@
+# 默认忽略的文件
+/shelf/
+/workspace.xml
+# 基于编辑器的 HTTP 客户端请求
+/httpRequests/
+# Datasource local storage ignored files
+/dataSources/
+/dataSources.local.xml
diff --git a/.idea/Phishpedia.iml b/.idea/Phishpedia.iml
new file mode 100644
index 0000000..ec63674
--- /dev/null
+++ b/.idea/Phishpedia.iml
@@ -0,0 +1,7 @@
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/deployment.xml b/.idea/deployment.xml
new file mode 100644
index 0000000..3e51933
--- /dev/null
+++ b/.idea/deployment.xml
@@ -0,0 +1,28 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml
new file mode 100644
index 0000000..105ce2d
--- /dev/null
+++ b/.idea/inspectionProfiles/profiles_settings.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/vcs.xml b/.idea/vcs.xml
new file mode 100644
index 0000000..35eb1dd
--- /dev/null
+++ b/.idea/vcs.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/GUI/function.py b/GUI/function.py
index b1940fe..182c835 100644
--- a/GUI/function.py
+++ b/GUI/function.py
@@ -4,6 +4,7 @@
from phishpedia import PhishpediaWrapper
import cv2
+
class PhishpediaFunction:
def __init__(self, ui):
self.ui = ui
@@ -12,7 +13,8 @@ def __init__(self, ui):
def upload_image(self):
options = QFileDialog.Options()
- file_name, _ = QFileDialog.getOpenFileName(self.ui, "Select Screenshot", "", "Images (*.png *.jpg *.jpeg)", options=options)
+ file_name, _ = QFileDialog.getOpenFileName(self.ui, "Select Screenshot", "", "Images (*.png *.jpg *.jpeg)",
+ options=options)
if file_name:
self.ui.image_input.setText(file_name)
@@ -24,7 +26,8 @@ def detect_phishing(self):
self.ui.result_display.setText("Please enter URL and upload a screenshot.")
return
- phish_category, pred_target, matched_domain, plotvis, siamese_conf, pred_boxes, logo_recog_time, logo_match_time = self.phishpedia_cls.test_orig_phishpedia(url, screenshot_path, None)
+ phish_category, pred_target, matched_domain, plotvis, siamese_conf, pred_boxes, logo_recog_time, logo_match_time = self.phishpedia_cls.test_orig_phishpedia(
+ url, screenshot_path, None)
# 根据 phish_category 改变颜色
phish_category_color = 'green' if phish_category == 0 else 'red'
@@ -49,7 +52,7 @@ def display_image(self, plotvis):
height, width, channel = plotvis_rgb.shape
bytes_per_line = 3 * width
plotvis_qimage = QImage(plotvis_rgb.data, width, height, bytes_per_line, QImage.Format_RGB888)
-
+
self.current_pixmap = QPixmap.fromImage(plotvis_qimage)
self.update_image_display()
except Exception as e:
@@ -60,7 +63,8 @@ def update_image_display(self):
available_width = self.ui.width()
available_height = self.ui.height() - self.ui.visualization_label.geometry().bottom() - 50
- scaled_pixmap = self.current_pixmap.scaled(available_width, available_height, Qt.KeepAspectRatio, Qt.SmoothTransformation)
+ scaled_pixmap = self.current_pixmap.scaled(available_width, available_height, Qt.KeepAspectRatio,
+ Qt.SmoothTransformation)
self.ui.visualization_display.setPixmap(scaled_pixmap)
def on_resize(self, event):
diff --git a/GUI/ui.py b/GUI/ui.py
index bde47fe..5a281f0 100644
--- a/GUI/ui.py
+++ b/GUI/ui.py
@@ -4,6 +4,7 @@
from PyQt5.QtWidgets import QApplication
from .function import PhishpediaFunction
+
class PhishpediaUI(QWidget):
def __init__(self):
super().__init__()
@@ -117,7 +118,8 @@ def set_dynamic_font_size(self):
font = QFont()
font.setPointSizeF(font_size)
- for widget in self.findChildren(QLabel) + self.findChildren(QLineEdit) + self.findChildren(QPushButton) + [self.result_display]:
+ for widget in self.findChildren(QLabel) + self.findChildren(QLineEdit) + self.findChildren(QPushButton) + [
+ self.result_display]:
widget.setFont(font)
def init_phish_test_page(self):
diff --git a/configs.py b/configs.py
index d5b2c2e..25980b1 100644
--- a/configs.py
+++ b/configs.py
@@ -5,12 +5,13 @@
import os
import numpy as np
+
def get_absolute_path(relative_path):
base_path = os.path.dirname(__file__)
return os.path.abspath(os.path.join(base_path, relative_path))
-def load_config(reload_targetlist=False):
+def load_config(reload_targetlist=False):
with open(os.path.join(os.path.dirname(__file__), 'configs.yaml')) as file:
configs = yaml.load(file, Loader=yaml.FullLoader)
@@ -42,8 +43,8 @@ def load_config(reload_targetlist=False):
# os.makedirs(full_targetlist_folder_dir, exist_ok=True)
# subprocess.run(f'unzip -o "{targetlist_zip_path}" -d "{full_targetlist_folder_dir}"', shell=True)
- SIAMESE_MODEL = load_model_weights( num_classes=configs['SIAMESE_MODEL']['NUM_CLASSES'],
- weights_path=configs['SIAMESE_MODEL']['WEIGHTS_PATH'])
+ SIAMESE_MODEL = load_model_weights(num_classes=configs['SIAMESE_MODEL']['NUM_CLASSES'],
+ weights_path=configs['SIAMESE_MODEL']['WEIGHTS_PATH'])
LOGO_FEATS_NAME = 'LOGO_FEATS.npy'
LOGO_FILES_NAME = 'LOGO_FILES.npy'
@@ -52,12 +53,12 @@ def load_config(reload_targetlist=False):
LOGO_FEATS, LOGO_FILES = cache_reference_list(model=SIAMESE_MODEL,
targetlist_path=full_targetlist_folder_dir)
print('Finish loading protected logo list')
- np.save(os.path.join(os.path.dirname(__file__),LOGO_FEATS_NAME), LOGO_FEATS)
- np.save(os.path.join(os.path.dirname(__file__),LOGO_FILES_NAME), LOGO_FILES)
+ np.save(os.path.join(os.path.dirname(__file__), LOGO_FEATS_NAME), LOGO_FEATS)
+ np.save(os.path.join(os.path.dirname(__file__), LOGO_FILES_NAME), LOGO_FILES)
else:
- LOGO_FEATS, LOGO_FILES = np.load(os.path.join(os.path.dirname(__file__),LOGO_FEATS_NAME)), \
- np.load(os.path.join(os.path.dirname(__file__),LOGO_FILES_NAME))
+ LOGO_FEATS, LOGO_FILES = np.load(os.path.join(os.path.dirname(__file__), LOGO_FEATS_NAME)), \
+ np.load(os.path.join(os.path.dirname(__file__), LOGO_FILES_NAME))
DOMAIN_MAP_PATH = configs['SIAMESE_MODEL']['DOMAIN_MAP_PATH']
diff --git a/logo_matching.py b/logo_matching.py
index 224b562..943f104 100644
--- a/logo_matching.py
+++ b/logo_matching.py
@@ -10,18 +10,17 @@
from tldextract import tldextract
import pickle
+
def check_domain_brand_inconsistency(logo_boxes,
domain_map_path: str,
model, logo_feat_list,
file_name_list, shot_path: str,
url: str, ts: float,
- topk: float=3):
-
+ topk: float = 3):
# targetlist domain list
with open(domain_map_path, 'rb') as handle:
domain_map = pickle.load(handle)
-
print('number of logo boxes:', len(logo_boxes))
extracted_domain = tldextract.extract(url).domain + '.' + tldextract.extract(url).suffix
matched_target, matched_domain, matched_coord, this_conf = None, None, None, None
@@ -36,9 +35,10 @@ def check_domain_brand_inconsistency(logo_boxes,
min_x, min_y, max_x, max_y = coord
bbox = [float(min_x), float(min_y), float(max_x), float(max_y)]
matched_target, matched_domain, this_conf = pred_brand(model, domain_map,
- logo_feat_list, file_name_list,
- shot_path, bbox, t_s=ts, grayscale=False,
- do_aspect_ratio_check=False, do_resolution_alignment=False)
+ logo_feat_list, file_name_list,
+ shot_path, bbox, t_s=ts, grayscale=False,
+ do_aspect_ratio_check=False,
+ do_resolution_alignment=False)
# print(target_this, domain_this, this_conf)
# domain matcher to avoid FP
if matched_target and matched_domain:
@@ -51,6 +51,7 @@ def check_domain_brand_inconsistency(logo_boxes,
return brand_converter(matched_target), matched_domain, matched_coord, this_conf
+
def load_model_weights(num_classes: int, weights_path: str):
# Initialize model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
@@ -72,6 +73,7 @@ def load_model_weights(num_classes: int, weights_path: str):
model.eval()
return model
+
def cache_reference_list(model, targetlist_path: str, grayscale=False):
'''
cache the embeddings of the reference list
@@ -92,7 +94,8 @@ def cache_reference_list(model, targetlist_path: str, grayscale=False):
if target.startswith('.'): # skip hidden files
continue
for logo_path in os.listdir(os.path.join(targetlist_path, target)):
- if logo_path.endswith('.png') or logo_path.endswith('.jpeg') or logo_path.endswith('.jpg') or logo_path.endswith('.PNG') \
+ if logo_path.endswith('.png') or logo_path.endswith('.jpeg') or logo_path.endswith(
+ '.jpg') or logo_path.endswith('.PNG') \
or logo_path.endswith('.JPG') or logo_path.endswith('.JPEG'):
if logo_path.startswith('loginpage') or logo_path.startswith('homepage'): # skip homepage/loginpage
continue
@@ -102,6 +105,7 @@ def cache_reference_list(model, targetlist_path: str, grayscale=False):
return np.asarray(logo_feat_list), np.asarray(file_name_list)
+
@torch.no_grad()
def get_embedding(img, model, grayscale=False):
'''
@@ -122,7 +126,7 @@ def get_embedding(img, model, grayscale=False):
img_transforms = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std),
- ])
+ ])
img = Image.open(img) if isinstance(img, str) else img
img = img.convert("L").convert("RGB") if grayscale else img.convert("RGB")
diff --git a/logo_recog.py b/logo_recog.py
index 21640b3..a6434f3 100644
--- a/logo_recog.py
+++ b/logo_recog.py
@@ -4,6 +4,7 @@
import numpy as np
import torch
+
def pred_rcnn(im, predictor):
'''
Perform inference for RCNN
@@ -72,4 +73,4 @@ def vis(img_path, pred_boxes):
else:
cv2.rectangle(check, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), (36, 255, 12), 2)
- return check
\ No newline at end of file
+ return check
diff --git a/models.py b/models.py
index 34d3992..4912c94 100644
--- a/models.py
+++ b/models.py
@@ -15,7 +15,7 @@
# Lint as: python3
"""Bottleneck ResNet v2 with GroupNorm and Weight Standardization."""
-from collections import OrderedDict # pylint: disable=g-importing-member
+from collections import OrderedDict # pylint: disable=g-importing-member
import torch
import torch.nn as nn
@@ -29,17 +29,17 @@ def forward(self, x):
v, m = torch.var_mean(w, dim=[1, 2, 3], keepdim=True, unbiased=False)
w = (w - m) / torch.sqrt(v + 1e-10)
return F.conv2d(x, w, self.bias, self.stride, self.padding,
- self.dilation, self.groups)
+ self.dilation, self.groups)
def conv3x3(cin, cout, stride=1, groups=1, bias=False):
return StdConv2d(cin, cout, kernel_size=3, stride=stride,
- padding=1, bias=bias, groups=groups)
+ padding=1, bias=bias, groups=groups)
def conv1x1(cin, cout, stride=1, bias=False):
return StdConv2d(cin, cout, kernel_size=1, stride=stride,
- padding=0, bias=bias)
+ padding=0, bias=bias)
def tf2th(conv_weights):
@@ -61,12 +61,12 @@ class PreActBottleneck(nn.Module):
def __init__(self, cin, cout=None, cmid=None, stride=1):
super().__init__()
cout = cout or cin
- cmid = cmid or cout//4
+ cmid = cmid or cout // 4
self.gn1 = nn.GroupNorm(32, cin)
self.conv1 = conv1x1(cin, cmid)
self.gn2 = nn.GroupNorm(32, cmid)
- self.conv2 = conv3x3(cmid, cmid, stride) # Original code has it on conv1!!
+ self.conv2 = conv3x3(cmid, cmid, stride) # Original code has it on conv1!!
self.gn3 = nn.GroupNorm(32, cmid)
self.conv3 = conv1x1(cmid, cout)
self.relu = nn.ReLU(inplace=True)
@@ -112,44 +112,48 @@ class ResNetV2(nn.Module):
def __init__(self, block_units, width_factor, head_size=21843, zero_head=False):
super().__init__()
- wf = width_factor # shortcut 'cause we'll use it a lot.
+ wf = width_factor # shortcut 'cause we'll use it a lot.
# The following will be unreadable if we split lines.
# pylint: disable=line-too-long
self.root = nn.Sequential(OrderedDict([
- ('conv', StdConv2d(3, 64*wf, kernel_size=7, stride=2, padding=3, bias=False)),
- ('pad', nn.ConstantPad2d(1, 0)),
- ('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=0)),
- # The following is subtly not the same!
- # ('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),
+ ('conv', StdConv2d(3, 64 * wf, kernel_size=7, stride=2, padding=3, bias=False)),
+ ('pad', nn.ConstantPad2d(1, 0)),
+ ('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=0)),
+ # The following is subtly not the same!
+ # ('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),
]))
self.body = nn.Sequential(OrderedDict([
- ('block1', nn.Sequential(OrderedDict(
- [('unit01', PreActBottleneck(cin=64*wf, cout=256*wf, cmid=64*wf))] +
- [(f'unit{i:02d}', PreActBottleneck(cin=256*wf, cout=256*wf, cmid=64*wf)) for i in range(2, block_units[0] + 1)],
- ))),
- ('block2', nn.Sequential(OrderedDict(
- [('unit01', PreActBottleneck(cin=256*wf, cout=512*wf, cmid=128*wf, stride=2))] +
- [(f'unit{i:02d}', PreActBottleneck(cin=512*wf, cout=512*wf, cmid=128*wf)) for i in range(2, block_units[1] + 1)],
- ))),
- ('block3', nn.Sequential(OrderedDict(
- [('unit01', PreActBottleneck(cin=512*wf, cout=1024*wf, cmid=256*wf, stride=2))] +
- [(f'unit{i:02d}', PreActBottleneck(cin=1024*wf, cout=1024*wf, cmid=256*wf)) for i in range(2, block_units[2] + 1)],
- ))),
- ('block4', nn.Sequential(OrderedDict(
- [('unit01', PreActBottleneck(cin=1024*wf, cout=2048*wf, cmid=512*wf, stride=2))] +
- [(f'unit{i:02d}', PreActBottleneck(cin=2048*wf, cout=2048*wf, cmid=512*wf)) for i in range(2, block_units[3] + 1)],
- ))),
+ ('block1', nn.Sequential(OrderedDict(
+ [('unit01', PreActBottleneck(cin=64 * wf, cout=256 * wf, cmid=64 * wf))] +
+ [(f'unit{i:02d}', PreActBottleneck(cin=256 * wf, cout=256 * wf, cmid=64 * wf)) for i in
+ range(2, block_units[0] + 1)],
+ ))),
+ ('block2', nn.Sequential(OrderedDict(
+ [('unit01', PreActBottleneck(cin=256 * wf, cout=512 * wf, cmid=128 * wf, stride=2))] +
+ [(f'unit{i:02d}', PreActBottleneck(cin=512 * wf, cout=512 * wf, cmid=128 * wf)) for i in
+ range(2, block_units[1] + 1)],
+ ))),
+ ('block3', nn.Sequential(OrderedDict(
+ [('unit01', PreActBottleneck(cin=512 * wf, cout=1024 * wf, cmid=256 * wf, stride=2))] +
+ [(f'unit{i:02d}', PreActBottleneck(cin=1024 * wf, cout=1024 * wf, cmid=256 * wf)) for i in
+ range(2, block_units[2] + 1)],
+ ))),
+ ('block4', nn.Sequential(OrderedDict(
+ [('unit01', PreActBottleneck(cin=1024 * wf, cout=2048 * wf, cmid=512 * wf, stride=2))] +
+ [(f'unit{i:02d}', PreActBottleneck(cin=2048 * wf, cout=2048 * wf, cmid=512 * wf)) for i in
+ range(2, block_units[3] + 1)],
+ ))),
]))
# pylint: enable=line-too-long
self.zero_head = zero_head
self.head = nn.Sequential(OrderedDict([
- ('gn', nn.GroupNorm(32, 2048*wf)),
- ('relu', nn.ReLU(inplace=True)),
- ('avg', nn.AdaptiveAvgPool2d(output_size=1)),
- ('conv', nn.Conv2d(2048*wf, head_size, kernel_size=1, bias=True)),
+ ('gn', nn.GroupNorm(32, 2048 * wf)),
+ ('relu', nn.ReLU(inplace=True)),
+ ('avg', nn.AdaptiveAvgPool2d(output_size=1)),
+ ('conv', nn.Conv2d(2048 * wf, head_size, kernel_size=1, bias=True)),
]))
def features(self, x):
@@ -159,19 +163,21 @@ def features(self, x):
def forward(self, x):
x = self.head(self.body(self.root(x)))
- assert x.shape[-2:] == (1, 1) # We should have no spatial shape left.
- return x[...,0,0]
+ assert x.shape[-2:] == (1, 1) # We should have no spatial shape left.
+ return x[..., 0, 0]
def load_from(self, weights, prefix='resnet/'):
with torch.no_grad():
- self.root.conv.weight.copy_(tf2th(weights[f'{prefix}root_block/standardized_conv2d/kernel'])) # pylint: disable=line-too-long
+ self.root.conv.weight.copy_(
+ tf2th(weights[f'{prefix}root_block/standardized_conv2d/kernel'])) # pylint: disable=line-too-long
self.head.gn.weight.copy_(tf2th(weights[f'{prefix}group_norm/gamma']))
self.head.gn.bias.copy_(tf2th(weights[f'{prefix}group_norm/beta']))
if self.zero_head:
nn.init.zeros_(self.head.conv.weight)
nn.init.zeros_(self.head.conv.bias)
else:
- self.head.conv.weight.copy_(tf2th(weights[f'{prefix}head/conv2d/kernel'])) # pylint: disable=line-too-long
+ self.head.conv.weight.copy_(
+ tf2th(weights[f'{prefix}head/conv2d/kernel'])) # pylint: disable=line-too-long
self.head.conv.bias.copy_(tf2th(weights[f'{prefix}head/conv2d/bias']))
for bname, block in self.body.named_children():
@@ -180,16 +186,16 @@ def load_from(self, weights, prefix='resnet/'):
KNOWN_MODELS = OrderedDict([
- ('BiT-M-R50x1', lambda *a, **kw: ResNetV2([3, 4, 6, 3], 1, *a, **kw)),
- ('BiT-M-R50x3', lambda *a, **kw: ResNetV2([3, 4, 6, 3], 3, *a, **kw)),
- ('BiT-M-R101x1', lambda *a, **kw: ResNetV2([3, 4, 23, 3], 1, *a, **kw)),
- ('BiT-M-R101x3', lambda *a, **kw: ResNetV2([3, 4, 23, 3], 3, *a, **kw)),
- ('BiT-M-R152x2', lambda *a, **kw: ResNetV2([3, 8, 36, 3], 2, *a, **kw)),
- ('BiT-M-R152x4', lambda *a, **kw: ResNetV2([3, 8, 36, 3], 4, *a, **kw)),
- ('BiT-S-R50x1', lambda *a, **kw: ResNetV2([3, 4, 6, 3], 1, *a, **kw)),
- ('BiT-S-R50x3', lambda *a, **kw: ResNetV2([3, 4, 6, 3], 3, *a, **kw)),
- ('BiT-S-R101x1', lambda *a, **kw: ResNetV2([3, 4, 23, 3], 1, *a, **kw)),
- ('BiT-S-R101x3', lambda *a, **kw: ResNetV2([3, 4, 23, 3], 3, *a, **kw)),
- ('BiT-S-R152x2', lambda *a, **kw: ResNetV2([3, 8, 36, 3], 2, *a, **kw)),
- ('BiT-S-R152x4', lambda *a, **kw: ResNetV2([3, 8, 36, 3], 4, *a, **kw)),
-])
\ No newline at end of file
+ ('BiT-M-R50x1', lambda *a, **kw: ResNetV2([3, 4, 6, 3], 1, *a, **kw)),
+ ('BiT-M-R50x3', lambda *a, **kw: ResNetV2([3, 4, 6, 3], 3, *a, **kw)),
+ ('BiT-M-R101x1', lambda *a, **kw: ResNetV2([3, 4, 23, 3], 1, *a, **kw)),
+ ('BiT-M-R101x3', lambda *a, **kw: ResNetV2([3, 4, 23, 3], 3, *a, **kw)),
+ ('BiT-M-R152x2', lambda *a, **kw: ResNetV2([3, 8, 36, 3], 2, *a, **kw)),
+ ('BiT-M-R152x4', lambda *a, **kw: ResNetV2([3, 8, 36, 3], 4, *a, **kw)),
+ ('BiT-S-R50x1', lambda *a, **kw: ResNetV2([3, 4, 6, 3], 1, *a, **kw)),
+ ('BiT-S-R50x3', lambda *a, **kw: ResNetV2([3, 4, 6, 3], 3, *a, **kw)),
+ ('BiT-S-R101x1', lambda *a, **kw: ResNetV2([3, 4, 23, 3], 1, *a, **kw)),
+ ('BiT-S-R101x3', lambda *a, **kw: ResNetV2([3, 4, 23, 3], 3, *a, **kw)),
+ ('BiT-S-R152x2', lambda *a, **kw: ResNetV2([3, 8, 36, 3], 2, *a, **kw)),
+ ('BiT-S-R152x4', lambda *a, **kw: ResNetV2([3, 8, 36, 3], 4, *a, **kw)),
+])
diff --git a/phishpedia.py b/phishpedia.py
index 595c2ec..f704123 100644
--- a/phishpedia.py
+++ b/phishpedia.py
@@ -1,4 +1,3 @@
-
import time
from datetime import datetime
import argparse
@@ -13,9 +12,12 @@
from tqdm import tqdm
import re
-os.environ['KMP_DUPLICATE_LIB_OK']='True'
-def result_file_write(f,folder,url,phish_category,pred_target,matched_domain,siamese_conf,logo_recog_time,logo_match_time):
+os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
+
+
+def result_file_write(f, folder, url, phish_category, pred_target, matched_domain, siamese_conf, logo_recog_time,
+ logo_match_time):
f.write(folder + "\t")
f.write(url + "\t")
f.write(str(phish_category) + "\t")
@@ -25,6 +27,7 @@ def result_file_write(f,folder,url,phish_category,pred_target,matched_domain,sia
f.write(str(round(logo_recog_time, 4)) + "\t")
f.write(str(round(logo_match_time, 4)) + "\n")
+
class PhishpediaWrapper:
_caller_prefix = "PhishpediaWrapper"
_DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
@@ -66,6 +69,7 @@ def _to_device(self):
# return False
'''Phishpedia'''
+
# @profile
def test_orig_phishpedia(self, url, screenshot_path, html_path):
# 0 for benign, 1 for phish, default is benign
@@ -93,15 +97,16 @@ def test_orig_phishpedia(self, url, screenshot_path, html_path):
######################## Step2: Siamese (Logo matcher) ########################################
start_time = time.time()
- pred_target, matched_domain, matched_coord, siamese_conf = check_domain_brand_inconsistency(logo_boxes=pred_boxes,
- domain_map_path=self.DOMAIN_MAP_PATH,
- model=self.SIAMESE_MODEL,
- logo_feat_list=self.LOGO_FEATS,
- file_name_list=self.LOGO_FILES,
- url=url,
- shot_path=screenshot_path,
- ts=self.SIAMESE_THRE,
- topk=1)
+ pred_target, matched_domain, matched_coord, siamese_conf = check_domain_brand_inconsistency(
+ logo_boxes=pred_boxes,
+ domain_map_path=self.DOMAIN_MAP_PATH,
+ model=self.SIAMESE_MODEL,
+ logo_feat_list=self.LOGO_FEATS,
+ file_name_list=self.LOGO_FILES,
+ url=url,
+ shot_path=screenshot_path,
+ ts=self.SIAMESE_THRE,
+ topk=1)
logo_match_time = time.time() - start_time
if pred_target is None:
@@ -111,8 +116,8 @@ def test_orig_phishpedia(self, url, screenshot_path, html_path):
######################## Step3: Simple input box check ###############
# has_input_box = self.simple_input_box_regex(html_path=html_path)
# if not has_input_box:
- # print('No input box')
- # return phish_category, pred_target, matched_domain, plotvis, siamese_conf, pred_boxes, logo_recog_time, logo_match_time
+ # print('No input box')
+ # return phish_category, pred_target, matched_domain, plotvis, siamese_conf, pred_boxes, logo_recog_time, logo_match_time
# else:
print('Match to Target: {} with confidence {:.4f}'.format(pred_target, siamese_conf))
phish_category = 1
@@ -121,7 +126,6 @@ def test_orig_phishpedia(self, url, screenshot_path, html_path):
(int(matched_coord[0] + 20), int(matched_coord[1] + 20)),
cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 0), 2)
-
return phish_category, pred_target, matched_domain, plotvis, siamese_conf, pred_boxes, logo_recog_time, logo_match_time
@@ -151,7 +155,6 @@ def test_orig_phishpedia(self, url, screenshot_path, html_path):
os.makedirs(request_dir, exist_ok=True)
-
for folder in tqdm(os.listdir(request_dir)):
html_path = os.path.join(request_dir, folder, "html.txt")
screenshot_path = os.path.join(request_dir, folder, "shot.png")
@@ -175,18 +178,17 @@ def test_orig_phishpedia(self, url, screenshot_path, html_path):
continue
phish_category, pred_target, matched_domain, \
- plotvis, siamese_conf, pred_boxes, \
- logo_recog_time, logo_match_time = phishpedia_cls.test_orig_phishpedia(url, screenshot_path, html_path)
-
+ plotvis, siamese_conf, pred_boxes, \
+ logo_recog_time, logo_match_time = phishpedia_cls.test_orig_phishpedia(url, screenshot_path, html_path)
try:
with open(result_txt, "a+", encoding='ISO-8859-1') as f:
- result_file_write(f,folder,url,phish_category,pred_target,matched_domain,siamese_conf,logo_recog_time,logo_match_time)
+ result_file_write(f, folder, url, phish_category, pred_target, matched_domain, siamese_conf,
+ logo_recog_time, logo_match_time)
except UnicodeError:
with open(result_txt, "a+", encoding='utf-8') as f:
- result_file_write(f,folder,url,phish_category,pred_target,matched_domain,siamese_conf,logo_recog_time,logo_match_time)
+ result_file_write(f, folder, url, phish_category, pred_target, matched_domain, siamese_conf,
+ logo_recog_time, logo_match_time)
if phish_category:
os.makedirs(os.path.join(request_dir, folder), exist_ok=True)
cv2.imwrite(os.path.join(request_dir, folder, "predict.png"), plotvis)
-
-
diff --git a/text_recog.py b/text_recog.py
index 241ee36..9d77fe9 100644
--- a/text_recog.py
+++ b/text_recog.py
@@ -1,7 +1,7 @@
import re
-def pred_text_in_image(ocr_model, shot_path):
+def pred_text_in_image(ocr_model, shot_path):
result = ocr_model.ocr(shot_path, cls=True)
if result is None or result[0] is None:
return ''
@@ -12,12 +12,14 @@ def pred_text_in_image(ocr_model, shot_path):
return detected_text
+
def check_email_credential_taking(ocr_model, shot_path):
detected_text = pred_text_in_image(ocr_model, shot_path)
if len(detected_text) > 0:
return rule_matching(detected_text)
return False, None
+
def rule_matching(detected_text):
email_login_pattern = r'邮箱.*登录|邮箱.*登陆|邮件.*登录|邮件.*登陆'
specified_email_pattern = r'@[\w.-]+\.\w+'
@@ -28,4 +30,3 @@ def rule_matching(detected_text):
return True, find_email[0]
return False, None
-
diff --git a/utils.py b/utils.py
index 2fecfa2..5ab87e2 100644
--- a/utils.py
+++ b/utils.py
@@ -1,6 +1,7 @@
import torch.nn.functional as F
import math
+
def resolution_alignment(img1, img2):
'''
Resize two images according to the minimum resolution between the two
@@ -14,72 +15,74 @@ def resolution_alignment(img1, img2):
if w_min == 0 or h_min == 0: ## something wrong, stop resizing
return img1, img2
if w_min < h_min:
- img1_resize = img1.resize((int(w_min), math.ceil(h1 * (w_min/w1)))) # ceiling to prevent rounding to 0
- img2_resize = img2.resize((int(w_min), math.ceil(h2 * (w_min/w2))))
+ img1_resize = img1.resize((int(w_min), math.ceil(h1 * (w_min / w1)))) # ceiling to prevent rounding to 0
+ img2_resize = img2.resize((int(w_min), math.ceil(h2 * (w_min / w2))))
else:
- img1_resize = img1.resize((math.ceil(w1 * (h_min/h1)), int(h_min)))
- img2_resize = img2.resize((math.ceil(w2 * (h_min/h2)), int(h_min)))
+ img1_resize = img1.resize((math.ceil(w1 * (h_min / h1)), int(h_min)))
+ img2_resize = img2.resize((math.ceil(w2 * (h_min / h2)), int(h_min)))
return img1_resize, img2_resize
+
def brand_converter(brand_name):
'''
Helper function to deal with inconsistency in brand naming
'''
- brand_tran_dict={'Adobe Inc.':'Adobe','Adobe Inc':'Adobe',
- 'ADP, LLC':'ADP','ADP, LLC.':'ADP',
- 'Amazon.com Inc.':'Amazon','Amazon.com Inc':'Amazon',
- 'Americanas.com S,A Comercio Electrnico':'Americanas.com S',
- 'AOL Inc.':'AOL','AOL Inc':'AOL',
- 'Apple Inc.':'Apple','Apple Inc':'Apple',
- 'AT&T Inc.':'AT&T','AT&T Inc':'AT&T',
- 'Banco do Brasil S.A.':'Banco do Brasil S.A',
- 'Credit Agricole S.A.':'Credit Agricole S.A',
- 'DGI (French Tax Authority)':'DGI French Tax Authority',
- 'DHL Airways, Inc.':'DHL Airways','DHL Airways, Inc':'DHL Airways','DHL':'DHL Airways',
- 'Dropbox, Inc.':'Dropbox','Dropbox, Inc':'Dropbox',
- 'eBay Inc.':'eBay','eBay Inc':'eBay',
- 'Facebook, Inc.':'Facebook','Facebook, Inc':'Facebook',
- 'Free (ISP)':'Free ISP',
- 'Google Inc.':'Google','Google Inc':'Google',
- 'Mastercard International Incorporated':'Mastercard International',
- 'Netflix Inc.':'Netflix','Netflix Inc':'Netflix',
- 'PayPal Inc.':'PayPal','PayPal Inc':'PayPal',
- 'Royal KPN N.V.':'Royal KPN N.V',
- 'SF Express Co.':'SF Express Co',
- 'SNS Bank N.V.':'SNS Bank N.V',
- 'Square, Inc.':'Square','Square, Inc':'Square',
- 'Webmail Providers':'Webmail Provider',
- 'Yahoo! Inc':'Yahoo!','Yahoo! Inc.':'Yahoo!',
- 'Microsoft OneDrive':'Microsoft','Office365':'Microsoft','Outlook':'Microsoft',
- 'Global Sources (HK)':'Global Sources HK',
- 'T-Online':'Deutsche Telekom',
- 'Airbnb, Inc':'Airbnb, Inc.',
- 'azul':'Azul',
- 'Raiffeisen Bank S.A':'Raiffeisen Bank S.A.',
- 'Twitter, Inc':'Twitter, Inc.','Twitter':'Twitter, Inc.',
- 'capital_one':'Capital One Financial Corporation',
- 'la_banque_postale':'La Banque postale',
- 'db':'Deutsche Bank AG',
- 'Swiss Post':'PostFinance','PostFinance':'PostFinance',
- 'grupo_bancolombia':'Bancolombia',
- 'barclays': 'Barclays Bank Plc',
- 'gov_uk':'Government of the United Kingdom',
- 'Aruba S.p.A':'Aruba S.p.A.',
- 'TSB Bank Plc':'TSB Bank Limited',
- 'strato':'Strato AG',
- 'cogeco':'Cogeco',
- 'Canada Revenue Agency':'Government of Canada',
- 'UniCredit Bulbank':'UniCredit Bank Aktiengesellschaft',
- 'ameli_fr':'French Health Insurance',
- 'Banco de Credito del Peru':'bcp'
- }
+ brand_tran_dict = {'Adobe Inc.': 'Adobe', 'Adobe Inc': 'Adobe',
+ 'ADP, LLC': 'ADP', 'ADP, LLC.': 'ADP',
+ 'Amazon.com Inc.': 'Amazon', 'Amazon.com Inc': 'Amazon',
+ 'Americanas.com S,A Comercio Electrnico': 'Americanas.com S',
+ 'AOL Inc.': 'AOL', 'AOL Inc': 'AOL',
+ 'Apple Inc.': 'Apple', 'Apple Inc': 'Apple',
+ 'AT&T Inc.': 'AT&T', 'AT&T Inc': 'AT&T',
+ 'Banco do Brasil S.A.': 'Banco do Brasil S.A',
+ 'Credit Agricole S.A.': 'Credit Agricole S.A',
+ 'DGI (French Tax Authority)': 'DGI French Tax Authority',
+ 'DHL Airways, Inc.': 'DHL Airways', 'DHL Airways, Inc': 'DHL Airways', 'DHL': 'DHL Airways',
+ 'Dropbox, Inc.': 'Dropbox', 'Dropbox, Inc': 'Dropbox',
+ 'eBay Inc.': 'eBay', 'eBay Inc': 'eBay',
+ 'Facebook, Inc.': 'Facebook', 'Facebook, Inc': 'Facebook',
+ 'Free (ISP)': 'Free ISP',
+ 'Google Inc.': 'Google', 'Google Inc': 'Google',
+ 'Mastercard International Incorporated': 'Mastercard International',
+ 'Netflix Inc.': 'Netflix', 'Netflix Inc': 'Netflix',
+ 'PayPal Inc.': 'PayPal', 'PayPal Inc': 'PayPal',
+ 'Royal KPN N.V.': 'Royal KPN N.V',
+ 'SF Express Co.': 'SF Express Co',
+ 'SNS Bank N.V.': 'SNS Bank N.V',
+ 'Square, Inc.': 'Square', 'Square, Inc': 'Square',
+ 'Webmail Providers': 'Webmail Provider',
+ 'Yahoo! Inc': 'Yahoo!', 'Yahoo! Inc.': 'Yahoo!',
+ 'Microsoft OneDrive': 'Microsoft', 'Office365': 'Microsoft', 'Outlook': 'Microsoft',
+ 'Global Sources (HK)': 'Global Sources HK',
+ 'T-Online': 'Deutsche Telekom',
+ 'Airbnb, Inc': 'Airbnb, Inc.',
+ 'azul': 'Azul',
+ 'Raiffeisen Bank S.A': 'Raiffeisen Bank S.A.',
+ 'Twitter, Inc': 'Twitter, Inc.', 'Twitter': 'Twitter, Inc.',
+ 'capital_one': 'Capital One Financial Corporation',
+ 'la_banque_postale': 'La Banque postale',
+ 'db': 'Deutsche Bank AG',
+ 'Swiss Post': 'PostFinance', 'PostFinance': 'PostFinance',
+ 'grupo_bancolombia': 'Bancolombia',
+ 'barclays': 'Barclays Bank Plc',
+ 'gov_uk': 'Government of the United Kingdom',
+ 'Aruba S.p.A': 'Aruba S.p.A.',
+ 'TSB Bank Plc': 'TSB Bank Limited',
+ 'strato': 'Strato AG',
+ 'cogeco': 'Cogeco',
+ 'Canada Revenue Agency': 'Government of Canada',
+ 'UniCredit Bulbank': 'UniCredit Bank Aktiengesellschaft',
+ 'ameli_fr': 'French Health Insurance',
+ 'Banco de Credito del Peru': 'bcp'
+ }
# find the value in the dict else return the origin brand name
- tran_brand_name=brand_tran_dict.get(brand_name,None)
+ tran_brand_name = brand_tran_dict.get(brand_name, None)
if tran_brand_name:
return tran_brand_name
else:
return brand_name
+
def l2_norm(x):
"""
l2 normalization
@@ -88,4 +91,4 @@ def l2_norm(x):
"""
if len(x.shape):
x = x.reshape((x.shape[0], -1))
- return F.normalize(x, p=2, dim=1)
\ No newline at end of file
+ return F.normalize(x, p=2, dim=1)