-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
147 lines (87 loc) · 3.45 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import logging
import sys
sys.path.append('configuration')
sys.path.append('input')
sys.path.append('train')
sys.path.append('output')
#from config import *
from dataset_manager import *
from training_manager import TrainManager
from output_manager import OutputManager
import tensorflow as tf
import numpy as np
def restore_session(sess,saver,models_path):
ckpt = 0
if not os.path.exists(models_path):
os.mkdir( models_path)
os.mkdir( models_path + "/train/")
os.mkdir( models_path + "/val/")
ckpt = tf.train.get_checkpoint_state(models_path)
if ckpt:
print 'Restoring from ',ckpt.model_checkpoint_path
saver.restore(sess,ckpt.model_checkpoint_path)
else:
ckpt = 0
return ckpt
def save_model(saver,sess,models_path,i):
saver.save(sess, models_path + '/model.ckpt', global_step=i)
print 'Model saved.'
def get_last_iteration(ckpt):
if ckpt:
return int(ckpt.model_checkpoint_path.split('-')[1])
else:
return 1
def train(gpu_number, experiment_name):
""" Initialize the input class to get the configuration """
conf_module = __import__(experiment_name)
config = conf_module.configMain()
manager = DatasetManager(conf_module.configInput())
""" Get the batch tensor that is going to be used around """
batch_tensor = manager.train.get_batch_tensor()
batch_tensor_val = manager.validation.get_batch_tensor()
config_gpu = tf.ConfigProto()
config_gpu.gpu_options.visible_device_list=gpu_number
sess = tf.Session(config=config_gpu)
manager.start_training_queueing(sess)
manager.start_validation_queueing(sess)
training_manager= TrainManager(conf_module.configTrain())
print 'building network'
training_manager.build_network()
print 'building loss'
training_manager.build_loss()
print 'building optimization'
training_manager.build_optimization()
""" Initializing Session as variables that control the session """
print 'initializing variables'
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver(tf.global_variables(),max_to_keep=0)
"""Load a previous model if it is configured to restore """
print 'restoring checkpoint'
cpkt = 0
if config.restore:
cpkt = restore_session(sess,saver,config.models_path)
# """Training"""
# """ Get the Last Iteration Trained """
initialIteration = get_last_iteration(cpkt)
output_manager = OutputManager(conf_module.configOutput(),training_manager,sess,batch_tensor_val)
# output_manager = output_class.get_output_manager()
# # CREATE HERE THE TF SESSION
print_flag = True
for i in range(initialIteration, config.number_iterations):
# """ Get the training batch """
# start_time = time.time()
# """Save the model every 300 iterations"""
start_time = time.time()
if i%6000 == 0:
save_model(saver,sess,config.models_path,i)
training_manager.run_train_step(batch_tensor,sess,i)
# #print "RUNNED STEP"
duration = time.time() - start_time
# """ With the current trained net, let the outputmanager print and save all the outputs """
output_manager.print_outputs(i,duration)
if print_flag == True:
#print entire batch(only once) to check if it is balanced
np.savetxt('batch_values.txt', output_manager._logger._last_batch_inputs, newline=' ',)
print_flag = False
#print(open('batch_values.txt', 'r').read().count(" 1.0")
#use something like this in terminal to check number of occurances of each