Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

HER Wrappers #340

Draft
wants to merge 22 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .isort.cfg
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[settings]
known_third_party = cv2,gym,matplotlib,numpy,pandas,pytest,scipy,setuptools,torch
known_third_party = cv2,gym,matplotlib,numpy,pandas,pytest,scipy,setuptools,toml,torch
multi_line_output=3
include_trailing_comma=True
force_grid_wrap=0
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ repos:
args: [--exclude=^((examples|docs)/.*)$]

- repo: https://github.com/timothycrosley/isort
rev: 4.3.2
rev: 5.4.2
hooks:
- id: isort

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
Using Shared Parameters in Actor Critic Agents in GenRL
=======================================================

The Actor Critic Agents use two networks, an Actor network to select an action to be taken in the current state, and a
critic network, to estimate the value of the state the agent is currently in. There are two common ways to implement
this actor critic architecture.

The first method - Indpendent Actor and critic networks -

.. code-block:: none

state
/ \
<actor network> <critic network>
/ \
action value

And the second method - Using a set of shared parameters to extract a feature vector from the state. The actor and the
critic network act on this feature vector to select an action and estimate the value

.. code-block:: none

state
|
<decoder>
/ \
<actor network> <critic network>
/ \
action value

GenRL provides support to incorporte this decoder network in all of the actor critic agents through a ``shared_layers``
parameter. ``shared_layers`` takes the sizes of the mlp layers o be used, and ``None`` if no decoder network is to be
used

As an example - in A2C -

.. code-block:: python
# The imports
from genrl.agents import A2C
from genrl.environments import VectorEnv
from genrl.trainers import OnPolicyTrainer

# Initializing the environment
env = VectorEnv("CartPole-v0", 1)

# Initializing the agent to be used
algo = A2C(
"mlp",
env,
policy_layers=(128,),
value_layers=(128,),
shared_layers=(32, 64),
rollout_size=128,
)

# Finally initializing the trainer and trainer
trainer = OnPolicyTrainer(algo, env, log_mode=["csv"], logdir="./logs", epochs=1)
trainer.train()

The above example uses and mlp of layer sizes (32, 64) as the decoder, and can be visualised as follows -

.. code-block:: none

state
|
<32>
|
<64>
/ \
<128> <128>
/ \
action value
14 changes: 9 additions & 5 deletions genrl/agents/deep/a2c/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,27 +66,31 @@ def _create_model(self) -> None:
self.env, self.network
)
if isinstance(self.network, str):
self.ac = get_model("ac", self.network)(
arch_type = self.network
if self.shared_layers is not None:
arch_type += "s"
self.ac = get_model("ac", arch_type)(
state_dim,
action_dim,
shared_layers=self.shared_layers,
policy_layers=self.policy_layers,
value_layers=self.value_layers,
val_type="V",
discrete=discrete,
action_lim=action_lim,
).to(self.device)

else:
self.ac = self.network.to(self.device)

# action_dim = self.network.action_dim

if self.noise is not None:
self.noise = self.noise(
torch.zeros(action_dim), self.noise_std * torch.ones(action_dim)
)

self.optimizer_policy = opt.Adam(self.ac.actor.parameters(), lr=self.lr_policy)
self.optimizer_value = opt.Adam(self.ac.critic.parameters(), lr=self.lr_value)
actor_params, critic_params = self.ac.get_params()
self.optimizer_policy = opt.Adam(critic_params, lr=self.lr_policy)
self.optimizer_value = opt.Adam(actor_params, lr=self.lr_value)

def select_action(
self, state: torch.Tensor, deterministic: bool = False
Expand Down
2 changes: 2 additions & 0 deletions genrl/agents/deep/base/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def __init__(
create_model: bool = True,
batch_size: int = 64,
gamma: float = 0.99,
shared_layers=None,
policy_layers: Tuple = (64, 64),
value_layers: Tuple = (64, 64),
lr_policy: float = 0.0001,
Expand All @@ -45,6 +46,7 @@ def __init__(
self.create_model = create_model
self.batch_size = batch_size
self.gamma = gamma
self.shared_layers = shared_layers
self.policy_layers = policy_layers
self.rewards = []
self.value_layers = value_layers
Expand Down
6 changes: 3 additions & 3 deletions genrl/agents/deep/base/offpolicy.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import collections
from typing import List
from typing import List, Union

import torch
from torch.nn import functional as F

from genrl.agents.deep.base import BaseAgent
from genrl.core import (
HERWrapper,
PrioritizedBuffer,
PrioritizedReplayBufferSamples,
ReplayBuffer,
Expand Down Expand Up @@ -98,7 +99,7 @@ def sample_from_buffer(self, beta: float = None):
states, actions, rewards, next_states, dones = self._reshape_batch(batch)

# Convert every experience to a Named Tuple. Either Replay or Prioritized Replay samples.
if isinstance(self.replay_buffer, ReplayBuffer):
if isinstance(self.replay_buffer, (ReplayBuffer, HERWrapper)):
batch = ReplayBufferSamples(*[states, actions, rewards, next_states, dones])
elif isinstance(self.replay_buffer, PrioritizedBuffer):
indices, weights = batch[5], batch[6]
Expand Down Expand Up @@ -231,7 +232,6 @@ def get_target_q_values(
next_q_target_values = self.ac_target.get_value(
torch.cat([next_states, next_target_actions], dim=-1)
)

target_q_values = rewards + self.gamma * (1 - dones) * next_q_target_values

return target_q_values
Expand Down
11 changes: 8 additions & 3 deletions genrl/agents/deep/ddpg/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,13 @@ def _create_model(self) -> None:
)

if isinstance(self.network, str):
self.ac = get_model("ac", self.network)(
arch_type = self.network
if self.shared_layers is not None:
arch_type += "s"
self.ac = get_model("ac", arch_type)(
state_dim,
action_dim,
self.shared_layers,
self.policy_layers,
self.value_layers,
"Qsa",
Expand All @@ -74,10 +78,11 @@ def _create_model(self) -> None:
else:
self.ac = self.network

actor_params, critic_params = self.ac.get_params()
self.ac_target = deepcopy(self.ac).to(self.device)

self.optimizer_policy = opt.Adam(self.ac.actor.parameters(), lr=self.lr_policy)
self.optimizer_value = opt.Adam(self.ac.critic.parameters(), lr=self.lr_value)
self.optimizer_policy = opt.Adam(actor_params, lr=self.lr_policy)
self.optimizer_value = opt.Adam(critic_params, lr=self.lr_value)

def update_params(self, update_interval: int) -> None:
"""Update parameters of the model
Expand Down
5 changes: 5 additions & 0 deletions genrl/agents/deep/dqn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,9 @@ def get_q_values(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Ten
q_values (:obj:`torch.Tensor`): Q values for the given states and actions
"""
q_values = self.model(states)
if len(q_values.shape) == 2:
q_values = q_values.unsqueeze(1)
actions = actions.unsqueeze(1)
q_values = q_values.gather(2, actions)
return q_values

Expand All @@ -171,6 +174,8 @@ def get_target_q_values(
target_q_values (:obj:`torch.Tensor`): Target Q values for the DQN
"""
# Next Q-values according to target model
if len(next_states.shape) == 2:
next_states = next_states.unsqueeze(1)
next_q_target_values = self.target_model(next_states)
# Maximum of next q_target values
max_next_q_target_values = next_q_target_values.max(2)[0]
Expand Down
11 changes: 8 additions & 3 deletions genrl/agents/deep/ppo1/ppo1.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,13 @@ def _create_model(self):
self.env, self.network
)
if isinstance(self.network, str):
self.ac = get_model("ac", self.network)(
arch = self.network
if self.shared_layers is not None:
arch += "s"
self.ac = get_model("ac", arch)(
state_dim,
action_dim,
shared_layers=self.shared_layers,
policy_layers=self.policy_layers,
value_layers=self.value_layers,
val_typ="V",
Expand All @@ -79,8 +83,9 @@ def _create_model(self):
else:
self.ac = self.network.to(self.device)

self.optimizer_policy = opt.Adam(self.ac.actor.parameters(), lr=self.lr_policy)
self.optimizer_value = opt.Adam(self.ac.critic.parameters(), lr=self.lr_value)
actor_params, critic_params = self.ac.get_params()
self.optimizer_policy = opt.Adam(actor_params, lr=self.lr_policy)
self.optimizer_value = opt.Adam(critic_params, lr=self.lr_value)

def select_action(
self, state: torch.Tensor, deterministic: bool = False
Expand Down
16 changes: 7 additions & 9 deletions genrl/agents/deep/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,10 @@ def _create_model(self, **kwargs) -> None:
state_dim, action_dim, discrete, _ = get_env_properties(
self.env, self.network
)

self.ac = get_model("ac", self.network + "12")(
arch = self.network + "12"
if self.shared_layers is not None:
arch += "s"
self.ac = get_model("ac", arch)(
state_dim,
action_dim,
policy_layers=self.policy_layers,
Expand All @@ -91,13 +93,9 @@ def _create_model(self, **kwargs) -> None:
self.model = self.network

self.ac_target = deepcopy(self.ac)

self.critic_params = list(self.ac.critic1.parameters()) + list(
self.ac.critic2.parameters()
)

self.optimizer_value = opt.Adam(self.critic_params, self.lr_value)
self.optimizer_policy = opt.Adam(self.ac.actor.parameters(), self.lr_policy)
actor_params, critic_params = self.ac.get_params()
self.optimizer_value = opt.Adam(critic_params, self.lr_value)
self.optimizer_policy = opt.Adam(actor_params, self.lr_policy)

if self.entropy_tuning:
self.target_entropy = -torch.prod(
Expand Down
18 changes: 8 additions & 10 deletions genrl/agents/deep/td3/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,13 @@ def _create_model(self) -> None:
)

if isinstance(self.network, str):
# Below, the "12" corresponds to the Single Actor, Double Critic network architecture
self.ac = get_model("ac", self.network + "12")(
arch = self.network + "12"
if self.shared_layers is not None:
arch += "s"
self.ac = get_model("ac", arch)(
state_dim,
action_dim,
shared_layers=self.shared_layers,
policy_layers=self.policy_layers,
value_layers=self.value_layers,
val_type="Qsa",
Expand All @@ -85,14 +88,9 @@ def _create_model(self) -> None:
)

self.ac_target = deepcopy(self.ac)

self.critic_params = list(self.ac.critic1.parameters()) + list(
self.ac.critic2.parameters()
)
self.optimizer_value = torch.optim.Adam(self.critic_params, lr=self.lr_value)
self.optimizer_policy = torch.optim.Adam(
self.ac.actor.parameters(), lr=self.lr_policy
)
actor_params, critic_params = self.ac.get_params()
self.optimizer_value = torch.optim.Adam(critic_params, lr=self.lr_value)
self.optimizer_policy = torch.optim.Adam(actor_params, lr=self.lr_policy)

def update_params(self, update_interval: int) -> None:
"""Update parameters of the model
Expand Down
1 change: 1 addition & 0 deletions genrl/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from genrl.core.actor_critic import MlpActorCritic, get_actor_critic_from_name # noqa
from genrl.core.bandit import Bandit, BanditAgent
from genrl.core.base import BaseActorCritic # noqa
from genrl.core.buffers import HERWrapper # noqa
from genrl.core.buffers import PrioritizedBuffer # noqa
from genrl.core.buffers import PrioritizedReplayBufferSamples # noqa
from genrl.core.buffers import ReplayBuffer # noqa
Expand Down
Loading