-
Notifications
You must be signed in to change notification settings - Fork 191
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
LSTM version #3
Comments
Not immediately, but it shouldn't be hard to implement it in TF. |
Is it possible to implement lstm in this ga3c architecture? |
Should be straight forward, as the state for Atari games is already defined as 4 frames together (See section 4.1 of the original DQN paper - https://www.cs.toronto.edu/~vmnih/docs/dqn.pdf) and that is what GA3C uses. If you supply those frames serially, the LSTM version of GA3C will work. |
Implementing the LSTM version without lots of code change depends on how long the sequences of training data should be. If the sequences are as long as TMAX frames (which I think is the case) then the current architecture works since Trainers receive sequences of TMAX frames. But if the training data should be any longer (i.e. multiple TMAX frames merged together) it becomes a little bit more complicated. |
in case of LSTM, shouldn't the batch be organized in (N, T, C, H, W) format? |
@etienne87 you are correct. But please look at here. What Trainer receives is in (N, T, C, H, W) format but it merges the T dimension to have data in (N, C, H, W) format. In a recurrent model these concatenations are unnecessary. |
@mbz, thanks for pointing to this. Now i am super confused with this part of the code! can you take a look at #6 ? I don't see how these concatenations are working at all! I would suggest to modify ThreadTrainer.py to :
|
LSTM would require reset_state func to address a specific row from the batch right?
sorry for pseudocode, not expert with TF. |
Another confusion I have about this, (because little experience with TF). It seems we need 2 graphs : one for prediction (taking a dynamic_rnn), and one (maybe taking a static tf.rnn?) for backprop (if feeding (N, T, C, H, W) , or is there a way to use a gradient applier like in myosuda ? |
@etienne87 I'm not sure if I understand your first question about reset_state correctly. Can you please provide more details? About having separate graphs, there are different ways of implementing the same logic in TF. We are not using separate graphs simply because it's not necessary. Can you please leverage why you think having two graphs is necessary? |
@mbz ok! What I mean : In classic A3C, it seems we can just backprop at the end of an episode (T_MAX), by just re-using the already computed predictions. On the other hand, here, it seems we need to recompute the predictions with the samples and actions. In short : X should be (N, H, W, C) in predictor thread, (N, T, H, W, C) in the train function? Maybe I misunderstood something about TF internal mechanism? Also, the thing about reset : at beginning of each episode you probably want to reset to zero the c & h of your lstm. So as @ppwwyyxx is suggesting, lstm state should be saved inside each ProcessAgent ? |
@mbz I have implemented A3C-LSTM with long sequence length. You don't have to send the whole sequence into the graph. What I did is to maintain the current LSTM hidden state for every game simulator in Python side, and every time feed the new frame together with the hidden state of each simulator to the graph. |
@ppwwyyxx I have built an LSTM and stored the hidden states. |
@markovyao
|
@markovyao
if tc.is_training:
need_reset_states = tf.reshape(tf.ones_like(self._input_is_over) - self._input_is_over, (-1, 1))
op_updates = [tf.scatter_update(initial_rnn_states[idx], self._input_agent_indexs, rnn_output_states_array[idx] * tf.cast(need_reset_states, rnn_output_states_array[idx].dtype)) \
for idx in range(len(rnn_output_states_array))]
else:
# in predict mode, the is_over is for last state
batch_size = tf.shape(self._input_agent_indexs)[0]
op_updates = []
for idx in range(len(initial_rnn_states)):
shape_states = tf.shape(initial_rnn_states[idx])
op = tf.scatter_update(initial_rnn_states[idx], self._input_agent_indexs, tf.zeros((batch_size,shape_states[1]), dtype=initial_rnn_states[idx].dtype))
op_resets.append(op)
op = tf.scatter_update(initial_rnn_states[idx], self._input_agent_indexs, rnn_output_states_array[idx])
op_updates.append(op)
this implement is useful if you have many LSTM network or in frequent modification development, for it only export update/reset ops outside model |
@ricky1203 : could you perhaps provide an example/ link in context? |
@etienne87 note: for hidden states stored in model, agent should predict/train in one model(GPU device) during one episode |
@etienne87 |
Coming back to this problem with a slightly more understanding on with variable length rnn : In ThreadTrainer::run :
In NetworkVP::_create_graph
In NetworkVP::predict_p_and_v:
In NetworkVP::train:
I think the only thing i am missing is how to sort of "unpack" sequence of encoded states in
|
Anyway, there is a first implementation that works fine if you don't have too much underachieved experiences (of length < Config.TIME_MAX) here I "solved" the issue by padding sequences in ThreadTrainer.py. In order to be optimal, we would need to dynamically batch the data after the feedforward encoder (before the LSTM), in order to feed a (N, TIME_MAX, 256) Tensor to I will now test on Pong, fuse with GAE branch. If someone wants to help me understand how to improve this you are welcome! :-) |
Hum, Actually there was still an error in my code, I forgot to mask the loss for padding inputs! I propose a first fix here Apparently this now works better (at least for CartPole-v0) In Config.py :
|
Out of interest, can I ask why you've removed this page? What were your findings wrt performance of the addition of LSTM? Edit: Found your model here: https://github.com/etienne87/GA3C , thanks! |
It is a great work. Is there any plan to develop a LSTM version?
The text was updated successfully, but these errors were encountered: