forked from xavysp/DexiNed
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun_model.py
134 lines (117 loc) · 6.29 KB
/
run_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
""" DexiNed main script
This code is based on DexiNed (Dense Extreme Inception Network for Edge Detection),
Please pay attention in the function config_model() to set any parameter before training or
testing the model.
"""
__author__ = "Xavier Soria Poma, CVC-UAB"
__email__ = "[email protected] / [email protected]"
__homepage__="www.cvc.uab.cat/people/xsoria"
__credits__=['DexiNed']
__copyright__ = "MIT License [see LICENSE for details]"#"Copyright 2019, CIMI"
import sys
import argparse
import tensorflow as tf
import utls.dataset_manager as dm
from train import m_trainer
from test import m_tester
import platform
def config_model():
in_linux = True if platform.system() == "Linux" else False
base_dir = "/opt/dataset/" if in_linux else "../../dataset/"
parser = argparse.ArgumentParser(description='Basic details to run HED')
# dataset config
parser.add_argument('--train_dataset', default='BIPED', choices=['BIPED','BSDS'])
parser.add_argument('--test_dataset', default='CLASSIC', choices=['BIPED', 'BSDS','MULTICUE','NYUD','PASCAL','CID','DCD'])
parser.add_argument('--dataset_dir',default=base_dir,type=str) # default:'/opt/dataset/'
parser.add_argument('--dataset_augmented', default=True,type=bool)
parser.add_argument('--train_list',default='train_rgb.lst', type=str) # BSDS train_pair.lst, SSMIHD train_rgb_pair.lst/train_rgbn_pair.lst
parser.add_argument('--test_list', default='test_rgb.lst',type=str) # for NYUD&BSDS:test_pair.lst, biped msi_test.lst/test_rgb.lst
parser.add_argument('--trained_model_dir', default='train',type=str) # 'trainV2_RN'
# SSMIHD_RGBN msi_valid_list.txt and msi_test_list.txt is for unified test
parser.add_argument('--use_nir', default=False, type=bool)
parser.add_argument('--use_dataset', default=False, type=bool) # test: dataset=True single image=FALSE
# model config
parser.add_argument('--model_state', default='test', choices=['train','test','None']) # always in None
parser.add_argument('--model_name', default='DXN',choices=['DXN','XCP','None'])
parser.add_argument('--use_v1', default=False,type=bool)
parser.add_argument('--model_purpose', default='edges',choices=['edges','restoration','None'])
parser.add_argument('--batch_size_train',default=8,type=int)
parser.add_argument('--batch_size_val',default=8, type=int)
parser.add_argument('--batch_size_test',default=1,type=int)
parser.add_argument('--checkpoint_dir', default='checkpoints',type=str)
parser.add_argument('--logs_dir', default='logs',type=str)
parser.add_argument('--learning_rate',default=1e-4, type=float) # 1e-4=0.0001
parser.add_argument('--lr_scheduler',default=None,choices=[None,'asce','desc']) # check here
parser.add_argument('--learning_rate_decay', default=0.1,type=float)
parser.add_argument('--weight_decay', default=0.0002, type=float)
parser.add_argument('--model_weights_path', default='vgg16_.npy')
parser.add_argument('--train_split', default=0.9, type=float) # default 0.8
parser.add_argument('--max_iterations', default=180000, type=int) # 100000
parser.add_argument('--learning_decay_interval',default=25000, type=int) # 25000
parser.add_argument('--loss_weights', default=1.0, type=float)
parser.add_argument('--save_interval', default=20000, type=int) # 50000
parser.add_argument('--val_interval', default=30, type=int)
parser.add_argument('--use_subpixel', default=None, type=bool) # None=upsampling with transp conv
parser.add_argument('--deep_supervision', default=True, type= bool)
parser.add_argument('--target_regression',default=True, type=bool) # true
parser.add_argument('--mean_pixel_values', default=[103.939,116.779,123.68, 137.86], type=float)# [103.939,116.779,123.68]
# for Nir pixels mean [103.939,116.779,123.68, 137.86]
parser.add_argument('--channel_swap', default=[2,1,0], type=int)
parser.add_argument('--gpu-limit',default=1.0, type= float, )
parser.add_argument('--use_trained_model', default=True, type=bool) # for vvg16
parser.add_argument('--use_previous_trained', default=False, type=bool) # for training
# image configuration
parser.add_argument('--image_width', default=512, type=int) # 480 NYUD=560 BIPED=1280 default 400 other 448
parser.add_argument('--image_height', default=512, type=int) # 480 for NYUD 425 BIPED=720 default 400
parser.add_argument('--n_channels', default=3, type=int) # last ssmihd_xcp trained in 512
# test config
parser.add_argument('--test_snapshot', default=149999, type=int) # BIPED: 149736 BSDS:101179
#DexiNedv1=149736,DexiNedv2=149999
parser.add_argument('--testing_threshold', default=0.0, type=float)
parser.add_argument('--base_dir_results',default='results/edges',type=str) # default: '/opt/results/edges'
# single image default=None
args = parser.parse_args()
return args
def get_session(gpu_fraction):
num_threads = False
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=gpu_fraction)
if num_threads:
return tf.Session(config=tf.ConfigProto(gpu_options=gpu_options, intra_op_parallelism_threads=num_threads))
else:
return tf.Session(config=tf.ConfigProto())
def main(args):
if not args.dataset_augmented:
# Only for BIPED dataset
# dm.augment_data(args)
print("Please visit the webpage of BIPED in:")
print("https://xavysp.github.io/MBIPED/")
print("and run the code")
sys.exit()
if args.model_state =='train' or args.model_state=='test':
sess = get_session(args.gpu_limit)
# sess =tf.Session()
else:
print("The model state is None, so it will exit...")
sys.exit()
if args.model_state=='train':
trainer = m_trainer(args)
trainer.setup()
trainer.run(sess)
sess.close()
if args.model_state=='test':
if args.test_dataset=="BIPED":
if args.image_width >700:
pass
else:
print(' image size is not set in non augmented data')
sys.exit()
tester = m_tester(args)
tester.setup(sess)
tester.run(sess)
sess.close()
if args.model_state=="None":
print("Sorry the model state is {}".format(args.model_state))
sys.exit()
if __name__=='__main__':
args = config_model()
main(args=args)