Skip to content

Commit

Permalink
option to retrain old model
Browse files Browse the repository at this point in the history
  • Loading branch information
Naresh1318 committed Dec 10, 2017
1 parent 706aea5 commit 20f9ab9
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 2 deletions.
16 changes: 16 additions & 0 deletions generate_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@ def __int__(self):
self.input_frames = tf.placeholder(dtype=tf.float32, shape=[None, 84, 84, 4], name='input_frames')
self.target_frame = tf.placeholder(dtype=tf.float32, shape=[None, 84, 84, 1], name='target_frame')
self.action_performed = tf.placeholder(dtype=tf.int32, shape=[None, 4], name='action_performed')
self.n_epochs = 100
self.learning_rate = 1e-4
self.batch_size = 32
self.momentum = 0.9
self.logdir = './Results/prediction_model'

def model(self, x, action, reuse=False):
if reuse:
Expand Down Expand Up @@ -65,6 +67,20 @@ def train(self):
# TODO: Currently only shows latest input frame
tf.summary.image(name='Input_images', tensor=self.input_frames[:, :, :, 0])

saver = tf.train.Saver()

init = tf.global_variables_initializer()

with tf.Session() as sess:
sess.run(init)

file_writer = tf.summary.FileWriter(logdir=self.logdir, graph=sess.graph)

# TODO: Train on 1 step prediction objective later extend
# for e in range(self.n_epochs):
# n_batches =


def get_next_batch(self):
pass

2 changes: 1 addition & 1 deletion mission_control_breakout.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
########################################################################################################################
# Control
train_model = False
load_trained_model = False
load_trained_model = True
show_ui = True
show_action = False

Expand Down
2 changes: 2 additions & 0 deletions play_acrobat.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@
import ops
import sys
import itertools
from gym.wrappers import Monitor

# Setup the environment
env = gym.make('Acrobot-v1')
env = Monitor(env=env, directory="./Results/Videos/Acrobat", resume=True)

# Placeholders
# TODO: make the shape of X_input generalized
Expand Down
4 changes: 3 additions & 1 deletion play_breakout.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@
import mission_control_breakout as mc
import ops
import matplotlib.pyplot as plt
from gym.wrappers import Monitor

# Setup the environment
env = gym.make('BreakoutDeterministic-v4')
env = Monitor(env=env, directory="./Results/Videos/Breakout", resume=True)

# Placeholders
X_input = tf.placeholder(dtype=tf.float32, shape=[None, 84, 84, 4], name='Observations')
Expand Down Expand Up @@ -167,7 +169,7 @@ def play(sess, agent, no_plays, log_dir=None, show_ui=False, show_action=False):
while not done:
if show_ui:
env.render()
if np.random.rand() < 0.01:
if np.random.rand() < 0.05:
action = env.action_space.sample()
else:
action = np.argmax(sess.run(agent, feed_dict={X_input: state}))
Expand Down
2 changes: 2 additions & 0 deletions play_cartpole.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@
import ops
import sys
import itertools
from gym.wrappers import Monitor

# Setup the environment
env = gym.make('CartPole-v1')
env = Monitor(env=env, directory="./Results/Videos/CartPole", resume=True)

# Placeholders
# TODO: make the shape of X_input generalized
Expand Down

0 comments on commit 20f9ab9

Please sign in to comment.