Skip to content

Commit

Permalink
Final refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
aPovidlo committed May 21, 2024
1 parent 09570e6 commit 78697b0
Show file tree
Hide file tree
Showing 28 changed files with 758 additions and 564,786 deletions.
192 changes: 0 additions & 192 deletions rl_core/agent/agent.py

This file was deleted.

3 changes: 2 additions & 1 deletion rl_core/agent/decision_transformer.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from typing import io

import numpy as np
import torch
from torch import nn
from torchinfo import summary

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'


class EmbeddingLayer(nn.Module):
def __init__(self, input_dim, embed_dim):
super().__init__()
Expand All @@ -18,6 +18,7 @@ def forward(self, x, pos_embedding):


class DecisionTransformer(nn.Module):
""" https://arxiv.org/abs/2106.01345 """
metadata = {'name': 'DecisionTransformer'}

def __init__(self, state_dim, action_dim, max_length, embed_dim, num_heads, num_layers, dim_feedforward=2048, device=DEVICE):
Expand Down
2 changes: 1 addition & 1 deletion rl_core/agent/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def load(self, path: str):


class DQN:
""" https://arxiv.org/abs/1312.5602 """
metadata = {'name': 'DQN'}

def __init__(self, state_dim, action_dim, hidden_dim=512, gamma=0.01, lr=1e-4, batch_size=64, eps_decrease=1e-6, eps_min=1e-3, device='cuda'):
Expand Down Expand Up @@ -121,4 +122,3 @@ def create_log_report(self, log_dir):
file.write('- PI MODEL -\n')
q_function = str(summary(self.q_function, (1, self.state_dim), verbose=0))
file.write(f'{q_function}')

13 changes: 8 additions & 5 deletions rl_core/agent/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,15 +87,18 @@ def entropy(self):

return -reduce(p_log_p, "b a -> b", "sum", b=self.batch, a=self.nb_action)


class PPO(nn.Module):
""" https://arxiv.org/abs/1707.06347 """
metadata = {'name': 'PPO'}

def __init__(self,
state_dim: int, action_dim: int, hidden_dim: int = 256,
gamma: float = 0.99, epsilon: float = 0.2, tau: float = 0.25,
batch_size: int = 10, epoch_n: int = 3,
pi_lr: float = 3e-5, v_lr: float = 1e-2, device: str = 'cpu'
):
state_dim: int, action_dim: int, hidden_dim: int = 512,
gamma: float = 0.99, epsilon: float = 0.2, tau: float = 1,
batch_size: int = 32, epoch_n: int = 10,
pi_lr: float = 1e-4, v_lr: float = 1e-4,
device: str = 'cpu'
):
super().__init__()

self.state_dim = state_dim
Expand Down
Loading

0 comments on commit 78697b0

Please sign in to comment.