forked from hhc1997/MSCN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
117 lines (97 loc) · 3.61 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import math
import shutil
import json
import random
import torch.backends.cudnn as cudnn
import torch
import numpy as np
from torch.optim.adam import Adam
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self, name="", fmt=":f"):
self.name = name
self.fmt = fmt
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def __str__(self):
fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})"
return fmtstr.format(**self.__dict__)
class ProgressMeter(object):
def __init__(self, num_batches, meters, prefix=""):
self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
self.meters = meters
self.prefix = prefix
def display(self, batch):
entries = [self.prefix + self.batch_fmtstr.format(batch)]
entries += [str(meter) for meter in self.meters]
print("\t".join(entries))
def _get_batch_fmtstr(self, num_batches):
num_digits = len(str(num_batches // 1))
fmt = "{:" + str(num_digits) + "d}"
return "[" + fmt + "/" + fmt.format(num_batches) + "]"
def save_config(opt, file_path):
with open(file_path, "w") as f:
json.dump(opt.__dict__, f, indent=2)
def save_checkpoint(state, is_best, filename="checkpoint.pth.tar", prefix=""):
tries = 15
error = None
# deal with unstable I/O. Usually not necessary.
while tries:
try:
torch.save(state, prefix + filename)
if is_best:
shutil.copyfile(prefix + filename, prefix + "model_best.pth.tar")
except IOError as e:
error = e
tries -= 1
else:
break
print("model save {} failed, remaining {} trials".format(filename, tries))
if not tries:
raise error
def set_seed(seed=1):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
def init_seeds(seed=0, cuda_deterministic=False):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
# Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html
if cuda_deterministic: # slower, more reproducible
cudnn.deterministic = True
cudnn.benchmark = False
else: # faster, less reproducible
cudnn.deterministic = False
cudnn.benchmark = True
class MetaAdam(Adam):
def __init__(self, net, *args, **kwargs):
super(MetaAdam, self).__init__(*args, **kwargs)
self.net = net
def set_parameter(self, current_module, name, parameters):
if '.' in name:
name_split = name.split('.')
module_name = name_split[0]
rest_name = '.'.join(name_split[1:])
for children_name, children in current_module.named_children():
if module_name == children_name:
self.set_parameter(children, rest_name, parameters)
break
else:
current_module._parameters[name] = parameters
def meta_step(self, grads):
group = self.param_groups[0]
lr = group['lr']
for (name, parameter), grad in zip(self.net.named_parameters(), grads):
parameter.detach_()
self.set_parameter(self.net, name, parameter.add(grad, alpha=-lr))