-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutil.py
123 lines (99 loc) · 3.6 KB
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import torch
from torch import nn
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import cv2
import math
import os
from model import get_model
captcha_list = list('0123456789abcdefghijklmnopqrstuvwxyz_')
captcha_length = 6
loader = transforms.Compose([
transforms.ToTensor()])
unloader = transforms.ToPILImage()
def get_accuracy(device, data_loader, test_model_name, attack_dataset_name, attack_model_name):
net = get_model(test_model_name)
net.to(device)
path = './model/' # The path of pretrained models
model = path + test_model_name
if os.path.exists(model):
checkpoint = torch.load(model, map_location=device)
net.load_state_dict(checkpoint['model_state_dict'])
net.eval()
acc = 0
count = 0
with torch.no_grad():
for inputs, labels in data_loader:
inputs, labels = inputs.to(device), labels.to(device)
outputs = net(inputs)
acc += calculat_acc(outputs, labels)
count += 1
output = acc / count
print('Accuracy for {} to {}_{} is: {}'.format(test_model_name, attack_model_name, attack_dataset_name, output))
return output
# calculate RMSD
def get_RMSD(attack_model_name, attack_dataset_name):
ori_path = './data/test/'
attack_path = './data/attack_image/' + attack_model_name + '/' + attack_dataset_name + '/'
ori_files = os.listdir(ori_path)
attack_files = os.listdir(attack_path)
i,j = 0,0
ori_len = len(ori_files)
attack_len = len(attack_files)
accumulated_err = 0
for i in range(ori_len):
if '.png' in ori_files[i]:
ori_img = cv2.imread(ori_path+ori_files[i])
ori_img = cv2.resize(ori_img,(140,44))
while j < attack_len and not ('.png' in attack_files[j]):
j += 1
if j < attack_len and ori_files[i] == attack_files[j]:
att_img = cv2.imread(attack_path+attack_files[j])
err = np.sum((ori_img.astype("float")-att_img.astype("float"))**2)
err /= float(ori_img.shape[0] * ori_img.shape[1])
err = math.sqrt(err)
accumulated_err += err
j += 1
return accumulated_err / ori_len
def calculat_acc(output, target):
target = target.view(-1, len(captcha_list))
target = torch.argmax(target, dim=1)
output = output.view(-1, len(captcha_list))
output = nn.functional.softmax(output, dim=1)
output = torch.argmax(output, dim=1)
output, target = output.view(-1, captcha_length), target.view(-1, captcha_length)
c = 0
for i, j in zip(target, output):
if torch.equal(i, j):
c += 1
acc = c / output.size()[0] * 100
return acc
def tensor_to_PIL(tensor):
image = tensor.cpu().clone()
image1 = torch.squeeze(image,0)
print(image1.shape)
return unloader(image1)
def imshow(tensor, title=None):
image = tensor.cpu().clone()
image = torch.squeeze(image,0)
print(image.shape)
image = unloader(image)
plt.imshow(image)
if title is not None:
plt.title(title)
plt.pause(0.001)
def make_dir(path):
if not os.path.exists(path):
os.makedirs(path)
# def save_history(history, model_name, save_path):
# df = pd.DataFrame.from_dict(history)
# df.to_csv(save_path + "_loss.csv", header=True)
#
# plt.figure(figsize=(6,4))
# plt.plot(df["epoch"], df["train_loss"])
# plt.xlabel("Number of Epochs")
# plt.ylabel("Training Loss")
# plt.title("Adversarial Attack Trained on %s" % model_name)
# plt.savefig(save_path + "_loss.png")
# plt.close()