forked from zkailinzhang/tensorflow_video_rnn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathvideo_train.py
77 lines (62 loc) · 2.44 KB
/
video_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import time
import numpy as np
import tensorflow as tf
from data.video_input import DataInput
from model.bilstm_model import BiLSTM
FLAGS = tf.app.flags.FLAGS
def run_epoch(session, model, eval_op=None, verbose=False):
"""Runs the model on the given data."""
start_time = time.time()
costs = 0.0
iters = 0
fetches = {
"cost": model.cost,
"accuracy": model.accuracy
}
if eval_op is not None:
fetches["eval_op"] = eval_op
for step in range(model.input.epoch_size):
vals = session.run(fetches)
cost = vals["cost"]
accuracy = vals["accuracy"]
costs += cost
iters += model.input.num_steps
if verbose and step % 10 == 9:
print("accuracy: %.3f" % accuracy)
print("%.3f -- perplexity: %.3f -- speed: %.0f vps" %
(step * 1.0 / model.input.epoch_size,
np.exp(costs / iters),
iters * model.input.batch_size / (time.time() - start_time)))
return np.exp(costs / iters)
def train(config, data):
"""video training procedure
Args:
config: the configuration class
data: data class that implement the dataset.py interface
"""
with tf.Graph().as_default():
initializer = tf.random_uniform_initializer(-config['init_scale'],
config['init_scale'])
with tf.name_scope('Train'):
with tf.variable_scope('Model', reuse=None, initializer=initializer):
train_input = DataInput(config=config, data=data)
model = BiLSTM(True, train_input, config)
tf.summary.scalar("Training Loss", model.cost)
tf.summary.scalar("Learning Rate", model.lr)
sv = tf.train.Supervisor(logdir=FLAGS.save_path)
with sv.managed_session() as session:
for i in range(config['epoch']):
# Decrease the learning rate according to training epoch
lr_decay = config['lr_decay'] ** max(i - config['decay_begin_epoch'], 0.0)
model.assign_lr(session, config['learning_rate'] * lr_decay)
print("Epoch: %d Learning rate: %.3f" %
(i + 1, session.run(model.lr)))
train_perplexity = run_epoch(session, model, eval_op=model.train_op,
verbose=True)
if FLAGS.save_path:
print("Saving model to %s." % FLAGS.save_path)
sv.saver.save(session, FLAGS.save_path,
global_step=sv.global_step)