-
Notifications
You must be signed in to change notification settings - Fork 178
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature Request] Add Attention nets (GTrXL model in particular) #165
Comments
I meant the SB3 contrib repo. For GTrXL, are you willing to contribute that algorithm? |
Sorry for the misunderstanding.
I'm not sure yet, I will try to implement it for my experiments first. |
Also related: https://github.com/maohangyu/TIT_open_source |
@RemiG3 hey, have you started to implement it? Mayba I can give a free hand in it :) |
Yes, I have implemented it, but not tested properly. I'm currently having some troubles with my custom environment that I'm trying to solve. @araffin is it possible to create a new branch for this feature (to share the code)? |
yes, that's what a fork and pull request are meant for |
I have came accross on this, this is quite modular and easy to tune, Transformers-RL, the only backside is that, it has been implemented only to gaussian policy. |
Hey, I finally made the PR #176 to share the code. It should work, but I'm not sure about the performances. |
RemiG3, Thank you for adding attention net to contrib. what's the shape of the input would be look like , for example if I want to use cartpole environment? |
Thank you, @eric000888, for reporting this (feel free to provide the code you tested as you did in your first edits). I have updated the branch to fix a bug on the dimension of minibatchs. So, it should now work for EDIT: I also add assertions about these cases, as in the original PPO. |
RemiG3, from sb3_contrib.ppo_attention.ppo_attention import AttentionPPO VE = DummyVecEnv([lambda: gym.make("CartPole-v1")]) model = AttentionPPO( First I create a vector environments and then setup the model like LSTM recurrent PPO, then run the model.learn(). I follow the code and saw you concatenate the tensor of input and memory, but the input format from SB3 is one records and then after the first round of full loop it's become batch number of records and that throw the error as the memory is still just one Thank you for the update, i will try it this weekend. |
another questions is if you just use GtrXL as feature extractor in PPO model, is this will get the same results? as the LSTM recurrent PPO has a flag to use the LSTM layer or not , similar like a feature extractor layer. |
another thing is GtrXL demand more computation power , and PPO is like aiming a moving target, I found training a GtrXL PPO is a daunting task especially when using multiple layers. but if you can update the gradient on the whole trajectory then you may speed up the learning process. that means you collect all action/observation and then do one pass of back propagation. |
🚀 Feature Request
This feature request is a duplicate from stable-baselines3 (see DLR-RM/stable-baselines3#177).
The idea is to add the GTrXL model in the contrib repo from the paper Stabilizing Transformers for Reinforcement Learning, as done in RLlib: https://github.com/ray-project/ray/blob/master/rllib/models/torch/attention_net.py.
@araffin has already mentioned that he created it and will make it public (comment).I wonder if this is still relevant?The text was updated successfully, but these errors were encountered: