diff --git a/README.md b/README.md index 857d249..ed38594 100644 --- a/README.md +++ b/README.md @@ -29,7 +29,7 @@ ## Framework - + ```Input```: A URL and its screenshot ```Output```: Phish/Benign, Phishing target - Step 1: Enter Deep Object Detection Model, get predicted logos and inputs (inputs are not used for later prediction, just for explanation) @@ -40,88 +40,54 @@ ## Project structure ``` -- src - - adv_attack: adversarial attacking scripts - - detectron2_pedia: training script for object detector - |_ output - |_ rcnn_2 - |_ rcnn_bet365.pth - - siamese_pedia: inference script for siamese - |_ siamese_retrain: training script for siamese - |_ expand_targetlist - |_ 1&1 Ionos - |_ ... - |_ domain_map.pkl - |_ resnetv2_rgb_new.pth.tar - - siamese.py: main script for siamese - - pipeline_eval.py: evaluation script for general experiment - -- tele: telegram scripts to vote for phishing -- phishpedia_config.py: config script for phish-discovery experiment -- phishpedia_main.py: main script for phish-discovery experiment +- logo_recog.py: Deep Object Detection Model +- logo_matching.py: Deep Siamese Model +- configs.yaml: Configuration file +- phishpedia.py: Main script ``` ## Instructions Requirements: -- CUDA 11 - Anaconda installed, please refer to the official installation guide: https://docs.anaconda.com/free/anaconda/install/index.html 1. Create a local clone of Phishpedia -``` +```bash git clone https://github.com/lindsey98/Phishpedia.git ``` 2. Setup -``` -cd Phishpedia/ +```bash chmod +x ./setup.sh ./setup.sh ``` -If you encounter any problem in downloading the models, you can manually download them from here https://huggingface.co/Kelsey98/Phishpedia. And put them into the corresponding conda environment. 3. ``` -conda activate myenv +conda activate phishpedia ``` -Run in Python to test a single website -```python -from phishpedia.phishpedia_main import test -import matplotlib.pyplot as plt -from phishpedia.phishpedia_config import load_config - -url = open("phishpedia/datasets/test_sites/accounts.g.cdcde.com/info.txt").read().strip() -screenshot_path = "phishpedia/datasets/test_sites/accounts.g.cdcde.com/shot.png" -ELE_MODEL, SIAMESE_THRE, SIAMESE_MODEL, LOGO_FEATS, LOGO_FILES, DOMAIN_MAP_PATH = load_config(None) - -phish_category, pred_target, plotvis, siamese_conf, pred_boxes = test(url=url, screenshot_path=screenshot_path, - ELE_MODEL=ELE_MODEL, - SIAMESE_THRE=SIAMESE_THRE, - SIAMESE_MODEL=SIAMESE_MODEL, - LOGO_FEATS=LOGO_FEATS, - LOGO_FILES=LOGO_FILES, - DOMAIN_MAP_PATH=DOMAIN_MAP_PATH - ) - -print('Phishing (1) or Benign (0) ?', phish_category) -print('What is its targeted brand if it is a phishing ?', pred_target) -print('What is the siamese matching confidence ?', siamese_conf) -print('Where is the predicted logo (in [x_min, y_min, x_max, y_max])?', pred_boxes) -plt.imshow(plotvis[:, :, ::-1]) -plt.title("Predicted screenshot with annotations") -plt.show() +4. Run in bash +```bash +python phishpedia.py --folder ``` -Or run in bash +The testing folder should be in the structure of: + ``` -python run.py --folder --results +test_site_1 +|__ info.txt (Write the URL) +|__ shot.png (Save the screenshot) +test_site_2 +|__ info.txt (Write the URL) +|__ shot.png (Save the screenshot) +...... ``` ## Miscellaneous - In our paper, we also implement several phishing detection and identification baselines, see [here](https://github.com/lindsey98/PhishingBaseline) - The logo targetlist described in our paper includes 181 brands, we have further expanded the targetlist to include 277 brands in this code repository - For the phish discovery experiment, we obtain feed from [Certstream phish_catcher](https://github.com/x0rz/phishing_catcher), we lower the score threshold to be 40 to process more suspicious websites, readers can refer to their repo for details -- We use Scrapy for website crawling [Repo here](https://github.com/lindsey98/MyScrapy.git) +- We use Scrapy for website crawling ## Citation If you find our work useful in your research, please consider citing our paper by: diff --git a/configs.py b/configs.py new file mode 100644 index 0000000..50be87c --- /dev/null +++ b/configs.py @@ -0,0 +1,63 @@ +# Global configuration +import subprocess +from typing import Union +import yaml +from logo_matching import cache_reference_list, load_model_weights +from logo_recog import config_rcnn +import os +import numpy as np + +def get_absolute_path(relative_path): + base_path = os.path.dirname(__file__) + return os.path.abspath(os.path.join(base_path, relative_path)) + +def load_config(reload_targetlist=False): + + with open(os.path.join(os.path.dirname(__file__), 'configs.yaml')) as file: + configs = yaml.load(file, Loader=yaml.FullLoader) + + # Iterate through the configuration and update paths + for section, settings in configs.items(): + for key, value in settings.items(): + if 'PATH' in key and isinstance(value, str): # Check if the key indicates a path + absolute_path = get_absolute_path(value) + configs[section][key] = absolute_path + + ELE_CFG_PATH = configs['ELE_MODEL']['CFG_PATH'] + ELE_WEIGHTS_PATH = configs['ELE_MODEL']['WEIGHTS_PATH'] + ELE_CONFIG_THRE = configs['ELE_MODEL']['DETECT_THRE'] + ELE_MODEL = config_rcnn(ELE_CFG_PATH, + ELE_WEIGHTS_PATH, + conf_threshold=ELE_CONFIG_THRE) + + # siamese model + SIAMESE_THRE = configs['SIAMESE_MODEL']['MATCH_THRE'] + + print('Load protected logo list') + targetlist_zip_path = configs['SIAMESE_MODEL']['TARGETLIST_PATH'] + targetlist_dir = os.path.dirname(targetlist_zip_path) + zip_file_name = os.path.basename(targetlist_zip_path) + targetlist_folder = zip_file_name.split('.zip')[0] + full_targetlist_folder_dir = os.path.join(targetlist_dir, targetlist_folder) + + if reload_targetlist or targetlist_zip_path.endswith('.zip') and not os.path.isdir(full_targetlist_folder_dir): + os.makedirs(full_targetlist_folder_dir, exist_ok=True) + subprocess.run(f'unzip -o "{targetlist_zip_path}" -d "{full_targetlist_folder_dir}"', shell=True) + + SIAMESE_MODEL = load_model_weights( num_classes=configs['SIAMESE_MODEL']['NUM_CLASSES'], + weights_path=configs['SIAMESE_MODEL']['WEIGHTS_PATH']) + + if reload_targetlist or (not os.path.exists(os.path.join(os.path.dirname(__file__), 'LOGO_FEATS.npy'))): + LOGO_FEATS, LOGO_FILES = cache_reference_list(model=SIAMESE_MODEL, + targetlist_path=full_targetlist_folder_dir) + print('Finish loading protected logo list') + np.save(os.path.join(os.path.dirname(__file__),'LOGO_FEATS.npy'), LOGO_FEATS) + np.save(os.path.join(os.path.dirname(__file__),'LOGO_FILES.npy'), LOGO_FILES) + + else: + LOGO_FEATS, LOGO_FILES = np.load(os.path.join(os.path.dirname(__file__),'LOGO_FEATS.npy')), \ + np.load(os.path.join(os.path.dirname(__file__),'LOGO_FILES.npy')) + + DOMAIN_MAP_PATH = configs['SIAMESE_MODEL']['DOMAIN_MAP_PATH'] + + return ELE_MODEL, SIAMESE_THRE, SIAMESE_MODEL, LOGO_FEATS, LOGO_FILES, DOMAIN_MAP_PATH \ No newline at end of file diff --git a/configs.yaml b/configs.yaml new file mode 100644 index 0000000..6b318ae --- /dev/null +++ b/configs.yaml @@ -0,0 +1,11 @@ +ELE_MODEL: # element recognition model -- logo only + CFG_PATH: models/faster_rcnn.yaml # os.path.join(os.path.dirname(__file__), xxx) + WEIGHTS_PATH: models/rcnn_bet365.pth + DETECT_THRE: 0.05 + +SIAMESE_MODEL: + NUM_CLASSES: 277 # number of brands, users don't need to modify this even the targetlist is expanded + MATCH_THRE: 0.87 # FIXME: threshold is 0.87 in phish-discovery? + WEIGHTS_PATH: models/resnetv2_rgb_new.pth.tar + TARGETLIST_PATH: models/expand_targetlist.zip + DOMAIN_MAP_PATH: models/domain_map.pkl \ No newline at end of file diff --git a/datasets/.DS_Store b/datasets/.DS_Store new file mode 100644 index 0000000..787bd30 Binary files /dev/null and b/datasets/.DS_Store differ diff --git a/datasets/overview.png b/datasets/overview.png new file mode 100644 index 0000000..74ed439 Binary files /dev/null and b/datasets/overview.png differ diff --git a/datasets/test_sites/.DS_Store b/datasets/test_sites/.DS_Store new file mode 100644 index 0000000..ccf2878 Binary files /dev/null and b/datasets/test_sites/.DS_Store differ diff --git a/datasets/test_sites/accounts.g.cdcde.com/.DS_Store b/datasets/test_sites/accounts.g.cdcde.com/.DS_Store new file mode 100644 index 0000000..9c34498 Binary files /dev/null and b/datasets/test_sites/accounts.g.cdcde.com/.DS_Store differ diff --git a/datasets/test_sites/accounts.g.cdcde.com/html.txt b/datasets/test_sites/accounts.g.cdcde.com/html.txt new file mode 100644 index 0000000..b8c0e99 --- /dev/null +++ b/datasets/test_sites/accounts.g.cdcde.com/html.txt @@ -0,0 +1,4106 @@ +Sign in - Google Accounts
\ No newline at end of file diff --git a/datasets/test_sites/accounts.g.cdcde.com/info.txt b/datasets/test_sites/accounts.g.cdcde.com/info.txt new file mode 100644 index 0000000..e797e4f --- /dev/null +++ b/datasets/test_sites/accounts.g.cdcde.com/info.txt @@ -0,0 +1 @@ +https://accounts.g.cdcde.com/ServiceLogin?passive=1209600&osid=1&continue=https://plus.g.cdcde.com/&followup=https://plus.g.cdcde.com/ \ No newline at end of file diff --git a/datasets/test_sites/accounts.g.cdcde.com/shot.png b/datasets/test_sites/accounts.g.cdcde.com/shot.png new file mode 100644 index 0000000..a6660c3 Binary files /dev/null and b/datasets/test_sites/accounts.g.cdcde.com/shot.png differ diff --git a/logo_matching.py b/logo_matching.py new file mode 100644 index 0000000..ba65bdf --- /dev/null +++ b/logo_matching.py @@ -0,0 +1,218 @@ +from PIL import Image, ImageOps +from torchvision import transforms +from utils import brand_converter, resolution_alignment, l2_norm +from models import KNOWN_MODELS +import torch +import os +import numpy as np +from collections import OrderedDict +from tqdm import tqdm +from tldextract import tldextract +import pickle + +def check_domain_brand_inconsistency(logo_boxes, + domain_map_path: str, + model, logo_feat_list, + file_name_list, shot_path: str, + url: str, ts: float, + topk: float=3): + + # targetlist domain list + with open(domain_map_path, 'rb') as handle: + domain_map = pickle.load(handle) + + print('number of logo boxes:', len(logo_boxes)) + matched_target, matched_domain, matched_coord, this_conf = None, None, None, None + + if len(logo_boxes) > 0: + # siamese prediction for logo box + for i, coord in enumerate(logo_boxes): + + if i == topk: + break + + min_x, min_y, max_x, max_y = coord + bbox = [float(min_x), float(min_y), float(max_x), float(max_y)] + matched_target, matched_domain, this_conf = pred_brand(model, domain_map, + logo_feat_list, file_name_list, + shot_path, bbox, t_s=ts, grayscale=False) + # print(target_this, domain_this, this_conf) + # domain matcher to avoid FP + if matched_target is not None: + matched_coord = coord + if tldextract.extract(url).domain in matched_domain: # domain and brand are consistent + matched_target = None + else: + break # break if target is matched + + return brand_converter(matched_target), matched_domain, matched_coord, this_conf + +def load_model_weights(num_classes: int, weights_path: str): + # Initialize model + device = 'cuda' if torch.cuda.is_available() else 'cpu' + model = KNOWN_MODELS["BiT-M-R50x1"](head_size=num_classes, zero_head=True) + + # Load weights + weights = torch.load(weights_path, map_location='cpu') + weights = weights['model'] if 'model' in weights.keys() else weights + new_state_dict = OrderedDict() + for k, v in weights.items(): + name = k.split('module.')[1] + new_state_dict[name] = v + + model.load_state_dict(new_state_dict) + model.to(device) + model.eval() + return model + +def cache_reference_list(model, targetlist_path: str, grayscale=False): + ''' + cache the embeddings of the reference list + :param num_classes: number of protected brands + :param weights_path: siamese weights + :param targetlist_path: targetlist folder + :param grayscale: convert logo to grayscale or not, default is RGB + :return model: siamese model + :return logo_feat_list: targetlist embeddings + :return file_name_list: targetlist paths + ''' + + # Prediction for targetlists + logo_feat_list = [] + file_name_list = [] + + for target in tqdm(os.listdir(targetlist_path)): + if target.startswith('.'): # skip hidden files + continue + for logo_path in os.listdir(os.path.join(targetlist_path, target)): + if logo_path.endswith('.png') or logo_path.endswith('.jpeg') or logo_path.endswith('.jpg') or logo_path.endswith('.PNG') \ + or logo_path.endswith('.JPG') or logo_path.endswith('.JPEG'): + if logo_path.startswith('loginpage') or logo_path.startswith('homepage'): # skip homepage/loginpage + continue + logo_feat_list.append(get_embedding(img=os.path.join(targetlist_path, target, logo_path), + model=model, grayscale=grayscale)) + file_name_list.append(str(os.path.join(targetlist_path, target, logo_path))) + + return np.asarray(logo_feat_list), np.asarray(file_name_list) + +@torch.no_grad() +def get_embedding(img, model, grayscale=False): + ''' + Inference for a single image + :param img: image path in str or image in PIL.Image + :param model: model to make inference + :param imshow: enable display of image or not + :param title: title of displayed image + :param grayscale: convert image to grayscale or not + :return feature embedding of shape (2048,) + ''' + # img_size = 224 + img_size = 128 + mean = [0.5, 0.5, 0.5] + std = [0.5, 0.5, 0.5] + device = 'cuda' if torch.cuda.is_available() else 'cpu' + + img_transforms = transforms.Compose( + [transforms.ToTensor(), + transforms.Normalize(mean=mean, std=std), + ]) + + img = Image.open(img) if isinstance(img, str) else img + img = img.convert("L").convert("RGB") if grayscale else img.convert("RGB") + + ## Resize the image while keeping the original aspect ratio + pad_color = 255 if grayscale else (255, 255, 255) + img = ImageOps.expand(img, ( + (max(img.size) - img.size[0]) // 2, (max(img.size) - img.size[1]) // 2, + (max(img.size) - img.size[0]) // 2, (max(img.size) - img.size[1]) // 2), fill=pad_color) + + img = img.resize((img_size, img_size)) + + # Predict the embedding + with torch.no_grad(): + img = img_transforms(img) + img = img[None, ...].to(device) + logo_feat = model.features(img) + logo_feat = l2_norm(logo_feat).squeeze(0).cpu().numpy() # L2-normalization final shape is (2048,) + + return logo_feat + + +def pred_brand(model, domain_map, logo_feat_list, file_name_list, shot_path: str, gt_bbox, t_s, grayscale=False): + ''' + Return predicted brand for one cropped image + :param model: model to use + :param domain_map: brand-domain dictionary + :param logo_feat_list: reference logo feature embeddings + :param file_name_list: reference logo paths + :param shot_path: path to the screenshot + :param gt_bbox: 1x4 np.ndarray/list/tensor bounding box coords + :param t_s: similarity threshold for siamese + :param grayscale: convert image(cropped) to grayscale or not + :return: predicted target, predicted target's domain + ''' + + try: + img = Image.open(shot_path) + except OSError: # if the image cannot be identified, return nothing + print('Screenshot cannot be open') + return None, None, None + + ## get predicted box --> crop from screenshot + cropped = img.crop((gt_bbox[0], gt_bbox[1], gt_bbox[2], gt_bbox[3])) + img_feat = get_embedding(cropped, model, grayscale=grayscale) + + ## get cosine similarity with every protected logo + sim_list = logo_feat_list @ img_feat.T # take dot product for every pair of embeddings (Cosine Similarity) + pred_brand_list = file_name_list + + assert len(sim_list) == len(pred_brand_list) + + ## get top 3 brands + idx = np.argsort(sim_list)[::-1][:3] + pred_brand_list = np.array(pred_brand_list)[idx] + sim_list = np.array(sim_list)[idx] + + # top1,2,3 candidate logos + top3_logolist = [Image.open(x) for x in pred_brand_list] + top3_brandlist = [brand_converter(os.path.basename(os.path.dirname(x))) for x in pred_brand_list] + top3_domainlist = [domain_map[x] for x in top3_brandlist] + top3_simlist = sim_list + + for j in range(3): + predicted_brand, predicted_domain = None, None + + ## If we are trying those lower rank logo, the predicted brand of them should be the same as top1 logo, otherwise might be false positive + if top3_brandlist[j] != top3_brandlist[0]: + continue + + ## If the largest similarity exceeds threshold + if top3_simlist[j] >= t_s: + predicted_brand = top3_brandlist[j] + predicted_domain = top3_domainlist[j] + final_sim = top3_simlist[j] + + ## Else if not exceed, try resolution alignment, see if can improve + else: + cropped, candidate_logo = resolution_alignment(cropped, top3_logolist[j]) + img_feat = get_embedding(cropped, model, grayscale=grayscale) + logo_feat = get_embedding(candidate_logo, model, grayscale=grayscale) + final_sim = logo_feat.dot(img_feat) + if final_sim >= t_s: + predicted_brand = top3_brandlist[j] + predicted_domain = top3_domainlist[j] + else: + break # no hope, do not try other lower rank logos + + ## If there is a prediction, do aspect ratio check + if predicted_brand is not None: + ratio_crop = cropped.size[0] / cropped.size[1] + ratio_logo = top3_logolist[j].size[0] / top3_logolist[j].size[1] + # aspect ratios of matched pair must not deviate by more than factor of 2.5 + if max(ratio_crop, ratio_logo) / min(ratio_crop, ratio_logo) > 2.5: + continue # did not pass aspect ratio check, try other + # If pass aspect ratio check, report a match + else: + return predicted_brand, predicted_domain, final_sim + + return None, None, top3_simlist[0] \ No newline at end of file diff --git a/logo_recog.py b/logo_recog.py new file mode 100644 index 0000000..10eb8a8 --- /dev/null +++ b/logo_recog.py @@ -0,0 +1,80 @@ +from detectron2.config import get_cfg +from detectron2.engine import DefaultPredictor +import cv2 +import numpy as np +import torch + +def pred_rcnn(im, predictor): + ''' + Perform inference for RCNN + :param im: + :param predictor: + :return: + ''' + im = cv2.imread(im) + + if im is not None: + if im.shape[-1] == 4: + im = cv2.cvtColor(im, cv2.COLOR_BGRA2BGR) + else: + return None, None, None, None + + outputs = predictor(im) + + instances = outputs['instances'] + pred_classes = instances.pred_classes # tensor + pred_boxes = instances.pred_boxes # Boxes object + + logo_boxes = pred_boxes[pred_classes == 1].tensor + input_boxes = pred_boxes[pred_classes == 0].tensor + + scores = instances.scores # tensor + logo_scores = scores[pred_classes == 1] + input_scores = scores[pred_classes == 0] + + return logo_boxes, logo_scores, input_boxes, input_scores + + +def config_rcnn(cfg_path, weights_path, conf_threshold): + ''' + Configure weights and confidence threshold + :param cfg_path: + :param weights_path: + :param conf_threshold: + :return: + ''' + cfg = get_cfg() + cfg.merge_from_file(cfg_path) + cfg.MODEL.WEIGHTS = weights_path + cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = conf_threshold + # uncomment if you installed detectron2 cpu version + if not torch.cuda.is_available(): + cfg.MODEL.DEVICE = 'cpu' + + # Initialize model + predictor = DefaultPredictor(cfg) + return predictor + + +def vis(img_path, pred_boxes): + ''' + Visualize rcnn predictions + :param img_path: str + :param pred_boxes: torch.Tensor of shape Nx4, bounding box coordinates in (x1, y1, x2, y2) + :param pred_classes: torch.Tensor of shape Nx1 0 for logo, 1 for input, 2 for button, 3 for label(text near input), 4 for block + :return None + ''' + + check = cv2.imread(img_path) + if pred_boxes is None or len(pred_boxes) == 0: + return check + pred_boxes = pred_boxes.numpy() if not isinstance(pred_boxes, np.ndarray) else pred_boxes + + # draw rectangle + for j, box in enumerate(pred_boxes): + if j == 0: + cv2.rectangle(check, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), (255, 255, 0), 2) + else: + cv2.rectangle(check, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), (36, 255, 12), 2) + + return check \ No newline at end of file diff --git a/models.py b/models.py new file mode 100644 index 0000000..34d3992 --- /dev/null +++ b/models.py @@ -0,0 +1,195 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Lint as: python3 +"""Bottleneck ResNet v2 with GroupNorm and Weight Standardization.""" + +from collections import OrderedDict # pylint: disable=g-importing-member + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class StdConv2d(nn.Conv2d): + + def forward(self, x): + w = self.weight + v, m = torch.var_mean(w, dim=[1, 2, 3], keepdim=True, unbiased=False) + w = (w - m) / torch.sqrt(v + 1e-10) + return F.conv2d(x, w, self.bias, self.stride, self.padding, + self.dilation, self.groups) + + +def conv3x3(cin, cout, stride=1, groups=1, bias=False): + return StdConv2d(cin, cout, kernel_size=3, stride=stride, + padding=1, bias=bias, groups=groups) + + +def conv1x1(cin, cout, stride=1, bias=False): + return StdConv2d(cin, cout, kernel_size=1, stride=stride, + padding=0, bias=bias) + + +def tf2th(conv_weights): + """Possibly convert HWIO to OIHW.""" + if conv_weights.ndim == 4: + conv_weights = conv_weights.transpose([3, 2, 0, 1]) + return torch.from_numpy(conv_weights) + + +class PreActBottleneck(nn.Module): + """Pre-activation (v2) bottleneck block. + + Follows the implementation of "Identity Mappings in Deep Residual Networks": + https://github.com/KaimingHe/resnet-1k-layers/blob/master/resnet-pre-act.lua + + Except it puts the stride on 3x3 conv when available. + """ + + def __init__(self, cin, cout=None, cmid=None, stride=1): + super().__init__() + cout = cout or cin + cmid = cmid or cout//4 + + self.gn1 = nn.GroupNorm(32, cin) + self.conv1 = conv1x1(cin, cmid) + self.gn2 = nn.GroupNorm(32, cmid) + self.conv2 = conv3x3(cmid, cmid, stride) # Original code has it on conv1!! + self.gn3 = nn.GroupNorm(32, cmid) + self.conv3 = conv1x1(cmid, cout) + self.relu = nn.ReLU(inplace=True) + + if (stride != 1 or cin != cout): + # Projection also with pre-activation according to paper. + self.downsample = conv1x1(cin, cout, stride) + + def forward(self, x): + out = self.relu(self.gn1(x)) + + # Residual branch + residual = x + if hasattr(self, 'downsample'): + residual = self.downsample(out) + + # Unit's branch + out = self.conv1(out) + out = self.conv2(self.relu(self.gn2(out))) + out = self.conv3(self.relu(self.gn3(out))) + + return out + residual + + def load_from(self, weights, prefix=''): + convname = 'standardized_conv2d' + with torch.no_grad(): + self.conv1.weight.copy_(tf2th(weights[f'{prefix}a/{convname}/kernel'])) + self.conv2.weight.copy_(tf2th(weights[f'{prefix}b/{convname}/kernel'])) + self.conv3.weight.copy_(tf2th(weights[f'{prefix}c/{convname}/kernel'])) + self.gn1.weight.copy_(tf2th(weights[f'{prefix}a/group_norm/gamma'])) + self.gn2.weight.copy_(tf2th(weights[f'{prefix}b/group_norm/gamma'])) + self.gn3.weight.copy_(tf2th(weights[f'{prefix}c/group_norm/gamma'])) + self.gn1.bias.copy_(tf2th(weights[f'{prefix}a/group_norm/beta'])) + self.gn2.bias.copy_(tf2th(weights[f'{prefix}b/group_norm/beta'])) + self.gn3.bias.copy_(tf2th(weights[f'{prefix}c/group_norm/beta'])) + if hasattr(self, 'downsample'): + w = weights[f'{prefix}a/proj/{convname}/kernel'] + self.downsample.weight.copy_(tf2th(w)) + + +class ResNetV2(nn.Module): + """Implementation of Pre-activation (v2) ResNet mode.""" + + def __init__(self, block_units, width_factor, head_size=21843, zero_head=False): + super().__init__() + wf = width_factor # shortcut 'cause we'll use it a lot. + + # The following will be unreadable if we split lines. + # pylint: disable=line-too-long + self.root = nn.Sequential(OrderedDict([ + ('conv', StdConv2d(3, 64*wf, kernel_size=7, stride=2, padding=3, bias=False)), + ('pad', nn.ConstantPad2d(1, 0)), + ('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=0)), + # The following is subtly not the same! + # ('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)), + ])) + + self.body = nn.Sequential(OrderedDict([ + ('block1', nn.Sequential(OrderedDict( + [('unit01', PreActBottleneck(cin=64*wf, cout=256*wf, cmid=64*wf))] + + [(f'unit{i:02d}', PreActBottleneck(cin=256*wf, cout=256*wf, cmid=64*wf)) for i in range(2, block_units[0] + 1)], + ))), + ('block2', nn.Sequential(OrderedDict( + [('unit01', PreActBottleneck(cin=256*wf, cout=512*wf, cmid=128*wf, stride=2))] + + [(f'unit{i:02d}', PreActBottleneck(cin=512*wf, cout=512*wf, cmid=128*wf)) for i in range(2, block_units[1] + 1)], + ))), + ('block3', nn.Sequential(OrderedDict( + [('unit01', PreActBottleneck(cin=512*wf, cout=1024*wf, cmid=256*wf, stride=2))] + + [(f'unit{i:02d}', PreActBottleneck(cin=1024*wf, cout=1024*wf, cmid=256*wf)) for i in range(2, block_units[2] + 1)], + ))), + ('block4', nn.Sequential(OrderedDict( + [('unit01', PreActBottleneck(cin=1024*wf, cout=2048*wf, cmid=512*wf, stride=2))] + + [(f'unit{i:02d}', PreActBottleneck(cin=2048*wf, cout=2048*wf, cmid=512*wf)) for i in range(2, block_units[3] + 1)], + ))), + ])) + # pylint: enable=line-too-long + + self.zero_head = zero_head + self.head = nn.Sequential(OrderedDict([ + ('gn', nn.GroupNorm(32, 2048*wf)), + ('relu', nn.ReLU(inplace=True)), + ('avg', nn.AdaptiveAvgPool2d(output_size=1)), + ('conv', nn.Conv2d(2048*wf, head_size, kernel_size=1, bias=True)), + ])) + + def features(self, x): + x = self.head[:-1](self.body(self.root(x))) + + return x.squeeze(-1).squeeze(-1) + + def forward(self, x): + x = self.head(self.body(self.root(x))) + assert x.shape[-2:] == (1, 1) # We should have no spatial shape left. + return x[...,0,0] + + def load_from(self, weights, prefix='resnet/'): + with torch.no_grad(): + self.root.conv.weight.copy_(tf2th(weights[f'{prefix}root_block/standardized_conv2d/kernel'])) # pylint: disable=line-too-long + self.head.gn.weight.copy_(tf2th(weights[f'{prefix}group_norm/gamma'])) + self.head.gn.bias.copy_(tf2th(weights[f'{prefix}group_norm/beta'])) + if self.zero_head: + nn.init.zeros_(self.head.conv.weight) + nn.init.zeros_(self.head.conv.bias) + else: + self.head.conv.weight.copy_(tf2th(weights[f'{prefix}head/conv2d/kernel'])) # pylint: disable=line-too-long + self.head.conv.bias.copy_(tf2th(weights[f'{prefix}head/conv2d/bias'])) + + for bname, block in self.body.named_children(): + for uname, unit in block.named_children(): + unit.load_from(weights, prefix=f'{prefix}{bname}/{uname}/') + + +KNOWN_MODELS = OrderedDict([ + ('BiT-M-R50x1', lambda *a, **kw: ResNetV2([3, 4, 6, 3], 1, *a, **kw)), + ('BiT-M-R50x3', lambda *a, **kw: ResNetV2([3, 4, 6, 3], 3, *a, **kw)), + ('BiT-M-R101x1', lambda *a, **kw: ResNetV2([3, 4, 23, 3], 1, *a, **kw)), + ('BiT-M-R101x3', lambda *a, **kw: ResNetV2([3, 4, 23, 3], 3, *a, **kw)), + ('BiT-M-R152x2', lambda *a, **kw: ResNetV2([3, 8, 36, 3], 2, *a, **kw)), + ('BiT-M-R152x4', lambda *a, **kw: ResNetV2([3, 8, 36, 3], 4, *a, **kw)), + ('BiT-S-R50x1', lambda *a, **kw: ResNetV2([3, 4, 6, 3], 1, *a, **kw)), + ('BiT-S-R50x3', lambda *a, **kw: ResNetV2([3, 4, 6, 3], 3, *a, **kw)), + ('BiT-S-R101x1', lambda *a, **kw: ResNetV2([3, 4, 23, 3], 1, *a, **kw)), + ('BiT-S-R101x3', lambda *a, **kw: ResNetV2([3, 4, 23, 3], 3, *a, **kw)), + ('BiT-S-R152x2', lambda *a, **kw: ResNetV2([3, 8, 36, 3], 2, *a, **kw)), + ('BiT-S-R152x4', lambda *a, **kw: ResNetV2([3, 8, 36, 3], 4, *a, **kw)), +]) \ No newline at end of file diff --git a/parse_results.py b/parse_results.py new file mode 100644 index 0000000..45d197f --- /dev/null +++ b/parse_results.py @@ -0,0 +1,40 @@ +import os.path +import shutil +from datetime import datetime + +def get_pos_site(result_txt): + + df = [x.strip().split('\t') for x in open(result_txt, encoding='ISO-8859-1').readlines()] + df_pos = [x for x in df if (x[2] == '1')] + df_pos = [x for x in df_pos if x[3] not in ['Google', 'Webmail Provider', + 'WhatsApp', 'Luno']] + return df_pos + + + +if __name__ == '__main__': + # today = datetime.now().strftime('%Y%m%d') + today = '20231223' + result_txt = f'/home/ruofan/git_space/PhishEmail/datasets/{today}_results.txt' + df_pos = get_pos_site(result_txt) + print(len(df_pos)) + + pos_result_txt = f'/home/ruofan/git_space/PhishEmail/datasets/{today}_pos.txt' + pos_result_dir = f'/home/ruofan/git_space/PhishEmail/datasets/sjtu_phish_pos/{today}' + os.makedirs(pos_result_dir, exist_ok=True) + + for x in df_pos: + url = x[1] + if os.path.exists(pos_result_txt) and url in open(pos_result_txt).read(): + pass + else: + with open(pos_result_txt, 'a+') as f: + f.write(url+'\n') + + # try: + shutil.copytree(os.path.join(f'/home/ruofan/git_space/PhishEmail/datasets/sjtu_phish/{today}', x[0]), + os.path.join(pos_result_dir, x[0])) + # except FileExistsError: + # pass + + print(df_pos) \ No newline at end of file diff --git a/phishpedia.py b/phishpedia.py new file mode 100644 index 0000000..99ed0a2 --- /dev/null +++ b/phishpedia.py @@ -0,0 +1,184 @@ + +import time +import sys +from datetime import datetime +import argparse +import os +import torch +from lib.phishpedia.configs import load_config +from tldextract import tldextract +import cv2 +from logo_recog import pred_rcnn, vis +from logo_matching import check_domain_brand_inconsistency +from text_recog import check_email_credential_taking +import pickle +from tqdm import tqdm +from paddleocr import PaddleOCR +import re + +os.environ['KMP_DUPLICATE_LIB_OK']='True' + +class PhishpediaWrapper: + _caller_prefix = "PhishpediaWrapper" + _DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' + + def __init__(self): + self._load_config() + self._to_device() + + def _load_config(self): + self.ELE_MODEL, self.SIAMESE_THRE, self.SIAMESE_MODEL, \ + self.LOGO_FEATS, self.LOGO_FILES, \ + self.DOMAIN_MAP_PATH = load_config() + + self.OCR_MODEL = PaddleOCR(use_angle_cls=True, + lang='ch', + show_log=False, + use_gpu=torch.cuda.is_available()) # need to run only once to download and load model into memory + print(f'Length of reference list = {len(self.LOGO_FEATS)}') + + def _to_device(self): + self.SIAMESE_MODEL.to(self._DEVICE) + + + '''Phishpedia''' + def test_orig_phishpedia(self, url, screenshot_path): + # 0 for benign, 1 for phish, default is benign + phish_category = 0 + pred_target = None + matched_domain = None + siamese_conf = None + logo_recog_time = 0 + logo_match_time = 0 + print("Entering phishpedia") + + ####################### Step1: Logo detector ############################################## + start_time = time.time() + pred_boxes, _, _, _ = pred_rcnn(im=screenshot_path, predictor=self.ELE_MODEL) + logo_recog_time = time.time() - start_time + + if pred_boxes is not None: + pred_boxes = pred_boxes.detach().cpu().numpy() + plotvis = vis(screenshot_path, pred_boxes) + + # If no element is reported + if pred_boxes is None or len(pred_boxes) == 0: + print('No logo is detected') + return phish_category, pred_target, matched_domain, plotvis, siamese_conf, pred_boxes, logo_recog_time, logo_match_time + + print('Entering siamese') + + ######################## Step2: Siamese (Logo matcher) ######################################## + start_time = time.time() + pred_target, matched_domain, matched_coord, siamese_conf = check_domain_brand_inconsistency(logo_boxes=pred_boxes, + domain_map_path=self.DOMAIN_MAP_PATH, + model=self.SIAMESE_MODEL, + logo_feat_list=self.LOGO_FEATS, + file_name_list=self.LOGO_FILES, + url=url, + shot_path=screenshot_path, + ts=self.SIAMESE_THRE) + logo_match_time = time.time() - start_time + + if pred_target is None: + ### ask for email + ask_for_emails, matched_email = check_email_credential_taking(self.OCR_MODEL, screenshot_path) + if ask_for_emails and (tldextract.extract(url).domain not in matched_email.replace('@', '')): # domain and brand are consistent + matched_target = matched_email.replace('@', '') + cv2.putText(plotvis, "No logo but asks for email credentials", + (20, 20), + cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 0), 2) + print('No logo but asks for specific email credentials') + return 1, matched_target, matched_domain, plotvis, siamese_conf, pred_boxes, logo_recog_time, logo_match_time + + print('Did not match to any brand, report as benign') + return phish_category, pred_target, matched_domain, plotvis, siamese_conf, pred_boxes, logo_recog_time, logo_match_time + + else: + phish_category = 1 + # Visualize, add annotations + cv2.putText(plotvis, "Target: {} with confidence {:.4f}".format(pred_target, siamese_conf), + (int(matched_coord[0] + 20), int(matched_coord[1] + 20)), + cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 0), 2) + + + return phish_category, pred_target, matched_domain, plotvis, siamese_conf, pred_boxes, logo_recog_time, logo_match_time + + +if __name__ == '__main__': + + + + '''update domain map''' + # with open('./lib/phishpedia/models/domain_map.pkl', "rb") as handle: + # domain_map = pickle.load(handle) + # + # domain_map['weibo'] = ['sina', 'weibo'] + # + # with open('./lib/phishpedia/models/domain_map.pkl', "wb") as handle: + # pickle.dump(domain_map, handle) + # exit() + + '''run''' + today = datetime.now().strftime('%Y%m%d') + + parser = argparse.ArgumentParser() + parser.add_argument("--folder", required=True, type=str) + parser.add_argument("--output_txt", default=f'{today}_results.txt', help="Output txt path") + args = parser.parse_args() + + request_dir = args.folder + phishpedia_cls = PhishpediaWrapper() + result_txt = args.output_txt + + os.makedirs(request_dir, exist_ok=True) + + for folder in tqdm(os.listdir(request_dir)): + html_path = os.path.join(request_dir, folder, "html.txt") + screenshot_path = os.path.join(request_dir, folder, "shot.png") + info_path = os.path.join(request_dir, folder, 'info.json') + + if not os.path.exists(screenshot_path): + continue + + url = eval(open(info_path).read())['url'] + + if os.path.exists(result_txt) and url in open(result_txt, encoding='ISO-8859-1').read(): + continue + + _forbidden_suffixes = r"\.(mp3|wav|wma|ogg|mkv|zip|tar|xz|rar|z|deb|bin|iso|csv|tsv|dat|txt|css|log|sql|xml|sql|mdb|apk|bat|bin|exe|jar|wsf|fnt|fon|otf|ttf|ai|bmp|gif|ico|jp(e)?g|png|ps|psd|svg|tif|tiff|cer|rss|key|odp|pps|ppt|pptx|c|class|cpp|cs|h|java|sh|swift|vb|odf|xlr|xls|xlsx|bak|cab|cfg|cpl|cur|dll|dmp|drv|icns|ini|lnk|msi|sys|tmp|3g2|3gp|avi|flv|h264|m4v|mov|mp4|mp(e)?g|rm|swf|vob|wmv|doc(x)?|odt|rtf|tex|txt|wks|wps|wpd)$" + if re.search(_forbidden_suffixes, url, re.IGNORECASE): + continue + + phish_category, pred_target, matched_domain, \ + plotvis, siamese_conf, pred_boxes, \ + logo_recog_time, logo_match_time = phishpedia_cls.test_orig_phishpedia(url, screenshot_path) + + try: + with open(result_txt, "a+", encoding='ISO-8859-1') as f: + f.write(folder + "\t") + f.write(url + "\t") + f.write(str(phish_category) + "\t") + f.write(str(pred_target) + "\t") # write top1 prediction only + f.write(str(matched_domain) + "\t") + f.write(str(siamese_conf) + "\t") + f.write(str(round(logo_recog_time, 4)) + "\t") + f.write(str(round(logo_match_time, 4)) + "\n") + except UnicodeError: + with open(result_txt, "a+", encoding='utf-8') as f: + f.write(folder + "\t") + f.write(url + "\t") + f.write(str(phish_category) + "\t") + f.write(str(pred_target) + "\t") # write top1 prediction only + f.write(str(matched_domain) + "\t") + f.write(str(siamese_conf) + "\t") + f.write(str(round(logo_recog_time, 4)) + "\t") + f.write(str(round(logo_match_time, 4)) + "\n") + if phish_category: + os.makedirs(os.path.join(request_dir, folder), exist_ok=True) + cv2.imwrite(os.path.join(request_dir, folder, "predict.png"), plotvis) + + + # import matplotlib.pyplot as plt + # plt.imshow(cropped) + # plt.savefig('./debug.png') \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index d9982f8..12d5922 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,22 +1,14 @@ -torchsummary scipy tldextract opencv-python pandas numpy tqdm -Pillow +Pillow==8.4.0 pathlib fvcore pycocotools scikit-learn -advertorch -python-telegram-bot -gspread -oauth2client lxml - - - - - +gdown +paddleocr>=2.0.1 \ No newline at end of file diff --git a/setup.sh b/setup.sh old mode 100755 new mode 100644 index 6269f44..dd0000f --- a/setup.sh +++ b/setup.sh @@ -1,21 +1,39 @@ #!/bin/bash +retry_count=3 # Number of retries + +download_with_retry() { + local file_id=$1 + local file_name=$2 + local count=0 + + until [ $count -ge $retry_count ] + do + gdown --id "$file_id" -O "$file_name" && break # attempt to download and break if successful + count=$((count+1)) + echo "Retry $count of $retry_count..." + sleep 1 # wait for 5 seconds before retrying + done + + if [ $count -ge $retry_count ]; then + echo "Failed to download $file_name after $retry_count attempts." + fi +} FILEDIR=$(pwd) CONDA_BASE=$(conda info --base) source "$CONDA_BASE/etc/profile.d/conda.sh" -conda info --envs | grep -w "myenv" > /dev/null +conda info --envs | grep -w "phishpedia" > /dev/null if [ $? -eq 0 ]; then - echo "Activating Conda environment myenv" - conda activate myenv + echo "Activating Conda environment phishpedia" + conda activate phishpedia else - echo "Creating and activating new Conda environment $ENV_NAME with Python 3.8" - conda create -n myenv python=3.8 - conda activate myenv + echo "Creating and activating new Conda environment phishpedia with Python 3.8" + conda create -n phishpedia python=3.8 + conda activate phishpedia fi -pip install -r requirements.txt OS=$(uname -s) @@ -23,39 +41,66 @@ if [[ "$OS" == "Darwin" ]]; then echo "Installing PyTorch and torchvision for macOS." pip install torch==1.9.0 torchvision==0.10.0 torchaudio==0.9.0 python -m pip install detectron2 -f "https://dl.fbaipublicfiles.com/detectron2/wheels/cpu/torch1.9/index.html" + python -m pip install paddlepaddle==2.5.1 -i https://mirror.baidu.com/pypi/simple else # Check if NVIDIA GPU is available for Linux and Windows - if command -v nvcc &> /dev/null; then + if command -v nvcc || command -v nvidia-smi &> /dev/null; then echo "CUDA is detected, installing GPU-supported PyTorch and torchvision." pip install torch==1.9.0+cu111 torchvision==0.10.0+cu111 torchaudio==0.9.0 -f "https://download.pytorch.org/whl/torch_stable.html" python -m pip install detectron2 -f "https://dl.fbaipublicfiles.com/detectron2/wheels/cu111/torch1.9/index.html" + python -m pip install paddlepaddle-gpu==2.5.1 -i https://mirror.baidu.com/pypi/simple else echo "No CUDA detected, installing CPU-only PyTorch and torchvision." pip install torch==1.9.0+cpu torchvision==0.10.0+cpu torchaudio==0.9.0 -f "https://download.pytorch.org/whl/torch_stable.html" python -m pip install detectron2 -f "https://dl.fbaipublicfiles.com/detectron2/wheels/cpu/torch1.9/index.html" + python -m pip install paddlepaddle==2.5.1 -i https://mirror.baidu.com/pypi/simple fi fi +pip install -r requirements.txt + ## Download models -pip install -v . -package_location=$(pip show phishpedia | grep Location | awk '{print $2}') +echo "Going to the directory of package Phishpedia in Conda environment myenv." +mkdir -p models/ +cd models/ + -if [ -z "$package_location" ]; then - echo "Package Phishpedia not found in the Conda environment myenv." - exit 1 +# RCNN model weights +if [ -f "rcnn_bet365.pth" ]; then + echo "RCNN model weights exists... skip" else - echo "Going to the directory of package Phishpedia in Conda environment myenv." - cd "$package_location/phishpedia/src/detectron2_pedia/output/rcnn_2" || exit - pip install gdown - gdown --id 1tE2Mu5WC8uqCxei3XqAd7AWaP5JTmVWH - cd "$package_location/phishpedia/src/siamese_pedia/" || exit - gdown --id 1H0Q_DbdKPLFcZee8I14K62qV7TTy7xvS - gdown --id 1fr5ZxBKyDiNZ_1B6rRAfZbAHBBoUjZ7I - gdown --id 1qSdkSSoCYUkZMKs44Rup_1DPBxHnEKl1 + download_with_retry 1tE2Mu5WC8uqCxei3XqAd7AWaP5JTmVWH rcnn_bet365.pth fi -# Replace the placeholder in the YAML template +# Faster RCNN config +if [ -f "faster_rcnn.yaml" ]; then + echo "RCNN model config exists... skip" +else + download_with_retry 1Q6lqjpl4exW7q_dPbComcj0udBMDl8CW faster_rcnn.yaml +fi -sed "s|CONDA_ENV_PATH_PLACEHOLDER|$package_location/phishpedia|g" "$FILEDIR/phishpedia/configs_template.yaml" > "$FILEDIR/phishpedia/configs.yaml" +# Siamese model weights +if [ -f "resnetv2_rgb_new.pth.tar" ]; then + echo "Siamese model weights exists... skip" +else + download_with_retry 1H0Q_DbdKPLFcZee8I14K62qV7TTy7xvS resnetv2_rgb_new.pth.tar +fi + +# Reference list +if [ -f "expand_targetlist.zip" ]; then + echo "Reference list exists... skip" +else + download_with_retry 1fr5ZxBKyDiNZ_1B6rRAfZbAHBBoUjZ7I expand_targetlist.zip +fi -echo "All packages installed successfully!" +# Domain map +if [ -f "domain_map.pkl" ]; then + echo "Domain map exists... skip" +else + download_with_retry 1qSdkSSoCYUkZMKs44Rup_1DPBxHnEKl1 domain_map.pkl +fi + + + +# Replace the placeholder in the YAML template +echo "All packages installed successfully!" \ No newline at end of file diff --git a/text_recog.py b/text_recog.py new file mode 100644 index 0000000..eb610ba --- /dev/null +++ b/text_recog.py @@ -0,0 +1,31 @@ +import re + +def pred_text_in_image(OCR_MODEL, shot_path): + + result = OCR_MODEL.ocr(shot_path, cls=True) + if result is None or result[0] is None: + return '' + + most_fit_results = result[0] + ocr_text = [line[1][0] for line in most_fit_results] + detected_text = ' '.join(ocr_text) + + return detected_text + +def check_email_credential_taking(OCR_MODEL, shot_path): + detected_text = pred_text_in_image(OCR_MODEL, shot_path) + if len(detected_text) > 0: + return rule_matching(detected_text) + return False, None + +def rule_matching(detected_text): + email_login_pattern = r'邮箱.*登录|邮箱.*登陆|邮件.*登录|邮件.*登陆' + specified_email_pattern = r'@[\w.-]+\.\w+' + + if re.findall(email_login_pattern, detected_text): + find_email = re.findall(specified_email_pattern, detected_text) + if find_email: + return True, find_email[0] + + return False, None + diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..bfc41d8 --- /dev/null +++ b/utils.py @@ -0,0 +1,134 @@ +import torch.nn.functional as F +from PIL import Image +import math + +def resolution_alignment(img1, img2): + ''' + Resize two images according to the minimum resolution between the two + :param img1: first image in PIL.Image + :param img2: second image in PIL.Image + :return: resized img1 in PIL.Image, resized img2 in PIL.Image + ''' + w1, h1 = img1.size + w2, h2 = img2.size + w_min, h_min = min(w1, w2), min(h1, h2) + if w_min == 0 or h_min == 0: ## something wrong, stop resizing + return img1, img2 + if w_min < h_min: + img1_resize = img1.resize((int(w_min), math.ceil(h1 * (w_min/w1)))) # ceiling to prevent rounding to 0 + img2_resize = img2.resize((int(w_min), math.ceil(h2 * (w_min/w2)))) + else: + img1_resize = img1.resize((math.ceil(w1 * (h_min/h1)), int(h_min))) + img2_resize = img2.resize((math.ceil(w2 * (h_min/h2)), int(h_min))) + return img1_resize, img2_resize + +def brand_converter(brand_name): + ''' + Helper function to deal with inconsistency in brand naming + ''' + if brand_name == 'Adobe Inc.' or brand_name == 'Adobe Inc': + return 'Adobe' + elif brand_name == 'ADP, LLC' or brand_name == 'ADP, LLC.': + return 'ADP' + elif brand_name == 'Amazon.com Inc.' or brand_name == 'Amazon.com Inc': + return 'Amazon' + elif brand_name == 'Americanas.com S,A Comercio Electrnico': + return 'Americanas.com S' + elif brand_name == 'AOL Inc.' or brand_name == 'AOL Inc': + return 'AOL' + elif brand_name == 'Apple Inc.' or brand_name == 'Apple Inc': + return 'Apple' + elif brand_name == 'AT&T Inc.' or brand_name == 'AT&T Inc': + return 'AT&T' + elif brand_name == 'Banco do Brasil S.A.': + return 'Banco do Brasil S.A' + elif brand_name == 'Credit Agricole S.A.': + return 'Credit Agricole S.A' + elif brand_name == 'DGI (French Tax Authority)': + return 'DGI French Tax Authority' + elif brand_name == 'DHL Airways, Inc.' or brand_name == 'DHL Airways, Inc' or brand_name == 'DHL': + return 'DHL Airways' + elif brand_name == 'Dropbox, Inc.' or brand_name == 'Dropbox, Inc': + return 'Dropbox' + elif brand_name == 'eBay Inc.' or brand_name == 'eBay Inc': + return 'eBay' + elif brand_name == 'Facebook, Inc.' or brand_name == 'Facebook, Inc': + return 'Facebook' + elif brand_name == 'Free (ISP)': + return 'Free ISP' + elif brand_name == 'Google Inc.' or brand_name == 'Google Inc': + return 'Google' + elif brand_name == 'Mastercard International Incorporated': + return 'Mastercard International' + elif brand_name == 'Netflix Inc.' or brand_name == 'Netflix Inc': + return 'Netflix' + elif brand_name == 'PayPal Inc.' or brand_name == 'PayPal Inc': + return 'PayPal' + elif brand_name == 'Royal KPN N.V.': + return 'Royal KPN N.V' + elif brand_name == 'SF Express Co.': + return 'SF Express Co' + elif brand_name == 'SNS Bank N.V.': + return 'SNS Bank N.V' + elif brand_name == 'Square, Inc.' or brand_name == 'Square, Inc': + return 'Square' + elif brand_name == 'Webmail Providers': + return 'Webmail Provider' + elif brand_name == 'Yahoo! Inc' or brand_name == 'Yahoo! Inc.': + return 'Yahoo!' + elif brand_name == 'Microsoft OneDrive' or brand_name == 'Office365' or brand_name == 'Outlook': + return 'Microsoft' + elif brand_name == 'Global Sources (HK)': + return 'Global Sources HK' + elif brand_name == 'T-Online': + return 'Deutsche Telekom' + elif brand_name == 'Airbnb, Inc': + return 'Airbnb, Inc.' + elif brand_name == 'azul': + return 'Azul' + elif brand_name == 'Raiffeisen Bank S.A': + return 'Raiffeisen Bank S.A.' + elif brand_name == 'Twitter, Inc' or brand_name == 'Twitter': + return 'Twitter, Inc.' + elif brand_name == 'capital_one': + return 'Capital One Financial Corporation' + elif brand_name == 'la_banque_postale': + return 'La Banque postale' + elif brand_name == 'db': + return 'Deutsche Bank AG' + elif brand_name == 'Swiss Post' or brand_name == 'PostFinance': + return 'PostFinance' + elif brand_name == 'grupo_bancolombia': + return 'Bancolombia' + elif brand_name == 'barclays': + return 'Barclays Bank Plc' + elif brand_name == 'gov_uk': + return 'Government of the United Kingdom' + elif brand_name == 'Aruba S.p.A': + return 'Aruba S.p.A.' + elif brand_name == 'TSB Bank Plc': + return 'TSB Bank Limited' + elif brand_name == 'strato': + return 'Strato AG' + elif brand_name == 'cogeco': + return 'Cogeco' + elif brand_name == 'Canada Revenue Agency': + return 'Government of Canada' + elif brand_name == 'UniCredit Bulbank': + return 'UniCredit Bank Aktiengesellschaft' + elif brand_name == 'ameli_fr': + return 'French Health Insurance' + elif brand_name == 'Banco de Credito del Peru': + return 'bcp' + else: + return brand_name + +def l2_norm(x): + """ + l2 normalization + :param x: + :return: + """ + if len(x.shape): + x = x.reshape((x.shape[0], -1)) + return F.normalize(x, p=2, dim=1) \ No newline at end of file