-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Initial DQN and Linear approximation agents in test environment.
Create test environment and implement an agent performing Q-learning with linear approximation of Q function.
- Loading branch information
0 parents
commit 23b169b
Showing
23 changed files
with
2,060 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
.idea/ | ||
**/__pycache__/ | ||
results/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
name: dqn_cuda | ||
channels: | ||
- defaults | ||
- conda-forge | ||
- pytorch | ||
dependencies: | ||
- ffmpeg=4.2.2 | ||
- moviepy=1.0.3 | ||
- notebook=6.5.4 | ||
- pip=23.0.1 | ||
- python=3.8.10 | ||
- pyyaml=6.0 | ||
- scipy=1.10.1 | ||
- tensorboard=2.10.0 | ||
- tk=8.6.12 | ||
- pip: | ||
- ale-py==0.8.1 | ||
- autorom[accept-rom-license]==0.6.1 | ||
- cmake==3.26.3; sys_platform != "win32" | ||
- gym[atari]==0.26.2 | ||
- matplotlib==3.7.1 | ||
- numpy==1.24.3 | ||
- pyglet==2.0.5 | ||
- torch==2.0.0+cu117 | ||
- wget==3.2; sys_platform != "win32" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
import sys | ||
import argparse | ||
import gym | ||
from q_learning.utils import read_config | ||
from q_learning.preprocess import greyscale | ||
from q_learning.environment import PreproWrapper, MaxPoolSkipEnv, EnvTest | ||
from q_learning.network import DQNLinear, DQNDeepMind, LinearExploration, LinearSchedule | ||
|
||
""" | ||
This script lets us run deep Q network or linear approximation according to a custom config file. | ||
(Configuration specified in the configs, config/ folder). | ||
Results, log and recording of the agent are stored in the results folder. | ||
We can monitor the progress of the agent with Tensorboard: | ||
To launch tensorboard (default port is 6006): | ||
>tensorboard --logdir=results/ --host 0.0.0.0 | ||
""" | ||
|
||
|
||
def run(): | ||
parser = argparse.ArgumentParser( | ||
description="A program to run DQN training", | ||
formatter_class=argparse.ArgumentDefaultsHelpFormatter, | ||
) | ||
|
||
parser.add_argument( | ||
"--config_filename", | ||
help="The name of the config file in the config/ directory to be used for model training.", | ||
default="test_dqn_linear.yml", | ||
) | ||
|
||
args = parser.parse_args() | ||
|
||
config = read_config(args.config_filename) | ||
|
||
if config["env"]["env_name"] == "test_environment": | ||
|
||
if config["model"] == "dqn": | ||
env = EnvTest((80, 80, 3)) | ||
|
||
# exploration strategy | ||
exp_schedule = LinearExploration( | ||
env, | ||
config["hyper_params"]["eps_begin"], | ||
config["hyper_params"]["eps_end"], | ||
config["hyper_params"]["eps_nsteps"], | ||
) | ||
|
||
# learning rate schedule | ||
lr_schedule = LinearSchedule( | ||
config["hyper_params"]["lr_begin"], | ||
config["hyper_params"]["lr_end"], | ||
config["hyper_params"]["lr_nsteps"], | ||
) | ||
|
||
# train model | ||
model = DQNDeepMind(env, config) | ||
model.run(exp_schedule, lr_schedule) | ||
|
||
elif config["model"] == "linear": | ||
env = EnvTest((5, 5, 1)) | ||
|
||
# exploration strategy | ||
exp_schedule = LinearExploration( | ||
env, | ||
config["hyper_params"]["eps_begin"], | ||
config["hyper_params"]["eps_end"], | ||
config["hyper_params"]["eps_nsteps"], | ||
) | ||
|
||
# learning rate schedule | ||
lr_schedule = LinearSchedule( | ||
config["hyper_params"]["lr_begin"], | ||
config["hyper_params"]["lr_end"], | ||
config["hyper_params"]["lr_nsteps"], | ||
) | ||
|
||
# train model | ||
model = DQNLinear(env, config) | ||
model.run(exp_schedule, lr_schedule) | ||
|
||
else: | ||
sys.exit( | ||
"Incorrectly specified model, config['model'] should either be 'dqn' or 'linear'." | ||
) | ||
elif config["env"]["env_name"] == "ALE/Pong-v5": | ||
# create env | ||
env = gym.make( | ||
config["env"]["env_name"], | ||
frameskip=(2, 5), | ||
full_action_space=False, | ||
render_mode=config["env"]["render_mode"], | ||
) | ||
env = MaxPoolSkipEnv(env, skip=config["hyper_params"]["skip_frame"]) | ||
env = PreproWrapper( | ||
env, | ||
prepro=greyscale, | ||
shape=(80, 80, 1), | ||
overwrite_render=config["env"]["overwrite_render"], | ||
) | ||
|
||
# exploration strategy | ||
exp_schedule = LinearExploration( | ||
env, | ||
config["hyper_params"]["eps_begin"], | ||
config["hyper_params"]["eps_end"], | ||
config["hyper_params"]["eps_nsteps"], | ||
) | ||
|
||
# learning rate schedule | ||
lr_schedule = LinearSchedule( | ||
config["hyper_params"]["lr_begin"], | ||
config["hyper_params"]["lr_end"], | ||
config["hyper_params"]["lr_nsteps"], | ||
) | ||
|
||
if config["model"] == "dqn": | ||
model = DQNDeepMind(env, config) | ||
model.run(exp_schedule, lr_schedule) | ||
|
||
elif config["model"] == "linear": | ||
model = DQNLinear(env, config) | ||
model.run(exp_schedule, lr_schedule) | ||
|
||
else: | ||
sys.exit("Incorrectly specified model, config['model'] should either be dqn or linear.") | ||
else: | ||
sys.exit( | ||
"Incorrectly specified environment, config['model'] should either be 'Pong-v5' or 'test_environment'." | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
run() |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
model: "dqn" | ||
|
||
env: | ||
env_name: "test_environment" | ||
overwrite_render: True | ||
record: False | ||
high: 255. | ||
|
||
model_training: | ||
num_episodes_test: 20 | ||
grad_clip: True | ||
clip_val: 10 | ||
saving_freq: 5000 | ||
log_freq: 50 | ||
eval_freq: 1000 | ||
soft_epsilon: 0 | ||
device: "gpu" # cpu/gpu/mps | ||
compile: False | ||
compile_mode: "default" | ||
|
||
hyper_params: | ||
nsteps_train: 10000 | ||
batch_size: 32 | ||
buffer_size: 1000 | ||
target_update_freq: 500 | ||
gamma: 0.99 | ||
learning_freq: 4 | ||
state_history: 4 | ||
lr_begin: 0.005 | ||
lr_end: 0.001 | ||
lr_nsteps: 5000 | ||
eps_begin: 1 | ||
eps_end: 0.01 | ||
eps_nsteps: 8000 | ||
learning_start: 200 | ||
|
||
output: | ||
output_path: &output_path "results/dqn_deepmind/" | ||
model_output: !join [*output_path, "model.weights.pt"] | ||
log_path: !join [*output_path, "log.txt"] | ||
plot_output: !join [*output_path, "scores.png"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
model: "linear" | ||
|
||
env: | ||
env_name: "test_environment" | ||
overwrite_render: True | ||
record: False # must be set to false! | ||
high: 255. | ||
|
||
model_training: | ||
num_episodes_test: 20 | ||
grad_clip: True | ||
clip_val: 10 | ||
saving_freq: 5000 | ||
log_freq: 50 | ||
eval_freq: 1000 | ||
soft_epsilon: 0 | ||
device: "cpu" # cpu/gpu/mps | ||
compile: False | ||
compile_mode: "default" | ||
|
||
hyper_params: | ||
nsteps_train: 10000 | ||
batch_size: 16 | ||
buffer_size: 1000 | ||
target_update_freq: 500 | ||
gamma: 0.99 | ||
learning_freq: 4 | ||
state_history: 4 | ||
lr_begin: 0.005 | ||
lr_end: 0.001 | ||
lr_nsteps: 5000 | ||
eps_begin: 1 | ||
eps_end: 0.01 | ||
eps_nsteps: 5000 | ||
learning_start: 200 | ||
|
||
output: | ||
output_path: &output_path "results/linear/" | ||
model_output: !join [*output_path, "model.weights.pt"] | ||
log_path: !join [*output_path, "log.txt"] | ||
plot_output: !join [*output_path, "scores.png"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .test_env import EnvTest | ||
from .wrappers import MaxPoolSkipEnv, PreproWrapper |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
import numpy as np | ||
|
||
class ActionSpace(object): | ||
def __init__(self, n): | ||
self.n = n | ||
|
||
def sample(self): | ||
return np.random.randint(0, self.n) | ||
|
||
|
||
class ObservationSpace(object): | ||
def __init__(self, shape): | ||
self.shape = shape | ||
self.state_0 = np.random.randint(0, 50, shape, dtype=np.uint16) | ||
self.state_1 = np.random.randint(100, 150, shape, dtype=np.uint16) | ||
self.state_2 = np.random.randint(200, 250, shape, dtype=np.uint16) | ||
self.state_3 = np.random.randint(300, 350, shape, dtype=np.uint16) | ||
self.states = [self.state_0, self.state_1, self.state_2, self.state_3] | ||
|
||
|
||
class EnvTest(object): | ||
""" | ||
Lightweight test environment. | ||
Attribution: Adapted from Igor Gitman, CMU / Karan Goel Modified | ||
""" | ||
|
||
def __init__(self, shape=(84, 84, 3)): | ||
# 4 states | ||
self.rewards = [0.2, -0.1, 0.0, -0.3] | ||
self.cur_state = 0 | ||
self.num_iters = 0 | ||
self.was_in_second = False | ||
self.action_space = ActionSpace(5) | ||
self.observation_space = ObservationSpace(shape) | ||
|
||
def reset(self): | ||
self.cur_state = 0 | ||
self.num_iters = 0 | ||
self.was_in_second = False | ||
return self.observation_space.states[self.cur_state] | ||
|
||
def step(self, action): | ||
assert 0 <= action <= 4 | ||
self.num_iters += 1 | ||
if action < 4: | ||
self.cur_state = action | ||
reward = self.rewards[self.cur_state] | ||
if self.was_in_second is True: | ||
reward *= -10 | ||
if self.cur_state == 2: | ||
self.was_in_second = True | ||
else: | ||
self.was_in_second = False | ||
return ( | ||
self.observation_space.states[self.cur_state], | ||
reward, | ||
self.num_iters >= 5, | ||
False, | ||
{"ale.lives": 0}, | ||
) | ||
|
||
def render(self): | ||
print(self.cur_state) |
Oops, something went wrong.