-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathutils.py
112 lines (83 loc) · 3.37 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import os
import pickle
import random
import numpy as np
import torch
def assign_learning_rate(param_group, new_lr):
param_group["lr"] = new_lr
def _warmup_lr(base_lr, warmup_length, step):
return base_lr * (step + 1) / warmup_length
def cosine_lr(optimizer, base_lrs, warmup_length, steps):
if not isinstance(base_lrs, list):
base_lrs = [base_lrs for _ in optimizer.param_groups]
assert len(base_lrs) == len(optimizer.param_groups)
def _lr_adjuster(step):
for param_group, base_lr in zip(optimizer.param_groups, base_lrs):
if step < warmup_length:
lr = _warmup_lr(base_lr, warmup_length, step)
else:
e = step - warmup_length
es = steps - warmup_length
lr = 0.5 * (1 + np.cos(np.pi * e / es)) * base_lr
assign_learning_rate(param_group, lr)
return _lr_adjuster
def accuracy(output, target, topk=(1,)):
pred = output.topk(max(topk), 1, True, True)[1].t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
return [
float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy())
for k in topk
]
def torch_save(classifier, save_path):
if os.path.dirname(save_path) != "":
os.makedirs(os.path.dirname(save_path), exist_ok=True)
torch.save({"state_dict": classifier.state_dict()}, save_path)
print("Checkpoint saved to", save_path)
# with open(save_path, 'wb') as f:
# pickle.dump(classifier.cpu(), f)
def torch_load(classifier, save_path, device=None):
checkpoint = torch.load(save_path)
missing_keys, unexpected_keys = classifier.load_state_dict(
checkpoint["state_dict"], strict=False
)
if len(missing_keys) > 0 or len(unexpected_keys) > 0:
print("Missing keys:", missing_keys)
print("Unexpected keys:", unexpected_keys)
print("Checkpoint loaded from", save_path)
# with open(save_path, 'rb') as f:
# classifier = pickle.load(f)
if device is not None:
classifier = classifier.to(device)
return classifier
def get_logits(inputs, classifier):
assert callable(classifier)
if hasattr(classifier, "to"):
classifier = classifier.to(inputs.device)
return classifier(inputs)
def get_probs(inputs, classifier):
if hasattr(classifier, "predict_proba"):
probs = classifier.predict_proba(inputs.detach().cpu().numpy())
return torch.from_numpy(probs)
logits = get_logits(inputs, classifier)
return logits.softmax(dim=1)
class LabelSmoothing(torch.nn.Module):
def __init__(self, smoothing=0.0):
super(LabelSmoothing, self).__init__()
self.confidence = 1.0 - smoothing
self.smoothing = smoothing
def forward(self, x, target):
logprobs = torch.nn.functional.log_softmax(x, dim=-1)
nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
nll_loss = nll_loss.squeeze(1)
smooth_loss = -logprobs.mean(dim=-1)
loss = self.confidence * nll_loss + self.smoothing * smooth_loss
return loss.mean()
def seed_all(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def num_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)