-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconfig.py
executable file
·129 lines (107 loc) · 3.54 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import os
import pprint
import socket
import platform
import copy
import pdb
class AttrDict():
_freezed = False
""" Avoid accidental creation of new hierarchies. """
def __getattr__(self, name):
if self._freezed:
raise AttributeError(name)
if name.startswith('_'):
# Do not mess with internals. Otherwise copy/pickle will fail
raise AttributeError(name)
ret = AttrDict()
setattr(self, name, ret)
return ret
def __setattr__(self, name, value):
if self._freezed and name not in self.__dict__:
raise AttributeError(
"Config was freezed! Unknown config: {}".format(name))
super().__setattr__(name, value)
def __str__(self):
return pprint.pformat(self.to_dict(), indent=1, width=100, compact=True)
__repr__ = __str__
def to_dict(self):
"""Convert to a nested dict. """
return {k: v.to_dict() if isinstance(v, AttrDict) else v
for k, v in self.__dict__.items() if not k.startswith('_')}
def update_args(self, args):
"""
Update from command line args.
E.g., args = [TRAIN.BATCH_SIZE=1,TRAIN.INIT_LR=0.1]
"""
assert isinstance(args, (tuple, list))
for cfg in args:
keys, v = cfg.split('=', maxsplit=1)
keylist = keys.split('.')
dic = self
for i, k in enumerate(keylist[:-1]):
assert k in dir(dic), "Unknown config key: {}".format(keys)
dic = getattr(dic, k)
key = keylist[-1]
oldv = getattr(dic, key)
if not isinstance(oldv, str):
v = eval(v)
setattr(dic, key, v)
def freeze(self, freezed=True):
self._freezed = freezed
for v in self.__dict__.values():
if isinstance(v, AttrDict):
v.freeze(freezed)
# avoid silent bugs
def __eq__(self, _):
raise NotImplementedError()
def __ne__(self, _):
raise NotImplementedError()
config = AttrDict()
_C = config # short alias to avoid coding
# training
_C.TRAIN.DATA_ROOT = 'data/sample_train'
_C.TRAIN.SEED = 100
_C.TRAIN.MAX_EPOCH = 1000
_C.TRAIN.LOSS_WEIGHTS = {
'scale_flow': 10,
'scale_occ': 1,
'pix': 10,
'vgg_feature': 10,
'vgg_style': 10,
'feat_match': 10,
'GAN_gen': 1,
'GAN_discrim': 1,
}
_C.TRAIN.STEPS_PER_EPOCH = 10000
_C.TRAIN.SAVE_PER_K_EPOCHS = 1
_C.TRAIN.SUMMARY_PERIOD = 1000
_C.TRAIN.INIT_G_LR = 2.e-4
_C.TRAIN.INIT_D_LR = 2.e-4
_C.TRAIN.G_PERIOD = 1
_C.TRAIN.D_PERIOD = 1
_C.TRAIN.PARAM_UPDATE_PERIOD = 1000
_C.TRAIN.WEIGHT_DECAY = 0
_C.TRAIN.IMG_SIZE = (256, 256)
_C.TRAIN.BATCH_SIZE = 8
_C.TRAIN.COLOR_JITTER = True
_C.TRAIN.PIX_LOSS = True # pixelwise loss
_C.TRAIN.GAN_LOSS = 'LS' # Hinge, LS, or None
_C.TRAIN.SCALE_FLOW_LOSS = False
_C.TRAIN.VGG_LOSS.NUM_PYRAMIDS = 4
_C.TRAIN.VGG_LOSS.FEATURE_LAYERS = [0, 1, 2, 3, 4]
_C.TRAIN.VGG_LOSS.STYLE_LAYERS = []
_C.MODEL.GUIDANCE = 'neural_codes' # neural_codes, geom_disp or both
_C.MODEL.NUM_LEVELS = 6
_C.MODEL.OCCLUSION_AWARE = False
_C.MODEL.USE_ALPHA_BLEND = False
_C.MODEL.SHRINK_RATIO = 0.5
_C.MODEL.SPECTRAL_NORM = True
_C.MODEL.NORM_METHOD = 'BN'
_C.MODEL.ENFORCE_BIAS = True
_C.MODEL.D_LOGITS_LAYERS = [-1]
_C.MODEL.D_SCALES = ['x1']
_C.MODEL.D_DROPOUT = False
_C.TEST.SEED = 100
_C.TEST.IMG_SIZE = (256, 256)
_C.TEST.BATCH_SIZE = 1
_C.freeze()