-
Notifications
You must be signed in to change notification settings - Fork 15
/
Copy pathtrain_vgg19_distibuted.py
89 lines (77 loc) · 3.82 KB
/
train_vgg19_distibuted.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
"""
Simple tester for the Tensorflow distributed
"""
from os import environ
import numpy as np
import tensorflow as tf
from Networking import ClusterGen
from dataSetGenerator import append
from dataSetGenerator import loadClasses
from dataSetGenerator import picShow
from vgg19 import vgg19_trainable as vgg19
batch = np.load("DataSets/UCMerced_LandUse_dataTrain.npy")
labels = np.load("DataSets/UCMerced_LandUse_labelsTrain.npy")
classes = loadClasses("DataSets/UCMerced_LandUse.txt")
classes_num = len(classes)
rib = batch.shape[1] # picture Rib
workers = ['DESKTOP-07HFBQN','FOUZI-PC']
pss = ['DELL-MINI']
index = workers.index(environ['COMPUTERNAME']) if environ['COMPUTERNAME'] in workers else pss.index(environ['COMPUTERNAME']) if environ['COMPUTERNAME'] in pss else None
job = 'worker' if environ['COMPUTERNAME'] in workers else 'ps' if environ['COMPUTERNAME'] in pss else None
cluster = tf.train.ClusterSpec(ClusterGen(workers,pss))
server = tf.train.Server(cluster, job_name=job, task_index=index)
# with tf.device(tf.train.replica_device_setter(
# worker_device="/job:ps/task:"+str(index),
# cluster=cluster)):
with tf.device("/job:ps/task:{}".format(index)):
images = tf.placeholder(tf.float32, [None, rib, rib, 3])
true_out = tf.placeholder(tf.float32, [None, len(classes)])
train_mode = tf.placeholder(tf.bool)
try:
vgg = vgg19.Vgg19('Weights/VGG19_{}C.npy'.format(classes_num),len(classes))
except:
vgg = vgg19.Vgg19(None,len(classes))
vgg.build(images,train_mode)
global_step = tf.train.get_or_create_global_step()
inc_global_step = tf.assign(global_step, global_step + 1)
if job == 'worker':
# with tf.device(tf.train.replica_device_setter(
# worker_device="/job:worker/task:"+str(index),
# cluster=cluster)):
with tf.device("/job:worker/task:{}".format(index)):
cost = tf.reduce_sum((vgg.prob - true_out) ** 2)
train = tf.train.GradientDescentOptimizer(0.0001).minimize(cost)
correct_prediction = tf.equal(tf.argmax(cost), tf.argmax(true_out))
acc = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
batch_size = 5
batche_num = len(batch)
# tf.train.global_step(sess, tf.Variable(10, trainable=False, name='global_step'))
hooks=[tf.train.StopAtStepHook(last_step=2)]
with tf.train.MonitoredTrainingSession(master=server.target, is_chief=(index == 0), hooks=hooks) as sess:
while not sess.should_stop():
_ = sess.run(inc_global_step)
costs = []
accs = []
print("******************* ", _, " *******************")
indice = np.random.permutation(batche_num)
for i in range(int(batche_num/batch_size)):
print('step 1')
min_batch = indice[i*batch_size:(i+1)*batch_size]
print('step 2')
cur_cost, cur_train,cur_acc= sess.run([cost, train,acc], feed_dict={images: batch[min_batch], true_out: labels[min_batch], train_mode: True})
print("Iteration %d loss:\n%s" % (i, cur_cost))
costs.append(str(cur_cost)+'\n')
accs.append(str(cur_acc)+'\n')
# with tf.device(tf.train.replica_device_setter(
# worker_device="/job:ps/task:"+str(index),
# cluster=cluster)):
append(costs,'Data/cost19_{}C_D'.format(classes_num))
append(accs,'Data/acc19_{}C_D'.format(classes_num))
vgg.save_npy(sess, 'Weights/VGG19_{}C_D.npy'.format(classes_num))
# test classification
prob = sess.run(vgg.prob, feed_dict={images: batch[:10], train_mode: False})
picShow(batch[:10],labels[:10], classes, None, prob)
elif job == 'ps':
server.join()
else:
print("error JOB")