Skip to content

Commit

Permalink
train on saved models
Browse files Browse the repository at this point in the history
  • Loading branch information
Naresh1318 committed Dec 9, 2017
1 parent b47dc5f commit 1236d06
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 13 deletions.
5 changes: 3 additions & 2 deletions mission_control_breakout.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@

########################################################################################################################
# Control
train_model = True
show_ui = False
train_model = False
load_trained_model = False
show_ui = True
show_action = False


Expand Down
11 changes: 6 additions & 5 deletions mission_control_breakout_ram.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
########################################################################################################################
# Training
learning_rate = 0.00025
learning_rate = 0.00015
batch_size = 32
observation_time = int(1e6) # 1e3
rand_observation_time = int(5e4) # 5e2
Expand All @@ -16,14 +16,15 @@

########################################################################################################################
# Agent Model
dense_1 = 400
dense_2 = 180
dense_3 = 100
dense_1 = 512
dense_2 = 512
dense_3 = 256


########################################################################################################################
# Control
train_model = True
train_model = False
load_trained_model = True
show_ui = True
show_action = False

Expand Down
28 changes: 22 additions & 6 deletions play_breakout.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import matplotlib.pyplot as plt

# Setup the environment
env = gym.make('Breakout-v4')
env = gym.make('BreakoutDeterministic-v4')

# Placeholders
X_input = tf.placeholder(dtype=tf.float32, shape=[None, 84, 84, 4], name='Observations')
Expand Down Expand Up @@ -79,10 +79,12 @@ def anneal_epsilon(step):
return epi


def collect_rand_observations(replay_memory):
def collect_rand_observations(replay_memory, sess=None, agent=None):
"""
Collects mc.rand_observation_time number of random observations and stores them in deque
:param replay_memory: deque, deque instance
:param agent: Tensor op, the agent architecture
:param sess: op, the restored session to restore
:return: ndarray, stored as follows:
(state, action, reward, next_states, done, life_lost)
"""
Expand All @@ -95,7 +97,11 @@ def collect_rand_observations(replay_memory):
lives_left = 5
if len(replay_memory) < mc.rand_observation_time:
for i in range(int(mc.rand_observation_time)):
action = env.action_space.sample()
if sess is None:
action = env.action_space.sample()
else:
q_prediction = sess.run(agent, feed_dict={X_input: state})
action = np.argmax(q_prediction)
next_state, reward, done, info = env.step(action)
next_state = ops.convert_to_gray_n_resize(next_state)
next_state = np.expand_dims(next_state, axis=2)
Expand Down Expand Up @@ -161,7 +167,7 @@ def play(sess, agent, no_plays, log_dir=None, show_ui=False, show_action=False):
while not done:
if show_ui:
env.render()
if np.random.rand() < 0.1:
if np.random.rand() < 0.01:
action = env.action_space.sample()
else:
action = np.argmax(sess.run(agent, feed_dict={X_input: state}))
Expand Down Expand Up @@ -247,13 +253,23 @@ def train(train_model=True):
# Get the initial epsilon
prob_rand = mc.prob_random

# TODO: Change this ASAP
# Add epsilon to Tensorboard
tf.summary.scalar('epsilon', tensor=prob_rand)
summary_op = tf.summary.merge_all()

# Replay memory
replay_memory = deque()
replay_memory = collect_rand_observations(replay_memory) # Get the initial 50k random observations

if mc.load_trained_model:
saved_models = os.listdir(mc.logdir)
latest_saved_model = sorted(saved_models)[-1]
saver.restore(sess, tf.train.latest_checkpoint(mc.logdir + latest_saved_model + "/saved_models/"))
with open(mc.logdir+"saved_models/checkout", 'r') as checkout_file:
line_1 = checkout_file.readline()
step = int(line_1[30:-2])
replay_memory = collect_rand_observations(replay_memory, sess, agent)
else:
replay_memory = collect_rand_observations(replay_memory) # Get the initial 50k random observations

game_rewards = []

Expand Down

0 comments on commit 1236d06

Please sign in to comment.