Skip to content

Commit

Permalink
描述你对这三个文件的修改
Browse files Browse the repository at this point in the history
  • Loading branch information
hanwenxu1 committed Dec 2, 2024
1 parent 9086ae4 commit c557ff8
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 12 deletions.
22 changes: 11 additions & 11 deletions logo_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def check_domain_brand_inconsistency(logo_boxes,
domain_map_path: str,
model, logo_feat_list,
file_name_list, shot_path: str,
url: str, ts: float,
url: str, similarity_threshold: float,
topk: float = 3):
# targetlist domain list
with open(domain_map_path, 'rb') as handle:
Expand All @@ -36,7 +36,7 @@ def check_domain_brand_inconsistency(logo_boxes,
bbox = [float(min_x), float(min_y), float(max_x), float(max_y)]
matched_target, matched_domain, this_conf = pred_brand(model, domain_map,
logo_feat_list, file_name_list,
shot_path, bbox, t_s=ts, grayscale=False,
shot_path, bbox, similarity_threshold=similarity_threshold, grayscale=False,
do_aspect_ratio_check=False,
do_resolution_alignment=False)
# print(target_this, domain_this, this_conf)
Expand All @@ -53,6 +53,11 @@ def check_domain_brand_inconsistency(logo_boxes,


def load_model_weights(num_classes: int, weights_path: str):
'''
:param num_classes: number of protected brands
:param weights_path: siamese weights
:return model: siamese model
'''
# Initialize model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = KNOWN_MODELS["BiT-M-R50x1"](head_size=num_classes, zero_head=True)
Expand All @@ -77,11 +82,8 @@ def load_model_weights(num_classes: int, weights_path: str):
def cache_reference_list(model, targetlist_path: str, grayscale=False):
'''
cache the embeddings of the reference list
:param num_classes: number of protected brands
:param weights_path: siamese weights
:param targetlist_path: targetlist folder
:param grayscale: convert logo to grayscale or not, default is RGB
:return model: siamese model
:return logo_feat_list: targetlist embeddings
:return file_name_list: targetlist paths
'''
Expand Down Expand Up @@ -112,8 +114,6 @@ def get_embedding(img, model, grayscale=False):
Inference for a single image
:param img: image path in str or image in PIL.Image
:param model: model to make inference
:param imshow: enable display of image or not
:param title: title of displayed image
:param grayscale: convert image to grayscale or not
:return feature embedding of shape (2048,)
'''
Expand Down Expand Up @@ -148,7 +148,7 @@ def get_embedding(img, model, grayscale=False):
return logo_feat


def pred_brand(model, domain_map, logo_feat_list, file_name_list, shot_path: str, gt_bbox, t_s,
def pred_brand(model, domain_map, logo_feat_list, file_name_list, shot_path: str, gt_bbox, similarity_threshold,
grayscale=False,
do_resolution_alignment=True,
do_aspect_ratio_check=True):
Expand All @@ -160,7 +160,7 @@ def pred_brand(model, domain_map, logo_feat_list, file_name_list, shot_path: str
:param file_name_list: reference logo paths
:param shot_path: path to the screenshot
:param gt_bbox: 1x4 np.ndarray/list/tensor bounding box coords
:param t_s: similarity threshold for siamese
:param similarity_threshold: similarity threshold for siamese
:param do_resolution_alignment: if the similarity does not exceed the threshold, do we align their resolutions to have a retry
:param do_aspect_ratio_check: once two logos are similar, whether we want to a further check on their aspect ratios
:param grayscale: convert image(cropped) to grayscale or not
Expand Down Expand Up @@ -201,7 +201,7 @@ def pred_brand(model, domain_map, logo_feat_list, file_name_list, shot_path: str
continue

## If the largest similarity exceeds threshold
if top3_simlist[j] >= t_s:
if top3_simlist[j] >= similarity_threshold:
predicted_brand = top3_brandlist[j]
predicted_domain = top3_domainlist[j]
final_sim = top3_simlist[j]
Expand All @@ -213,7 +213,7 @@ def pred_brand(model, domain_map, logo_feat_list, file_name_list, shot_path: str
img_feat = get_embedding(cropped, model, grayscale=grayscale)
logo_feat = get_embedding(candidate_logo, model, grayscale=grayscale)
final_sim = logo_feat.dot(img_feat)
if final_sim >= t_s:
if final_sim >= similarity_threshold:
predicted_brand = top3_brandlist[j]
predicted_domain = top3_domainlist[j]
else:
Expand Down
2 changes: 2 additions & 0 deletions logo_recog.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def pred_rcnn(im, predictor):
if im.shape[-1] == 4:
im = cv2.cvtColor(im, cv2.COLOR_BGRA2BGR)
else:
print(f"Image at path {im} is None")
return None

outputs = predictor(im)
Expand Down Expand Up @@ -63,6 +64,7 @@ def vis(img_path, pred_boxes):

check = cv2.imread(img_path)
if pred_boxes is None or len(pred_boxes) == 0:
print("Pred_boxes is None or the length of pred_boxes is 0")
return check
pred_boxes = pred_boxes.numpy() if not isinstance(pred_boxes, np.ndarray) else pred_boxes

Expand Down
2 changes: 1 addition & 1 deletion phishpedia.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def test_orig_phishpedia(self, url, screenshot_path, html_path):
file_name_list=self.LOGO_FILES,
url=url,
shot_path=screenshot_path,
ts=self.SIAMESE_THRE,
similarity_threshold=self.SIAMESE_THRE,
topk=1)
logo_match_time = time.time() - start_time

Expand Down

0 comments on commit c557ff8

Please sign in to comment.