diff --git a/README.md b/README.md
index 857d249..ed38594 100644
--- a/README.md
+++ b/README.md
@@ -29,7 +29,7 @@
## Framework
-
+
```Input```: A URL and its screenshot ```Output```: Phish/Benign, Phishing target
- Step 1: Enter Deep Object Detection Model, get predicted logos and inputs (inputs are not used for later prediction, just for explanation)
@@ -40,88 +40,54 @@
## Project structure
```
-- src
- - adv_attack: adversarial attacking scripts
- - detectron2_pedia: training script for object detector
- |_ output
- |_ rcnn_2
- |_ rcnn_bet365.pth
- - siamese_pedia: inference script for siamese
- |_ siamese_retrain: training script for siamese
- |_ expand_targetlist
- |_ 1&1 Ionos
- |_ ...
- |_ domain_map.pkl
- |_ resnetv2_rgb_new.pth.tar
- - siamese.py: main script for siamese
- - pipeline_eval.py: evaluation script for general experiment
-
-- tele: telegram scripts to vote for phishing
-- phishpedia_config.py: config script for phish-discovery experiment
-- phishpedia_main.py: main script for phish-discovery experiment
+- logo_recog.py: Deep Object Detection Model
+- logo_matching.py: Deep Siamese Model
+- configs.yaml: Configuration file
+- phishpedia.py: Main script
```
## Instructions
Requirements:
-- CUDA 11
- Anaconda installed, please refer to the official installation guide: https://docs.anaconda.com/free/anaconda/install/index.html
1. Create a local clone of Phishpedia
-```
+```bash
git clone https://github.com/lindsey98/Phishpedia.git
```
2. Setup
-```
-cd Phishpedia/
+```bash
chmod +x ./setup.sh
./setup.sh
```
-If you encounter any problem in downloading the models, you can manually download them from here https://huggingface.co/Kelsey98/Phishpedia. And put them into the corresponding conda environment.
3.
```
-conda activate myenv
+conda activate phishpedia
```
-Run in Python to test a single website
-```python
-from phishpedia.phishpedia_main import test
-import matplotlib.pyplot as plt
-from phishpedia.phishpedia_config import load_config
-
-url = open("phishpedia/datasets/test_sites/accounts.g.cdcde.com/info.txt").read().strip()
-screenshot_path = "phishpedia/datasets/test_sites/accounts.g.cdcde.com/shot.png"
-ELE_MODEL, SIAMESE_THRE, SIAMESE_MODEL, LOGO_FEATS, LOGO_FILES, DOMAIN_MAP_PATH = load_config(None)
-
-phish_category, pred_target, plotvis, siamese_conf, pred_boxes = test(url=url, screenshot_path=screenshot_path,
- ELE_MODEL=ELE_MODEL,
- SIAMESE_THRE=SIAMESE_THRE,
- SIAMESE_MODEL=SIAMESE_MODEL,
- LOGO_FEATS=LOGO_FEATS,
- LOGO_FILES=LOGO_FILES,
- DOMAIN_MAP_PATH=DOMAIN_MAP_PATH
- )
-
-print('Phishing (1) or Benign (0) ?', phish_category)
-print('What is its targeted brand if it is a phishing ?', pred_target)
-print('What is the siamese matching confidence ?', siamese_conf)
-print('Where is the predicted logo (in [x_min, y_min, x_max, y_max])?', pred_boxes)
-plt.imshow(plotvis[:, :, ::-1])
-plt.title("Predicted screenshot with annotations")
-plt.show()
+4. Run in bash
+```bash
+python phishpedia.py --folder
```
-Or run in bash
+The testing folder should be in the structure of:
+
```
-python run.py --folder --results
+test_site_1
+|__ info.txt (Write the URL)
+|__ shot.png (Save the screenshot)
+test_site_2
+|__ info.txt (Write the URL)
+|__ shot.png (Save the screenshot)
+......
```
## Miscellaneous
- In our paper, we also implement several phishing detection and identification baselines, see [here](https://github.com/lindsey98/PhishingBaseline)
- The logo targetlist described in our paper includes 181 brands, we have further expanded the targetlist to include 277 brands in this code repository
- For the phish discovery experiment, we obtain feed from [Certstream phish_catcher](https://github.com/x0rz/phishing_catcher), we lower the score threshold to be 40 to process more suspicious websites, readers can refer to their repo for details
-- We use Scrapy for website crawling [Repo here](https://github.com/lindsey98/MyScrapy.git)
+- We use Scrapy for website crawling
## Citation
If you find our work useful in your research, please consider citing our paper by:
diff --git a/configs.py b/configs.py
new file mode 100644
index 0000000..50be87c
--- /dev/null
+++ b/configs.py
@@ -0,0 +1,63 @@
+# Global configuration
+import subprocess
+from typing import Union
+import yaml
+from logo_matching import cache_reference_list, load_model_weights
+from logo_recog import config_rcnn
+import os
+import numpy as np
+
+def get_absolute_path(relative_path):
+ base_path = os.path.dirname(__file__)
+ return os.path.abspath(os.path.join(base_path, relative_path))
+
+def load_config(reload_targetlist=False):
+
+ with open(os.path.join(os.path.dirname(__file__), 'configs.yaml')) as file:
+ configs = yaml.load(file, Loader=yaml.FullLoader)
+
+ # Iterate through the configuration and update paths
+ for section, settings in configs.items():
+ for key, value in settings.items():
+ if 'PATH' in key and isinstance(value, str): # Check if the key indicates a path
+ absolute_path = get_absolute_path(value)
+ configs[section][key] = absolute_path
+
+ ELE_CFG_PATH = configs['ELE_MODEL']['CFG_PATH']
+ ELE_WEIGHTS_PATH = configs['ELE_MODEL']['WEIGHTS_PATH']
+ ELE_CONFIG_THRE = configs['ELE_MODEL']['DETECT_THRE']
+ ELE_MODEL = config_rcnn(ELE_CFG_PATH,
+ ELE_WEIGHTS_PATH,
+ conf_threshold=ELE_CONFIG_THRE)
+
+ # siamese model
+ SIAMESE_THRE = configs['SIAMESE_MODEL']['MATCH_THRE']
+
+ print('Load protected logo list')
+ targetlist_zip_path = configs['SIAMESE_MODEL']['TARGETLIST_PATH']
+ targetlist_dir = os.path.dirname(targetlist_zip_path)
+ zip_file_name = os.path.basename(targetlist_zip_path)
+ targetlist_folder = zip_file_name.split('.zip')[0]
+ full_targetlist_folder_dir = os.path.join(targetlist_dir, targetlist_folder)
+
+ if reload_targetlist or targetlist_zip_path.endswith('.zip') and not os.path.isdir(full_targetlist_folder_dir):
+ os.makedirs(full_targetlist_folder_dir, exist_ok=True)
+ subprocess.run(f'unzip -o "{targetlist_zip_path}" -d "{full_targetlist_folder_dir}"', shell=True)
+
+ SIAMESE_MODEL = load_model_weights( num_classes=configs['SIAMESE_MODEL']['NUM_CLASSES'],
+ weights_path=configs['SIAMESE_MODEL']['WEIGHTS_PATH'])
+
+ if reload_targetlist or (not os.path.exists(os.path.join(os.path.dirname(__file__), 'LOGO_FEATS.npy'))):
+ LOGO_FEATS, LOGO_FILES = cache_reference_list(model=SIAMESE_MODEL,
+ targetlist_path=full_targetlist_folder_dir)
+ print('Finish loading protected logo list')
+ np.save(os.path.join(os.path.dirname(__file__),'LOGO_FEATS.npy'), LOGO_FEATS)
+ np.save(os.path.join(os.path.dirname(__file__),'LOGO_FILES.npy'), LOGO_FILES)
+
+ else:
+ LOGO_FEATS, LOGO_FILES = np.load(os.path.join(os.path.dirname(__file__),'LOGO_FEATS.npy')), \
+ np.load(os.path.join(os.path.dirname(__file__),'LOGO_FILES.npy'))
+
+ DOMAIN_MAP_PATH = configs['SIAMESE_MODEL']['DOMAIN_MAP_PATH']
+
+ return ELE_MODEL, SIAMESE_THRE, SIAMESE_MODEL, LOGO_FEATS, LOGO_FILES, DOMAIN_MAP_PATH
\ No newline at end of file
diff --git a/configs.yaml b/configs.yaml
new file mode 100644
index 0000000..6b318ae
--- /dev/null
+++ b/configs.yaml
@@ -0,0 +1,11 @@
+ELE_MODEL: # element recognition model -- logo only
+ CFG_PATH: models/faster_rcnn.yaml # os.path.join(os.path.dirname(__file__), xxx)
+ WEIGHTS_PATH: models/rcnn_bet365.pth
+ DETECT_THRE: 0.05
+
+SIAMESE_MODEL:
+ NUM_CLASSES: 277 # number of brands, users don't need to modify this even the targetlist is expanded
+ MATCH_THRE: 0.87 # FIXME: threshold is 0.87 in phish-discovery?
+ WEIGHTS_PATH: models/resnetv2_rgb_new.pth.tar
+ TARGETLIST_PATH: models/expand_targetlist.zip
+ DOMAIN_MAP_PATH: models/domain_map.pkl
\ No newline at end of file
diff --git a/datasets/.DS_Store b/datasets/.DS_Store
new file mode 100644
index 0000000..787bd30
Binary files /dev/null and b/datasets/.DS_Store differ
diff --git a/datasets/overview.png b/datasets/overview.png
new file mode 100644
index 0000000..74ed439
Binary files /dev/null and b/datasets/overview.png differ
diff --git a/datasets/test_sites/.DS_Store b/datasets/test_sites/.DS_Store
new file mode 100644
index 0000000..ccf2878
Binary files /dev/null and b/datasets/test_sites/.DS_Store differ
diff --git a/datasets/test_sites/accounts.g.cdcde.com/.DS_Store b/datasets/test_sites/accounts.g.cdcde.com/.DS_Store
new file mode 100644
index 0000000..9c34498
Binary files /dev/null and b/datasets/test_sites/accounts.g.cdcde.com/.DS_Store differ
diff --git a/datasets/test_sites/accounts.g.cdcde.com/html.txt b/datasets/test_sites/accounts.g.cdcde.com/html.txt
new file mode 100644
index 0000000..b8c0e99
--- /dev/null
+++ b/datasets/test_sites/accounts.g.cdcde.com/html.txt
@@ -0,0 +1,4106 @@
+Sign in - Google Accounts
Sign in
Use your Google Account
Not your computer? Use Private Browsing windows to sign in. Learn more
\ No newline at end of file
diff --git a/datasets/test_sites/accounts.g.cdcde.com/info.txt b/datasets/test_sites/accounts.g.cdcde.com/info.txt
new file mode 100644
index 0000000..e797e4f
--- /dev/null
+++ b/datasets/test_sites/accounts.g.cdcde.com/info.txt
@@ -0,0 +1 @@
+https://accounts.g.cdcde.com/ServiceLogin?passive=1209600&osid=1&continue=https://plus.g.cdcde.com/&followup=https://plus.g.cdcde.com/
\ No newline at end of file
diff --git a/datasets/test_sites/accounts.g.cdcde.com/shot.png b/datasets/test_sites/accounts.g.cdcde.com/shot.png
new file mode 100644
index 0000000..a6660c3
Binary files /dev/null and b/datasets/test_sites/accounts.g.cdcde.com/shot.png differ
diff --git a/logo_matching.py b/logo_matching.py
new file mode 100644
index 0000000..ba65bdf
--- /dev/null
+++ b/logo_matching.py
@@ -0,0 +1,218 @@
+from PIL import Image, ImageOps
+from torchvision import transforms
+from utils import brand_converter, resolution_alignment, l2_norm
+from models import KNOWN_MODELS
+import torch
+import os
+import numpy as np
+from collections import OrderedDict
+from tqdm import tqdm
+from tldextract import tldextract
+import pickle
+
+def check_domain_brand_inconsistency(logo_boxes,
+ domain_map_path: str,
+ model, logo_feat_list,
+ file_name_list, shot_path: str,
+ url: str, ts: float,
+ topk: float=3):
+
+ # targetlist domain list
+ with open(domain_map_path, 'rb') as handle:
+ domain_map = pickle.load(handle)
+
+ print('number of logo boxes:', len(logo_boxes))
+ matched_target, matched_domain, matched_coord, this_conf = None, None, None, None
+
+ if len(logo_boxes) > 0:
+ # siamese prediction for logo box
+ for i, coord in enumerate(logo_boxes):
+
+ if i == topk:
+ break
+
+ min_x, min_y, max_x, max_y = coord
+ bbox = [float(min_x), float(min_y), float(max_x), float(max_y)]
+ matched_target, matched_domain, this_conf = pred_brand(model, domain_map,
+ logo_feat_list, file_name_list,
+ shot_path, bbox, t_s=ts, grayscale=False)
+ # print(target_this, domain_this, this_conf)
+ # domain matcher to avoid FP
+ if matched_target is not None:
+ matched_coord = coord
+ if tldextract.extract(url).domain in matched_domain: # domain and brand are consistent
+ matched_target = None
+ else:
+ break # break if target is matched
+
+ return brand_converter(matched_target), matched_domain, matched_coord, this_conf
+
+def load_model_weights(num_classes: int, weights_path: str):
+ # Initialize model
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
+ model = KNOWN_MODELS["BiT-M-R50x1"](head_size=num_classes, zero_head=True)
+
+ # Load weights
+ weights = torch.load(weights_path, map_location='cpu')
+ weights = weights['model'] if 'model' in weights.keys() else weights
+ new_state_dict = OrderedDict()
+ for k, v in weights.items():
+ name = k.split('module.')[1]
+ new_state_dict[name] = v
+
+ model.load_state_dict(new_state_dict)
+ model.to(device)
+ model.eval()
+ return model
+
+def cache_reference_list(model, targetlist_path: str, grayscale=False):
+ '''
+ cache the embeddings of the reference list
+ :param num_classes: number of protected brands
+ :param weights_path: siamese weights
+ :param targetlist_path: targetlist folder
+ :param grayscale: convert logo to grayscale or not, default is RGB
+ :return model: siamese model
+ :return logo_feat_list: targetlist embeddings
+ :return file_name_list: targetlist paths
+ '''
+
+ # Prediction for targetlists
+ logo_feat_list = []
+ file_name_list = []
+
+ for target in tqdm(os.listdir(targetlist_path)):
+ if target.startswith('.'): # skip hidden files
+ continue
+ for logo_path in os.listdir(os.path.join(targetlist_path, target)):
+ if logo_path.endswith('.png') or logo_path.endswith('.jpeg') or logo_path.endswith('.jpg') or logo_path.endswith('.PNG') \
+ or logo_path.endswith('.JPG') or logo_path.endswith('.JPEG'):
+ if logo_path.startswith('loginpage') or logo_path.startswith('homepage'): # skip homepage/loginpage
+ continue
+ logo_feat_list.append(get_embedding(img=os.path.join(targetlist_path, target, logo_path),
+ model=model, grayscale=grayscale))
+ file_name_list.append(str(os.path.join(targetlist_path, target, logo_path)))
+
+ return np.asarray(logo_feat_list), np.asarray(file_name_list)
+
+@torch.no_grad()
+def get_embedding(img, model, grayscale=False):
+ '''
+ Inference for a single image
+ :param img: image path in str or image in PIL.Image
+ :param model: model to make inference
+ :param imshow: enable display of image or not
+ :param title: title of displayed image
+ :param grayscale: convert image to grayscale or not
+ :return feature embedding of shape (2048,)
+ '''
+ # img_size = 224
+ img_size = 128
+ mean = [0.5, 0.5, 0.5]
+ std = [0.5, 0.5, 0.5]
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
+
+ img_transforms = transforms.Compose(
+ [transforms.ToTensor(),
+ transforms.Normalize(mean=mean, std=std),
+ ])
+
+ img = Image.open(img) if isinstance(img, str) else img
+ img = img.convert("L").convert("RGB") if grayscale else img.convert("RGB")
+
+ ## Resize the image while keeping the original aspect ratio
+ pad_color = 255 if grayscale else (255, 255, 255)
+ img = ImageOps.expand(img, (
+ (max(img.size) - img.size[0]) // 2, (max(img.size) - img.size[1]) // 2,
+ (max(img.size) - img.size[0]) // 2, (max(img.size) - img.size[1]) // 2), fill=pad_color)
+
+ img = img.resize((img_size, img_size))
+
+ # Predict the embedding
+ with torch.no_grad():
+ img = img_transforms(img)
+ img = img[None, ...].to(device)
+ logo_feat = model.features(img)
+ logo_feat = l2_norm(logo_feat).squeeze(0).cpu().numpy() # L2-normalization final shape is (2048,)
+
+ return logo_feat
+
+
+def pred_brand(model, domain_map, logo_feat_list, file_name_list, shot_path: str, gt_bbox, t_s, grayscale=False):
+ '''
+ Return predicted brand for one cropped image
+ :param model: model to use
+ :param domain_map: brand-domain dictionary
+ :param logo_feat_list: reference logo feature embeddings
+ :param file_name_list: reference logo paths
+ :param shot_path: path to the screenshot
+ :param gt_bbox: 1x4 np.ndarray/list/tensor bounding box coords
+ :param t_s: similarity threshold for siamese
+ :param grayscale: convert image(cropped) to grayscale or not
+ :return: predicted target, predicted target's domain
+ '''
+
+ try:
+ img = Image.open(shot_path)
+ except OSError: # if the image cannot be identified, return nothing
+ print('Screenshot cannot be open')
+ return None, None, None
+
+ ## get predicted box --> crop from screenshot
+ cropped = img.crop((gt_bbox[0], gt_bbox[1], gt_bbox[2], gt_bbox[3]))
+ img_feat = get_embedding(cropped, model, grayscale=grayscale)
+
+ ## get cosine similarity with every protected logo
+ sim_list = logo_feat_list @ img_feat.T # take dot product for every pair of embeddings (Cosine Similarity)
+ pred_brand_list = file_name_list
+
+ assert len(sim_list) == len(pred_brand_list)
+
+ ## get top 3 brands
+ idx = np.argsort(sim_list)[::-1][:3]
+ pred_brand_list = np.array(pred_brand_list)[idx]
+ sim_list = np.array(sim_list)[idx]
+
+ # top1,2,3 candidate logos
+ top3_logolist = [Image.open(x) for x in pred_brand_list]
+ top3_brandlist = [brand_converter(os.path.basename(os.path.dirname(x))) for x in pred_brand_list]
+ top3_domainlist = [domain_map[x] for x in top3_brandlist]
+ top3_simlist = sim_list
+
+ for j in range(3):
+ predicted_brand, predicted_domain = None, None
+
+ ## If we are trying those lower rank logo, the predicted brand of them should be the same as top1 logo, otherwise might be false positive
+ if top3_brandlist[j] != top3_brandlist[0]:
+ continue
+
+ ## If the largest similarity exceeds threshold
+ if top3_simlist[j] >= t_s:
+ predicted_brand = top3_brandlist[j]
+ predicted_domain = top3_domainlist[j]
+ final_sim = top3_simlist[j]
+
+ ## Else if not exceed, try resolution alignment, see if can improve
+ else:
+ cropped, candidate_logo = resolution_alignment(cropped, top3_logolist[j])
+ img_feat = get_embedding(cropped, model, grayscale=grayscale)
+ logo_feat = get_embedding(candidate_logo, model, grayscale=grayscale)
+ final_sim = logo_feat.dot(img_feat)
+ if final_sim >= t_s:
+ predicted_brand = top3_brandlist[j]
+ predicted_domain = top3_domainlist[j]
+ else:
+ break # no hope, do not try other lower rank logos
+
+ ## If there is a prediction, do aspect ratio check
+ if predicted_brand is not None:
+ ratio_crop = cropped.size[0] / cropped.size[1]
+ ratio_logo = top3_logolist[j].size[0] / top3_logolist[j].size[1]
+ # aspect ratios of matched pair must not deviate by more than factor of 2.5
+ if max(ratio_crop, ratio_logo) / min(ratio_crop, ratio_logo) > 2.5:
+ continue # did not pass aspect ratio check, try other
+ # If pass aspect ratio check, report a match
+ else:
+ return predicted_brand, predicted_domain, final_sim
+
+ return None, None, top3_simlist[0]
\ No newline at end of file
diff --git a/logo_recog.py b/logo_recog.py
new file mode 100644
index 0000000..10eb8a8
--- /dev/null
+++ b/logo_recog.py
@@ -0,0 +1,80 @@
+from detectron2.config import get_cfg
+from detectron2.engine import DefaultPredictor
+import cv2
+import numpy as np
+import torch
+
+def pred_rcnn(im, predictor):
+ '''
+ Perform inference for RCNN
+ :param im:
+ :param predictor:
+ :return:
+ '''
+ im = cv2.imread(im)
+
+ if im is not None:
+ if im.shape[-1] == 4:
+ im = cv2.cvtColor(im, cv2.COLOR_BGRA2BGR)
+ else:
+ return None, None, None, None
+
+ outputs = predictor(im)
+
+ instances = outputs['instances']
+ pred_classes = instances.pred_classes # tensor
+ pred_boxes = instances.pred_boxes # Boxes object
+
+ logo_boxes = pred_boxes[pred_classes == 1].tensor
+ input_boxes = pred_boxes[pred_classes == 0].tensor
+
+ scores = instances.scores # tensor
+ logo_scores = scores[pred_classes == 1]
+ input_scores = scores[pred_classes == 0]
+
+ return logo_boxes, logo_scores, input_boxes, input_scores
+
+
+def config_rcnn(cfg_path, weights_path, conf_threshold):
+ '''
+ Configure weights and confidence threshold
+ :param cfg_path:
+ :param weights_path:
+ :param conf_threshold:
+ :return:
+ '''
+ cfg = get_cfg()
+ cfg.merge_from_file(cfg_path)
+ cfg.MODEL.WEIGHTS = weights_path
+ cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = conf_threshold
+ # uncomment if you installed detectron2 cpu version
+ if not torch.cuda.is_available():
+ cfg.MODEL.DEVICE = 'cpu'
+
+ # Initialize model
+ predictor = DefaultPredictor(cfg)
+ return predictor
+
+
+def vis(img_path, pred_boxes):
+ '''
+ Visualize rcnn predictions
+ :param img_path: str
+ :param pred_boxes: torch.Tensor of shape Nx4, bounding box coordinates in (x1, y1, x2, y2)
+ :param pred_classes: torch.Tensor of shape Nx1 0 for logo, 1 for input, 2 for button, 3 for label(text near input), 4 for block
+ :return None
+ '''
+
+ check = cv2.imread(img_path)
+ if pred_boxes is None or len(pred_boxes) == 0:
+ return check
+ pred_boxes = pred_boxes.numpy() if not isinstance(pred_boxes, np.ndarray) else pred_boxes
+
+ # draw rectangle
+ for j, box in enumerate(pred_boxes):
+ if j == 0:
+ cv2.rectangle(check, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), (255, 255, 0), 2)
+ else:
+ cv2.rectangle(check, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), (36, 255, 12), 2)
+
+ return check
\ No newline at end of file
diff --git a/models.py b/models.py
new file mode 100644
index 0000000..34d3992
--- /dev/null
+++ b/models.py
@@ -0,0 +1,195 @@
+# Copyright 2020 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Lint as: python3
+"""Bottleneck ResNet v2 with GroupNorm and Weight Standardization."""
+
+from collections import OrderedDict # pylint: disable=g-importing-member
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class StdConv2d(nn.Conv2d):
+
+ def forward(self, x):
+ w = self.weight
+ v, m = torch.var_mean(w, dim=[1, 2, 3], keepdim=True, unbiased=False)
+ w = (w - m) / torch.sqrt(v + 1e-10)
+ return F.conv2d(x, w, self.bias, self.stride, self.padding,
+ self.dilation, self.groups)
+
+
+def conv3x3(cin, cout, stride=1, groups=1, bias=False):
+ return StdConv2d(cin, cout, kernel_size=3, stride=stride,
+ padding=1, bias=bias, groups=groups)
+
+
+def conv1x1(cin, cout, stride=1, bias=False):
+ return StdConv2d(cin, cout, kernel_size=1, stride=stride,
+ padding=0, bias=bias)
+
+
+def tf2th(conv_weights):
+ """Possibly convert HWIO to OIHW."""
+ if conv_weights.ndim == 4:
+ conv_weights = conv_weights.transpose([3, 2, 0, 1])
+ return torch.from_numpy(conv_weights)
+
+
+class PreActBottleneck(nn.Module):
+ """Pre-activation (v2) bottleneck block.
+
+ Follows the implementation of "Identity Mappings in Deep Residual Networks":
+ https://github.com/KaimingHe/resnet-1k-layers/blob/master/resnet-pre-act.lua
+
+ Except it puts the stride on 3x3 conv when available.
+ """
+
+ def __init__(self, cin, cout=None, cmid=None, stride=1):
+ super().__init__()
+ cout = cout or cin
+ cmid = cmid or cout//4
+
+ self.gn1 = nn.GroupNorm(32, cin)
+ self.conv1 = conv1x1(cin, cmid)
+ self.gn2 = nn.GroupNorm(32, cmid)
+ self.conv2 = conv3x3(cmid, cmid, stride) # Original code has it on conv1!!
+ self.gn3 = nn.GroupNorm(32, cmid)
+ self.conv3 = conv1x1(cmid, cout)
+ self.relu = nn.ReLU(inplace=True)
+
+ if (stride != 1 or cin != cout):
+ # Projection also with pre-activation according to paper.
+ self.downsample = conv1x1(cin, cout, stride)
+
+ def forward(self, x):
+ out = self.relu(self.gn1(x))
+
+ # Residual branch
+ residual = x
+ if hasattr(self, 'downsample'):
+ residual = self.downsample(out)
+
+ # Unit's branch
+ out = self.conv1(out)
+ out = self.conv2(self.relu(self.gn2(out)))
+ out = self.conv3(self.relu(self.gn3(out)))
+
+ return out + residual
+
+ def load_from(self, weights, prefix=''):
+ convname = 'standardized_conv2d'
+ with torch.no_grad():
+ self.conv1.weight.copy_(tf2th(weights[f'{prefix}a/{convname}/kernel']))
+ self.conv2.weight.copy_(tf2th(weights[f'{prefix}b/{convname}/kernel']))
+ self.conv3.weight.copy_(tf2th(weights[f'{prefix}c/{convname}/kernel']))
+ self.gn1.weight.copy_(tf2th(weights[f'{prefix}a/group_norm/gamma']))
+ self.gn2.weight.copy_(tf2th(weights[f'{prefix}b/group_norm/gamma']))
+ self.gn3.weight.copy_(tf2th(weights[f'{prefix}c/group_norm/gamma']))
+ self.gn1.bias.copy_(tf2th(weights[f'{prefix}a/group_norm/beta']))
+ self.gn2.bias.copy_(tf2th(weights[f'{prefix}b/group_norm/beta']))
+ self.gn3.bias.copy_(tf2th(weights[f'{prefix}c/group_norm/beta']))
+ if hasattr(self, 'downsample'):
+ w = weights[f'{prefix}a/proj/{convname}/kernel']
+ self.downsample.weight.copy_(tf2th(w))
+
+
+class ResNetV2(nn.Module):
+ """Implementation of Pre-activation (v2) ResNet mode."""
+
+ def __init__(self, block_units, width_factor, head_size=21843, zero_head=False):
+ super().__init__()
+ wf = width_factor # shortcut 'cause we'll use it a lot.
+
+ # The following will be unreadable if we split lines.
+ # pylint: disable=line-too-long
+ self.root = nn.Sequential(OrderedDict([
+ ('conv', StdConv2d(3, 64*wf, kernel_size=7, stride=2, padding=3, bias=False)),
+ ('pad', nn.ConstantPad2d(1, 0)),
+ ('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=0)),
+ # The following is subtly not the same!
+ # ('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),
+ ]))
+
+ self.body = nn.Sequential(OrderedDict([
+ ('block1', nn.Sequential(OrderedDict(
+ [('unit01', PreActBottleneck(cin=64*wf, cout=256*wf, cmid=64*wf))] +
+ [(f'unit{i:02d}', PreActBottleneck(cin=256*wf, cout=256*wf, cmid=64*wf)) for i in range(2, block_units[0] + 1)],
+ ))),
+ ('block2', nn.Sequential(OrderedDict(
+ [('unit01', PreActBottleneck(cin=256*wf, cout=512*wf, cmid=128*wf, stride=2))] +
+ [(f'unit{i:02d}', PreActBottleneck(cin=512*wf, cout=512*wf, cmid=128*wf)) for i in range(2, block_units[1] + 1)],
+ ))),
+ ('block3', nn.Sequential(OrderedDict(
+ [('unit01', PreActBottleneck(cin=512*wf, cout=1024*wf, cmid=256*wf, stride=2))] +
+ [(f'unit{i:02d}', PreActBottleneck(cin=1024*wf, cout=1024*wf, cmid=256*wf)) for i in range(2, block_units[2] + 1)],
+ ))),
+ ('block4', nn.Sequential(OrderedDict(
+ [('unit01', PreActBottleneck(cin=1024*wf, cout=2048*wf, cmid=512*wf, stride=2))] +
+ [(f'unit{i:02d}', PreActBottleneck(cin=2048*wf, cout=2048*wf, cmid=512*wf)) for i in range(2, block_units[3] + 1)],
+ ))),
+ ]))
+ # pylint: enable=line-too-long
+
+ self.zero_head = zero_head
+ self.head = nn.Sequential(OrderedDict([
+ ('gn', nn.GroupNorm(32, 2048*wf)),
+ ('relu', nn.ReLU(inplace=True)),
+ ('avg', nn.AdaptiveAvgPool2d(output_size=1)),
+ ('conv', nn.Conv2d(2048*wf, head_size, kernel_size=1, bias=True)),
+ ]))
+
+ def features(self, x):
+ x = self.head[:-1](self.body(self.root(x)))
+
+ return x.squeeze(-1).squeeze(-1)
+
+ def forward(self, x):
+ x = self.head(self.body(self.root(x)))
+ assert x.shape[-2:] == (1, 1) # We should have no spatial shape left.
+ return x[...,0,0]
+
+ def load_from(self, weights, prefix='resnet/'):
+ with torch.no_grad():
+ self.root.conv.weight.copy_(tf2th(weights[f'{prefix}root_block/standardized_conv2d/kernel'])) # pylint: disable=line-too-long
+ self.head.gn.weight.copy_(tf2th(weights[f'{prefix}group_norm/gamma']))
+ self.head.gn.bias.copy_(tf2th(weights[f'{prefix}group_norm/beta']))
+ if self.zero_head:
+ nn.init.zeros_(self.head.conv.weight)
+ nn.init.zeros_(self.head.conv.bias)
+ else:
+ self.head.conv.weight.copy_(tf2th(weights[f'{prefix}head/conv2d/kernel'])) # pylint: disable=line-too-long
+ self.head.conv.bias.copy_(tf2th(weights[f'{prefix}head/conv2d/bias']))
+
+ for bname, block in self.body.named_children():
+ for uname, unit in block.named_children():
+ unit.load_from(weights, prefix=f'{prefix}{bname}/{uname}/')
+
+
+KNOWN_MODELS = OrderedDict([
+ ('BiT-M-R50x1', lambda *a, **kw: ResNetV2([3, 4, 6, 3], 1, *a, **kw)),
+ ('BiT-M-R50x3', lambda *a, **kw: ResNetV2([3, 4, 6, 3], 3, *a, **kw)),
+ ('BiT-M-R101x1', lambda *a, **kw: ResNetV2([3, 4, 23, 3], 1, *a, **kw)),
+ ('BiT-M-R101x3', lambda *a, **kw: ResNetV2([3, 4, 23, 3], 3, *a, **kw)),
+ ('BiT-M-R152x2', lambda *a, **kw: ResNetV2([3, 8, 36, 3], 2, *a, **kw)),
+ ('BiT-M-R152x4', lambda *a, **kw: ResNetV2([3, 8, 36, 3], 4, *a, **kw)),
+ ('BiT-S-R50x1', lambda *a, **kw: ResNetV2([3, 4, 6, 3], 1, *a, **kw)),
+ ('BiT-S-R50x3', lambda *a, **kw: ResNetV2([3, 4, 6, 3], 3, *a, **kw)),
+ ('BiT-S-R101x1', lambda *a, **kw: ResNetV2([3, 4, 23, 3], 1, *a, **kw)),
+ ('BiT-S-R101x3', lambda *a, **kw: ResNetV2([3, 4, 23, 3], 3, *a, **kw)),
+ ('BiT-S-R152x2', lambda *a, **kw: ResNetV2([3, 8, 36, 3], 2, *a, **kw)),
+ ('BiT-S-R152x4', lambda *a, **kw: ResNetV2([3, 8, 36, 3], 4, *a, **kw)),
+])
\ No newline at end of file
diff --git a/parse_results.py b/parse_results.py
new file mode 100644
index 0000000..45d197f
--- /dev/null
+++ b/parse_results.py
@@ -0,0 +1,40 @@
+import os.path
+import shutil
+from datetime import datetime
+
+def get_pos_site(result_txt):
+
+ df = [x.strip().split('\t') for x in open(result_txt, encoding='ISO-8859-1').readlines()]
+ df_pos = [x for x in df if (x[2] == '1')]
+ df_pos = [x for x in df_pos if x[3] not in ['Google', 'Webmail Provider',
+ 'WhatsApp', 'Luno']]
+ return df_pos
+
+
+
+if __name__ == '__main__':
+ # today = datetime.now().strftime('%Y%m%d')
+ today = '20231223'
+ result_txt = f'/home/ruofan/git_space/PhishEmail/datasets/{today}_results.txt'
+ df_pos = get_pos_site(result_txt)
+ print(len(df_pos))
+
+ pos_result_txt = f'/home/ruofan/git_space/PhishEmail/datasets/{today}_pos.txt'
+ pos_result_dir = f'/home/ruofan/git_space/PhishEmail/datasets/sjtu_phish_pos/{today}'
+ os.makedirs(pos_result_dir, exist_ok=True)
+
+ for x in df_pos:
+ url = x[1]
+ if os.path.exists(pos_result_txt) and url in open(pos_result_txt).read():
+ pass
+ else:
+ with open(pos_result_txt, 'a+') as f:
+ f.write(url+'\n')
+
+ # try:
+ shutil.copytree(os.path.join(f'/home/ruofan/git_space/PhishEmail/datasets/sjtu_phish/{today}', x[0]),
+ os.path.join(pos_result_dir, x[0]))
+ # except FileExistsError:
+ # pass
+
+ print(df_pos)
\ No newline at end of file
diff --git a/phishpedia.py b/phishpedia.py
new file mode 100644
index 0000000..99ed0a2
--- /dev/null
+++ b/phishpedia.py
@@ -0,0 +1,184 @@
+
+import time
+import sys
+from datetime import datetime
+import argparse
+import os
+import torch
+from lib.phishpedia.configs import load_config
+from tldextract import tldextract
+import cv2
+from logo_recog import pred_rcnn, vis
+from logo_matching import check_domain_brand_inconsistency
+from text_recog import check_email_credential_taking
+import pickle
+from tqdm import tqdm
+from paddleocr import PaddleOCR
+import re
+
+os.environ['KMP_DUPLICATE_LIB_OK']='True'
+
+class PhishpediaWrapper:
+ _caller_prefix = "PhishpediaWrapper"
+ _DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
+
+ def __init__(self):
+ self._load_config()
+ self._to_device()
+
+ def _load_config(self):
+ self.ELE_MODEL, self.SIAMESE_THRE, self.SIAMESE_MODEL, \
+ self.LOGO_FEATS, self.LOGO_FILES, \
+ self.DOMAIN_MAP_PATH = load_config()
+
+ self.OCR_MODEL = PaddleOCR(use_angle_cls=True,
+ lang='ch',
+ show_log=False,
+ use_gpu=torch.cuda.is_available()) # need to run only once to download and load model into memory
+ print(f'Length of reference list = {len(self.LOGO_FEATS)}')
+
+ def _to_device(self):
+ self.SIAMESE_MODEL.to(self._DEVICE)
+
+
+ '''Phishpedia'''
+ def test_orig_phishpedia(self, url, screenshot_path):
+ # 0 for benign, 1 for phish, default is benign
+ phish_category = 0
+ pred_target = None
+ matched_domain = None
+ siamese_conf = None
+ logo_recog_time = 0
+ logo_match_time = 0
+ print("Entering phishpedia")
+
+ ####################### Step1: Logo detector ##############################################
+ start_time = time.time()
+ pred_boxes, _, _, _ = pred_rcnn(im=screenshot_path, predictor=self.ELE_MODEL)
+ logo_recog_time = time.time() - start_time
+
+ if pred_boxes is not None:
+ pred_boxes = pred_boxes.detach().cpu().numpy()
+ plotvis = vis(screenshot_path, pred_boxes)
+
+ # If no element is reported
+ if pred_boxes is None or len(pred_boxes) == 0:
+ print('No logo is detected')
+ return phish_category, pred_target, matched_domain, plotvis, siamese_conf, pred_boxes, logo_recog_time, logo_match_time
+
+ print('Entering siamese')
+
+ ######################## Step2: Siamese (Logo matcher) ########################################
+ start_time = time.time()
+ pred_target, matched_domain, matched_coord, siamese_conf = check_domain_brand_inconsistency(logo_boxes=pred_boxes,
+ domain_map_path=self.DOMAIN_MAP_PATH,
+ model=self.SIAMESE_MODEL,
+ logo_feat_list=self.LOGO_FEATS,
+ file_name_list=self.LOGO_FILES,
+ url=url,
+ shot_path=screenshot_path,
+ ts=self.SIAMESE_THRE)
+ logo_match_time = time.time() - start_time
+
+ if pred_target is None:
+ ### ask for email
+ ask_for_emails, matched_email = check_email_credential_taking(self.OCR_MODEL, screenshot_path)
+ if ask_for_emails and (tldextract.extract(url).domain not in matched_email.replace('@', '')): # domain and brand are consistent
+ matched_target = matched_email.replace('@', '')
+ cv2.putText(plotvis, "No logo but asks for email credentials",
+ (20, 20),
+ cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 0), 2)
+ print('No logo but asks for specific email credentials')
+ return 1, matched_target, matched_domain, plotvis, siamese_conf, pred_boxes, logo_recog_time, logo_match_time
+
+ print('Did not match to any brand, report as benign')
+ return phish_category, pred_target, matched_domain, plotvis, siamese_conf, pred_boxes, logo_recog_time, logo_match_time
+
+ else:
+ phish_category = 1
+ # Visualize, add annotations
+ cv2.putText(plotvis, "Target: {} with confidence {:.4f}".format(pred_target, siamese_conf),
+ (int(matched_coord[0] + 20), int(matched_coord[1] + 20)),
+ cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 0), 2)
+
+
+ return phish_category, pred_target, matched_domain, plotvis, siamese_conf, pred_boxes, logo_recog_time, logo_match_time
+
+
+if __name__ == '__main__':
+
+
+
+ '''update domain map'''
+ # with open('./lib/phishpedia/models/domain_map.pkl', "rb") as handle:
+ # domain_map = pickle.load(handle)
+ #
+ # domain_map['weibo'] = ['sina', 'weibo']
+ #
+ # with open('./lib/phishpedia/models/domain_map.pkl', "wb") as handle:
+ # pickle.dump(domain_map, handle)
+ # exit()
+
+ '''run'''
+ today = datetime.now().strftime('%Y%m%d')
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--folder", required=True, type=str)
+ parser.add_argument("--output_txt", default=f'{today}_results.txt', help="Output txt path")
+ args = parser.parse_args()
+
+ request_dir = args.folder
+ phishpedia_cls = PhishpediaWrapper()
+ result_txt = args.output_txt
+
+ os.makedirs(request_dir, exist_ok=True)
+
+ for folder in tqdm(os.listdir(request_dir)):
+ html_path = os.path.join(request_dir, folder, "html.txt")
+ screenshot_path = os.path.join(request_dir, folder, "shot.png")
+ info_path = os.path.join(request_dir, folder, 'info.json')
+
+ if not os.path.exists(screenshot_path):
+ continue
+
+ url = eval(open(info_path).read())['url']
+
+ if os.path.exists(result_txt) and url in open(result_txt, encoding='ISO-8859-1').read():
+ continue
+
+ _forbidden_suffixes = r"\.(mp3|wav|wma|ogg|mkv|zip|tar|xz|rar|z|deb|bin|iso|csv|tsv|dat|txt|css|log|sql|xml|sql|mdb|apk|bat|bin|exe|jar|wsf|fnt|fon|otf|ttf|ai|bmp|gif|ico|jp(e)?g|png|ps|psd|svg|tif|tiff|cer|rss|key|odp|pps|ppt|pptx|c|class|cpp|cs|h|java|sh|swift|vb|odf|xlr|xls|xlsx|bak|cab|cfg|cpl|cur|dll|dmp|drv|icns|ini|lnk|msi|sys|tmp|3g2|3gp|avi|flv|h264|m4v|mov|mp4|mp(e)?g|rm|swf|vob|wmv|doc(x)?|odt|rtf|tex|txt|wks|wps|wpd)$"
+ if re.search(_forbidden_suffixes, url, re.IGNORECASE):
+ continue
+
+ phish_category, pred_target, matched_domain, \
+ plotvis, siamese_conf, pred_boxes, \
+ logo_recog_time, logo_match_time = phishpedia_cls.test_orig_phishpedia(url, screenshot_path)
+
+ try:
+ with open(result_txt, "a+", encoding='ISO-8859-1') as f:
+ f.write(folder + "\t")
+ f.write(url + "\t")
+ f.write(str(phish_category) + "\t")
+ f.write(str(pred_target) + "\t") # write top1 prediction only
+ f.write(str(matched_domain) + "\t")
+ f.write(str(siamese_conf) + "\t")
+ f.write(str(round(logo_recog_time, 4)) + "\t")
+ f.write(str(round(logo_match_time, 4)) + "\n")
+ except UnicodeError:
+ with open(result_txt, "a+", encoding='utf-8') as f:
+ f.write(folder + "\t")
+ f.write(url + "\t")
+ f.write(str(phish_category) + "\t")
+ f.write(str(pred_target) + "\t") # write top1 prediction only
+ f.write(str(matched_domain) + "\t")
+ f.write(str(siamese_conf) + "\t")
+ f.write(str(round(logo_recog_time, 4)) + "\t")
+ f.write(str(round(logo_match_time, 4)) + "\n")
+ if phish_category:
+ os.makedirs(os.path.join(request_dir, folder), exist_ok=True)
+ cv2.imwrite(os.path.join(request_dir, folder, "predict.png"), plotvis)
+
+
+ # import matplotlib.pyplot as plt
+ # plt.imshow(cropped)
+ # plt.savefig('./debug.png')
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
index d9982f8..12d5922 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,22 +1,14 @@
-torchsummary
scipy
tldextract
opencv-python
pandas
numpy
tqdm
-Pillow
+Pillow==8.4.0
pathlib
fvcore
pycocotools
scikit-learn
-advertorch
-python-telegram-bot
-gspread
-oauth2client
lxml
-
-
-
-
-
+gdown
+paddleocr>=2.0.1
\ No newline at end of file
diff --git a/setup.sh b/setup.sh
old mode 100755
new mode 100644
index 6269f44..dd0000f
--- a/setup.sh
+++ b/setup.sh
@@ -1,21 +1,39 @@
#!/bin/bash
+retry_count=3 # Number of retries
+
+download_with_retry() {
+ local file_id=$1
+ local file_name=$2
+ local count=0
+
+ until [ $count -ge $retry_count ]
+ do
+ gdown --id "$file_id" -O "$file_name" && break # attempt to download and break if successful
+ count=$((count+1))
+ echo "Retry $count of $retry_count..."
+ sleep 1 # wait for 5 seconds before retrying
+ done
+
+ if [ $count -ge $retry_count ]; then
+ echo "Failed to download $file_name after $retry_count attempts."
+ fi
+}
FILEDIR=$(pwd)
CONDA_BASE=$(conda info --base)
source "$CONDA_BASE/etc/profile.d/conda.sh"
-conda info --envs | grep -w "myenv" > /dev/null
+conda info --envs | grep -w "phishpedia" > /dev/null
if [ $? -eq 0 ]; then
- echo "Activating Conda environment myenv"
- conda activate myenv
+ echo "Activating Conda environment phishpedia"
+ conda activate phishpedia
else
- echo "Creating and activating new Conda environment $ENV_NAME with Python 3.8"
- conda create -n myenv python=3.8
- conda activate myenv
+ echo "Creating and activating new Conda environment phishpedia with Python 3.8"
+ conda create -n phishpedia python=3.8
+ conda activate phishpedia
fi
-pip install -r requirements.txt
OS=$(uname -s)
@@ -23,39 +41,66 @@ if [[ "$OS" == "Darwin" ]]; then
echo "Installing PyTorch and torchvision for macOS."
pip install torch==1.9.0 torchvision==0.10.0 torchaudio==0.9.0
python -m pip install detectron2 -f "https://dl.fbaipublicfiles.com/detectron2/wheels/cpu/torch1.9/index.html"
+ python -m pip install paddlepaddle==2.5.1 -i https://mirror.baidu.com/pypi/simple
else
# Check if NVIDIA GPU is available for Linux and Windows
- if command -v nvcc &> /dev/null; then
+ if command -v nvcc || command -v nvidia-smi &> /dev/null; then
echo "CUDA is detected, installing GPU-supported PyTorch and torchvision."
pip install torch==1.9.0+cu111 torchvision==0.10.0+cu111 torchaudio==0.9.0 -f "https://download.pytorch.org/whl/torch_stable.html"
python -m pip install detectron2 -f "https://dl.fbaipublicfiles.com/detectron2/wheels/cu111/torch1.9/index.html"
+ python -m pip install paddlepaddle-gpu==2.5.1 -i https://mirror.baidu.com/pypi/simple
else
echo "No CUDA detected, installing CPU-only PyTorch and torchvision."
pip install torch==1.9.0+cpu torchvision==0.10.0+cpu torchaudio==0.9.0 -f "https://download.pytorch.org/whl/torch_stable.html"
python -m pip install detectron2 -f "https://dl.fbaipublicfiles.com/detectron2/wheels/cpu/torch1.9/index.html"
+ python -m pip install paddlepaddle==2.5.1 -i https://mirror.baidu.com/pypi/simple
fi
fi
+pip install -r requirements.txt
+
## Download models
-pip install -v .
-package_location=$(pip show phishpedia | grep Location | awk '{print $2}')
+echo "Going to the directory of package Phishpedia in Conda environment myenv."
+mkdir -p models/
+cd models/
+
-if [ -z "$package_location" ]; then
- echo "Package Phishpedia not found in the Conda environment myenv."
- exit 1
+# RCNN model weights
+if [ -f "rcnn_bet365.pth" ]; then
+ echo "RCNN model weights exists... skip"
else
- echo "Going to the directory of package Phishpedia in Conda environment myenv."
- cd "$package_location/phishpedia/src/detectron2_pedia/output/rcnn_2" || exit
- pip install gdown
- gdown --id 1tE2Mu5WC8uqCxei3XqAd7AWaP5JTmVWH
- cd "$package_location/phishpedia/src/siamese_pedia/" || exit
- gdown --id 1H0Q_DbdKPLFcZee8I14K62qV7TTy7xvS
- gdown --id 1fr5ZxBKyDiNZ_1B6rRAfZbAHBBoUjZ7I
- gdown --id 1qSdkSSoCYUkZMKs44Rup_1DPBxHnEKl1
+ download_with_retry 1tE2Mu5WC8uqCxei3XqAd7AWaP5JTmVWH rcnn_bet365.pth
fi
-# Replace the placeholder in the YAML template
+# Faster RCNN config
+if [ -f "faster_rcnn.yaml" ]; then
+ echo "RCNN model config exists... skip"
+else
+ download_with_retry 1Q6lqjpl4exW7q_dPbComcj0udBMDl8CW faster_rcnn.yaml
+fi
-sed "s|CONDA_ENV_PATH_PLACEHOLDER|$package_location/phishpedia|g" "$FILEDIR/phishpedia/configs_template.yaml" > "$FILEDIR/phishpedia/configs.yaml"
+# Siamese model weights
+if [ -f "resnetv2_rgb_new.pth.tar" ]; then
+ echo "Siamese model weights exists... skip"
+else
+ download_with_retry 1H0Q_DbdKPLFcZee8I14K62qV7TTy7xvS resnetv2_rgb_new.pth.tar
+fi
+
+# Reference list
+if [ -f "expand_targetlist.zip" ]; then
+ echo "Reference list exists... skip"
+else
+ download_with_retry 1fr5ZxBKyDiNZ_1B6rRAfZbAHBBoUjZ7I expand_targetlist.zip
+fi
-echo "All packages installed successfully!"
+# Domain map
+if [ -f "domain_map.pkl" ]; then
+ echo "Domain map exists... skip"
+else
+ download_with_retry 1qSdkSSoCYUkZMKs44Rup_1DPBxHnEKl1 domain_map.pkl
+fi
+
+
+
+# Replace the placeholder in the YAML template
+echo "All packages installed successfully!"
\ No newline at end of file
diff --git a/text_recog.py b/text_recog.py
new file mode 100644
index 0000000..eb610ba
--- /dev/null
+++ b/text_recog.py
@@ -0,0 +1,31 @@
+import re
+
+def pred_text_in_image(OCR_MODEL, shot_path):
+
+ result = OCR_MODEL.ocr(shot_path, cls=True)
+ if result is None or result[0] is None:
+ return ''
+
+ most_fit_results = result[0]
+ ocr_text = [line[1][0] for line in most_fit_results]
+ detected_text = ' '.join(ocr_text)
+
+ return detected_text
+
+def check_email_credential_taking(OCR_MODEL, shot_path):
+ detected_text = pred_text_in_image(OCR_MODEL, shot_path)
+ if len(detected_text) > 0:
+ return rule_matching(detected_text)
+ return False, None
+
+def rule_matching(detected_text):
+ email_login_pattern = r'邮箱.*登录|邮箱.*登陆|邮件.*登录|邮件.*登陆'
+ specified_email_pattern = r'@[\w.-]+\.\w+'
+
+ if re.findall(email_login_pattern, detected_text):
+ find_email = re.findall(specified_email_pattern, detected_text)
+ if find_email:
+ return True, find_email[0]
+
+ return False, None
+
diff --git a/utils.py b/utils.py
new file mode 100644
index 0000000..bfc41d8
--- /dev/null
+++ b/utils.py
@@ -0,0 +1,134 @@
+import torch.nn.functional as F
+from PIL import Image
+import math
+
+def resolution_alignment(img1, img2):
+ '''
+ Resize two images according to the minimum resolution between the two
+ :param img1: first image in PIL.Image
+ :param img2: second image in PIL.Image
+ :return: resized img1 in PIL.Image, resized img2 in PIL.Image
+ '''
+ w1, h1 = img1.size
+ w2, h2 = img2.size
+ w_min, h_min = min(w1, w2), min(h1, h2)
+ if w_min == 0 or h_min == 0: ## something wrong, stop resizing
+ return img1, img2
+ if w_min < h_min:
+ img1_resize = img1.resize((int(w_min), math.ceil(h1 * (w_min/w1)))) # ceiling to prevent rounding to 0
+ img2_resize = img2.resize((int(w_min), math.ceil(h2 * (w_min/w2))))
+ else:
+ img1_resize = img1.resize((math.ceil(w1 * (h_min/h1)), int(h_min)))
+ img2_resize = img2.resize((math.ceil(w2 * (h_min/h2)), int(h_min)))
+ return img1_resize, img2_resize
+
+def brand_converter(brand_name):
+ '''
+ Helper function to deal with inconsistency in brand naming
+ '''
+ if brand_name == 'Adobe Inc.' or brand_name == 'Adobe Inc':
+ return 'Adobe'
+ elif brand_name == 'ADP, LLC' or brand_name == 'ADP, LLC.':
+ return 'ADP'
+ elif brand_name == 'Amazon.com Inc.' or brand_name == 'Amazon.com Inc':
+ return 'Amazon'
+ elif brand_name == 'Americanas.com S,A Comercio Electrnico':
+ return 'Americanas.com S'
+ elif brand_name == 'AOL Inc.' or brand_name == 'AOL Inc':
+ return 'AOL'
+ elif brand_name == 'Apple Inc.' or brand_name == 'Apple Inc':
+ return 'Apple'
+ elif brand_name == 'AT&T Inc.' or brand_name == 'AT&T Inc':
+ return 'AT&T'
+ elif brand_name == 'Banco do Brasil S.A.':
+ return 'Banco do Brasil S.A'
+ elif brand_name == 'Credit Agricole S.A.':
+ return 'Credit Agricole S.A'
+ elif brand_name == 'DGI (French Tax Authority)':
+ return 'DGI French Tax Authority'
+ elif brand_name == 'DHL Airways, Inc.' or brand_name == 'DHL Airways, Inc' or brand_name == 'DHL':
+ return 'DHL Airways'
+ elif brand_name == 'Dropbox, Inc.' or brand_name == 'Dropbox, Inc':
+ return 'Dropbox'
+ elif brand_name == 'eBay Inc.' or brand_name == 'eBay Inc':
+ return 'eBay'
+ elif brand_name == 'Facebook, Inc.' or brand_name == 'Facebook, Inc':
+ return 'Facebook'
+ elif brand_name == 'Free (ISP)':
+ return 'Free ISP'
+ elif brand_name == 'Google Inc.' or brand_name == 'Google Inc':
+ return 'Google'
+ elif brand_name == 'Mastercard International Incorporated':
+ return 'Mastercard International'
+ elif brand_name == 'Netflix Inc.' or brand_name == 'Netflix Inc':
+ return 'Netflix'
+ elif brand_name == 'PayPal Inc.' or brand_name == 'PayPal Inc':
+ return 'PayPal'
+ elif brand_name == 'Royal KPN N.V.':
+ return 'Royal KPN N.V'
+ elif brand_name == 'SF Express Co.':
+ return 'SF Express Co'
+ elif brand_name == 'SNS Bank N.V.':
+ return 'SNS Bank N.V'
+ elif brand_name == 'Square, Inc.' or brand_name == 'Square, Inc':
+ return 'Square'
+ elif brand_name == 'Webmail Providers':
+ return 'Webmail Provider'
+ elif brand_name == 'Yahoo! Inc' or brand_name == 'Yahoo! Inc.':
+ return 'Yahoo!'
+ elif brand_name == 'Microsoft OneDrive' or brand_name == 'Office365' or brand_name == 'Outlook':
+ return 'Microsoft'
+ elif brand_name == 'Global Sources (HK)':
+ return 'Global Sources HK'
+ elif brand_name == 'T-Online':
+ return 'Deutsche Telekom'
+ elif brand_name == 'Airbnb, Inc':
+ return 'Airbnb, Inc.'
+ elif brand_name == 'azul':
+ return 'Azul'
+ elif brand_name == 'Raiffeisen Bank S.A':
+ return 'Raiffeisen Bank S.A.'
+ elif brand_name == 'Twitter, Inc' or brand_name == 'Twitter':
+ return 'Twitter, Inc.'
+ elif brand_name == 'capital_one':
+ return 'Capital One Financial Corporation'
+ elif brand_name == 'la_banque_postale':
+ return 'La Banque postale'
+ elif brand_name == 'db':
+ return 'Deutsche Bank AG'
+ elif brand_name == 'Swiss Post' or brand_name == 'PostFinance':
+ return 'PostFinance'
+ elif brand_name == 'grupo_bancolombia':
+ return 'Bancolombia'
+ elif brand_name == 'barclays':
+ return 'Barclays Bank Plc'
+ elif brand_name == 'gov_uk':
+ return 'Government of the United Kingdom'
+ elif brand_name == 'Aruba S.p.A':
+ return 'Aruba S.p.A.'
+ elif brand_name == 'TSB Bank Plc':
+ return 'TSB Bank Limited'
+ elif brand_name == 'strato':
+ return 'Strato AG'
+ elif brand_name == 'cogeco':
+ return 'Cogeco'
+ elif brand_name == 'Canada Revenue Agency':
+ return 'Government of Canada'
+ elif brand_name == 'UniCredit Bulbank':
+ return 'UniCredit Bank Aktiengesellschaft'
+ elif brand_name == 'ameli_fr':
+ return 'French Health Insurance'
+ elif brand_name == 'Banco de Credito del Peru':
+ return 'bcp'
+ else:
+ return brand_name
+
+def l2_norm(x):
+ """
+ l2 normalization
+ :param x:
+ :return:
+ """
+ if len(x.shape):
+ x = x.reshape((x.shape[0], -1))
+ return F.normalize(x, p=2, dim=1)
\ No newline at end of file