Skip to content

Commit

Permalink
pep change
Browse files Browse the repository at this point in the history
  • Loading branch information
wangxu60 committed Dec 1, 2024
1 parent 33cbd58 commit 827cc5a
Show file tree
Hide file tree
Showing 14 changed files with 228 additions and 149 deletions.
8 changes: 8 additions & 0 deletions .idea/.gitignore

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 7 additions & 0 deletions .idea/Phishpedia.iml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

28 changes: 28 additions & 0 deletions .idea/deployment.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions .idea/inspectionProfiles/profiles_settings.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions .idea/vcs.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 8 additions & 4 deletions GUI/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from phishpedia import PhishpediaWrapper
import cv2


class PhishpediaFunction:
def __init__(self, ui):
self.ui = ui
Expand All @@ -12,7 +13,8 @@ def __init__(self, ui):

def upload_image(self):
options = QFileDialog.Options()
file_name, _ = QFileDialog.getOpenFileName(self.ui, "Select Screenshot", "", "Images (*.png *.jpg *.jpeg)", options=options)
file_name, _ = QFileDialog.getOpenFileName(self.ui, "Select Screenshot", "", "Images (*.png *.jpg *.jpeg)",
options=options)
if file_name:
self.ui.image_input.setText(file_name)

Expand All @@ -24,7 +26,8 @@ def detect_phishing(self):
self.ui.result_display.setText("Please enter URL and upload a screenshot.")
return

phish_category, pred_target, matched_domain, plotvis, siamese_conf, pred_boxes, logo_recog_time, logo_match_time = self.phishpedia_cls.test_orig_phishpedia(url, screenshot_path, None)
phish_category, pred_target, matched_domain, plotvis, siamese_conf, pred_boxes, logo_recog_time, logo_match_time = self.phishpedia_cls.test_orig_phishpedia(
url, screenshot_path, None)

# 根据 phish_category 改变颜色
phish_category_color = 'green' if phish_category == 0 else 'red'
Expand All @@ -49,7 +52,7 @@ def display_image(self, plotvis):
height, width, channel = plotvis_rgb.shape
bytes_per_line = 3 * width
plotvis_qimage = QImage(plotvis_rgb.data, width, height, bytes_per_line, QImage.Format_RGB888)

self.current_pixmap = QPixmap.fromImage(plotvis_qimage)
self.update_image_display()
except Exception as e:
Expand All @@ -60,7 +63,8 @@ def update_image_display(self):
available_width = self.ui.width()
available_height = self.ui.height() - self.ui.visualization_label.geometry().bottom() - 50

scaled_pixmap = self.current_pixmap.scaled(available_width, available_height, Qt.KeepAspectRatio, Qt.SmoothTransformation)
scaled_pixmap = self.current_pixmap.scaled(available_width, available_height, Qt.KeepAspectRatio,
Qt.SmoothTransformation)
self.ui.visualization_display.setPixmap(scaled_pixmap)

def on_resize(self, event):
Expand Down
4 changes: 3 additions & 1 deletion GUI/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from PyQt5.QtWidgets import QApplication
from .function import PhishpediaFunction


class PhishpediaUI(QWidget):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -117,7 +118,8 @@ def set_dynamic_font_size(self):
font = QFont()
font.setPointSizeF(font_size)

for widget in self.findChildren(QLabel) + self.findChildren(QLineEdit) + self.findChildren(QPushButton) + [self.result_display]:
for widget in self.findChildren(QLabel) + self.findChildren(QLineEdit) + self.findChildren(QPushButton) + [
self.result_display]:
widget.setFont(font)

def init_phish_test_page(self):
Expand Down
15 changes: 8 additions & 7 deletions configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@
import os
import numpy as np


def get_absolute_path(relative_path):
base_path = os.path.dirname(__file__)
return os.path.abspath(os.path.join(base_path, relative_path))

def load_config(reload_targetlist=False):

def load_config(reload_targetlist=False):
with open(os.path.join(os.path.dirname(__file__), 'configs.yaml')) as file:
configs = yaml.load(file, Loader=yaml.FullLoader)

Expand Down Expand Up @@ -42,8 +43,8 @@ def load_config(reload_targetlist=False):
# os.makedirs(full_targetlist_folder_dir, exist_ok=True)
# subprocess.run(f'unzip -o "{targetlist_zip_path}" -d "{full_targetlist_folder_dir}"', shell=True)

SIAMESE_MODEL = load_model_weights( num_classes=configs['SIAMESE_MODEL']['NUM_CLASSES'],
weights_path=configs['SIAMESE_MODEL']['WEIGHTS_PATH'])
SIAMESE_MODEL = load_model_weights(num_classes=configs['SIAMESE_MODEL']['NUM_CLASSES'],
weights_path=configs['SIAMESE_MODEL']['WEIGHTS_PATH'])

LOGO_FEATS_NAME = 'LOGO_FEATS.npy'
LOGO_FILES_NAME = 'LOGO_FILES.npy'
Expand All @@ -52,12 +53,12 @@ def load_config(reload_targetlist=False):
LOGO_FEATS, LOGO_FILES = cache_reference_list(model=SIAMESE_MODEL,
targetlist_path=full_targetlist_folder_dir)
print('Finish loading protected logo list')
np.save(os.path.join(os.path.dirname(__file__),LOGO_FEATS_NAME), LOGO_FEATS)
np.save(os.path.join(os.path.dirname(__file__),LOGO_FILES_NAME), LOGO_FILES)
np.save(os.path.join(os.path.dirname(__file__), LOGO_FEATS_NAME), LOGO_FEATS)
np.save(os.path.join(os.path.dirname(__file__), LOGO_FILES_NAME), LOGO_FILES)

else:
LOGO_FEATS, LOGO_FILES = np.load(os.path.join(os.path.dirname(__file__),LOGO_FEATS_NAME)), \
np.load(os.path.join(os.path.dirname(__file__),LOGO_FILES_NAME))
LOGO_FEATS, LOGO_FILES = np.load(os.path.join(os.path.dirname(__file__), LOGO_FEATS_NAME)), \
np.load(os.path.join(os.path.dirname(__file__), LOGO_FILES_NAME))

DOMAIN_MAP_PATH = configs['SIAMESE_MODEL']['DOMAIN_MAP_PATH']

Expand Down
20 changes: 12 additions & 8 deletions logo_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,17 @@
from tldextract import tldextract
import pickle


def check_domain_brand_inconsistency(logo_boxes,
domain_map_path: str,
model, logo_feat_list,
file_name_list, shot_path: str,
url: str, ts: float,
topk: float=3):

topk: float = 3):
# targetlist domain list
with open(domain_map_path, 'rb') as handle:
domain_map = pickle.load(handle)


print('number of logo boxes:', len(logo_boxes))
extracted_domain = tldextract.extract(url).domain + '.' + tldextract.extract(url).suffix
matched_target, matched_domain, matched_coord, this_conf = None, None, None, None
Expand All @@ -36,9 +35,10 @@ def check_domain_brand_inconsistency(logo_boxes,
min_x, min_y, max_x, max_y = coord
bbox = [float(min_x), float(min_y), float(max_x), float(max_y)]
matched_target, matched_domain, this_conf = pred_brand(model, domain_map,
logo_feat_list, file_name_list,
shot_path, bbox, t_s=ts, grayscale=False,
do_aspect_ratio_check=False, do_resolution_alignment=False)
logo_feat_list, file_name_list,
shot_path, bbox, t_s=ts, grayscale=False,
do_aspect_ratio_check=False,
do_resolution_alignment=False)
# print(target_this, domain_this, this_conf)
# domain matcher to avoid FP
if matched_target and matched_domain:
Expand All @@ -51,6 +51,7 @@ def check_domain_brand_inconsistency(logo_boxes,

return brand_converter(matched_target), matched_domain, matched_coord, this_conf


def load_model_weights(num_classes: int, weights_path: str):
# Initialize model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
Expand All @@ -72,6 +73,7 @@ def load_model_weights(num_classes: int, weights_path: str):
model.eval()
return model


def cache_reference_list(model, targetlist_path: str, grayscale=False):
'''
cache the embeddings of the reference list
Expand All @@ -92,7 +94,8 @@ def cache_reference_list(model, targetlist_path: str, grayscale=False):
if target.startswith('.'): # skip hidden files
continue
for logo_path in os.listdir(os.path.join(targetlist_path, target)):
if logo_path.endswith('.png') or logo_path.endswith('.jpeg') or logo_path.endswith('.jpg') or logo_path.endswith('.PNG') \
if logo_path.endswith('.png') or logo_path.endswith('.jpeg') or logo_path.endswith(
'.jpg') or logo_path.endswith('.PNG') \
or logo_path.endswith('.JPG') or logo_path.endswith('.JPEG'):
if logo_path.startswith('loginpage') or logo_path.startswith('homepage'): # skip homepage/loginpage
continue
Expand All @@ -102,6 +105,7 @@ def cache_reference_list(model, targetlist_path: str, grayscale=False):

return np.asarray(logo_feat_list), np.asarray(file_name_list)


@torch.no_grad()
def get_embedding(img, model, grayscale=False):
'''
Expand All @@ -122,7 +126,7 @@ def get_embedding(img, model, grayscale=False):
img_transforms = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std),
])
])

img = Image.open(img) if isinstance(img, str) else img
img = img.convert("L").convert("RGB") if grayscale else img.convert("RGB")
Expand Down
3 changes: 2 additions & 1 deletion logo_recog.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np
import torch


def pred_rcnn(im, predictor):
'''
Perform inference for RCNN
Expand Down Expand Up @@ -72,4 +73,4 @@ def vis(img_path, pred_boxes):
else:
cv2.rectangle(check, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), (36, 255, 12), 2)

return check
return check
Loading

0 comments on commit 827cc5a

Please sign in to comment.