Skip to content

Commit

Permalink
Perhaps better naming
Browse files Browse the repository at this point in the history
Rename networks.
  • Loading branch information
jupyter31 committed Jul 7, 2024
1 parent 23b169b commit 99aea4f
Show file tree
Hide file tree
Showing 8 changed files with 21 additions and 21 deletions.
12 changes: 6 additions & 6 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from q_learning.utils import read_config
from q_learning.preprocess import greyscale
from q_learning.environment import PreproWrapper, MaxPoolSkipEnv, EnvTest
from q_learning.network import DQNLinear, DQNDeepMind, LinearExploration, LinearSchedule
from q_learning.network import Linear, DQN, LinearExploration, LinearSchedule

"""
This script lets us run deep Q network or linear approximation according to a custom config file.
Expand All @@ -26,7 +26,7 @@ def run():
parser.add_argument(
"--config_filename",
help="The name of the config file in the config/ directory to be used for model training.",
default="test_dqn_linear.yml",
default="test_linear.yml",
)

args = parser.parse_args()
Expand Down Expand Up @@ -54,7 +54,7 @@ def run():
)

# train model
model = DQNDeepMind(env, config)
model = DQN(env, config)
model.run(exp_schedule, lr_schedule)

elif config["model"] == "linear":
Expand All @@ -76,7 +76,7 @@ def run():
)

# train model
model = DQNLinear(env, config)
model = Linear(env, config)
model.run(exp_schedule, lr_schedule)

else:
Expand Down Expand Up @@ -115,11 +115,11 @@ def run():
)

if config["model"] == "dqn":
model = DQNDeepMind(env, config)
model = DQN(env, config)
model.run(exp_schedule, lr_schedule)

elif config["model"] == "linear":
model = DQNLinear(env, config)
model = Linear(env, config)
model.run(exp_schedule, lr_schedule)

else:
Expand Down
File renamed without changes.
File renamed without changes.
4 changes: 2 additions & 2 deletions q_learning/network/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .dqn_linear_approximation import DQNLinear
from .dqn_deepmind import DQNDeepMind
from .linear import Linear
from .dqn import DQN
from .schedule import LinearExploration, LinearSchedule
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import torch
import torch.nn as nn
from .dqn_abstract import DQN
from .dqn_abstract import AbstractDQN


class DQNDeepMind(DQN):
class DQN(AbstractDQN):
"""
Implementation of DeepMind's Nature paper:
https://storage.googleapis.com/deepmind-data/assets/papers/DeepMindNature14236Paper.pdf
Expand Down
2 changes: 1 addition & 1 deletion q_learning/network/dqn_abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from q_learning.network.qn import QN


class DQN(QN, ABC):
class AbstractDQN(QN, ABC):
""" Abstract class for Deep Q Network """
def __init__(self, env, config, logger=None):
self.q_network = None
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import torch
import torch.nn as nn

from .dqn_abstract import DQN
from .dqn_abstract import AbstractDQN


class DQNLinear(DQN):
class Linear(AbstractDQN):
"""
We represent Q function as linear approximation Q_\theta(s,a) = \thetaT*\delta(s,a)
where [\delta(s,a)]_{s‘,a‘} = 1 iff s‘ = s, a‘ = a.
Expand Down
16 changes: 8 additions & 8 deletions test/test_q_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch.nn as nn

from q_learning.environment import EnvTest
from q_learning.network import DQNDeepMind, DQNLinear, LinearSchedule, LinearExploration
from q_learning.network import DQN, Linear, LinearSchedule, LinearExploration
from utils import read_config


Expand Down Expand Up @@ -34,11 +34,11 @@ def test_eps_out_of_range(self):
class TestLinearDQN(unittest.TestCase):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.linear_config = read_config('test_dqn_linear.yml')
self.linear_config = read_config('test_linear.yml')

def test_config(self):
env = EnvTest((5, 5, 1))
model = DQNLinear(env, self.linear_config)
model = Linear(env, self.linear_config)
state_shape = list(env.observation_space.shape)
img_height, img_width, n_channels = state_shape
num_actions = env.action_space.n
Expand Down Expand Up @@ -71,7 +71,7 @@ def test_config(self):

def test_loss(self):
env = EnvTest((5, 5, 1))
model = DQNLinear(env, self.linear_config)
model = Linear(env, self.linear_config)
state_shape = list(env.observation_space.shape)
img_height, img_width, n_channels = state_shape
num_actions = env.action_space.n
Expand Down Expand Up @@ -105,7 +105,7 @@ def test_loss(self):

def test_optimizer(self):
env = EnvTest((5, 5, 1))
model = DQNLinear(env, self.linear_config)
model = Linear(env, self.linear_config)
state_shape = list(env.observation_space.shape)
img_height, img_width, n_channels = state_shape
num_actions = env.action_space.n
Expand Down Expand Up @@ -139,18 +139,18 @@ def test_run(self):
)

# train model
model = DQNLinear(env, self.linear_config)
model = Linear(env, self.linear_config)
model.run(exp_schedule, lr_schedule)


class TestDeepMindDQN(unittest.TestCase):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.dqn_deepmind_config = read_config('test_dqn_deepmind.yml')
self.dqn_deepmind_config = read_config('test_dqn.yml')

def test_input_output_shapes(self):
env = EnvTest((80, 80, 1))
model = DQNDeepMind(env, self.dqn_deepmind_config)
model = DQN(env, self.dqn_deepmind_config)

state_shape = list(env.observation_space.shape)
img_height, img_width, n_channels = state_shape
Expand Down

0 comments on commit 99aea4f

Please sign in to comment.