-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathtrain.py
executable file
·105 lines (76 loc) · 3.09 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
#!/usr/bin/env python
from model_defs import *
from utils import *
import argparse
import pandas as pd
import os
import numpy as np
import csv
from keras.callbacks import Callback
import time
from keras.optimizers import RMSprop, SGD, Adagrad, Adadelta, Adam
from keras.callbacks import ModelCheckpoint, EarlyStopping, TensorBoard
def run(specs):
''' Make the mid directory '''
directory = os.path.join(specs['work_dir'],specs['save_id'])
if not os.path.exists(directory):
os.makedirs(directory)
"""logging"""
logfilename = os.path.join(specs['work_dir'],specs['save_id'],"log.txt")
with open(logfilename, mode='w') as logfile:
for key in sorted(list(specs.keys())):
value = specs[key]
print("{}: {}".format(key, value))
logfile.write("{}: {}\n".format(key, value))
logfile.flush()
'''define the optimiser and compile'''
model = specs['model']()
model.compile(loss='categorical_crossentropy',
optimizer=specs['optimisation'],
metrics=['accuracy'])
'''callbacks'''
checkpointer = ModelCheckpoint(
filepath=os.path.join(specs['work_dir'],specs['save_id'],'model.hdf5'),
verbose=1,
save_best_only=True)
earlystopping = EarlyStopping(monitor='val_loss', patience=10, verbose=1)
tboard = TensorBoard(log_dir=os.path.join(specs['work_dir'],specs['save_id'],'logs_tb'),
histogram_freq=0, write_graph=False, write_images=False)
c = Cifar_npy_gen(batch_size=specs['batch_size'])
''' save the final model'''
# serialize model to JSON
model_json = model.to_json()
with open(os.path.join(specs['work_dir'],specs['save_id'],"model_arch.json"), "w") as json_file:
json_file.write(model_json)
#Can be replaced by NPYGenerators so that the last batch doesnt roll over the training set
'''call fit_generartor'''
model.fit_generator(
generator=c.train_gen,
samples_per_epoch=50000,
nb_epoch=epochs,
validation_data = c.test_gen,
nb_val_samples = 10000,
callbacks=[checkpointer,earlystopping,tboard],
verbose=1)
model.save(os.path.join(specs['work_dir'],specs['save_id'],"model_final.h5"))
pass
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Training CIFAR example for LeNet architectures')
#epochs, batch_size and model ID
parser.add_argument('--epochs', type=str, default='200', help='number of epochs (the program runs through the whole data set)')
parser.add_argument('--batchsize', type=str, default='64', help='batch size')
args = parser.parse_args()
#arguments from the parser
epochs = int(args.epochs)
batch_size = int(args.batchsize)
model = keras_eg_alldrop
specs = {
'model': model,
'epochs': epochs,
'batch_size': batch_size,
'save_id': model.__name__,
'optimisation': 'adam',
'work_dir': '/u/ambrish/models'
}
#run the model
run(specs)