forked from ryanxhr/Discrete_IVR
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbcq.py
55 lines (47 loc) · 1.92 KB
/
bcq.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import argparse
import d3rlpy
import wandb
import numpy as np
from d3rlpy.algos import IQL, DiscreteCQL, DiscreteSAC, DiscreteBCQ
from d3rlpy.datasets import get_atari
from d3rlpy.metrics.scorer import evaluate_on_environment
from d3rlpy.metrics.scorer import td_error_scorer
from d3rlpy.metrics.scorer import discounted_sum_of_advantage_scorer
from d3rlpy.metrics.scorer import average_value_estimation_scorer
from sklearn.model_selection import train_test_split
import os
# os.system('export WANDB_API_KEY=9d45bb78a65fb0f3b0402a9eae36ed832ae8cbdc')
# os.system('wandb login')
def main(args):
# export WANDB_KEY_API=
name = f'BCQ_{args.dataset}_{args.ratio}'
wandb.init(project='atari', entity='louis_t0', name=name)
dataset, env = get_atari(args.dataset)
d3rlpy.seed(args.seed)
train_episodes, test_episodes = train_test_split(dataset, test_size=1-args.ratio)
bcq = DiscreteBCQ(
n_frames=4, # frame stacking
q_func_factory=args.q_func,
scaler='pixel',
use_gpu=args.gpu)
bcq.fit(train_episodes,
eval_episodes=test_episodes,
n_epochs=200,
scorers={
'environment': evaluate_on_environment(env, epsilon=0.05)
# 'td_error': td_error_scorer,
# 'discounted_advantage': discounted_sum_of_advantage_scorer,
# 'value_scale': average_value_estimation_scorer
})
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, default='breakout-mixed-v0')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--ratio', type=float, default=0.1)
parser.add_argument('--q-func',
type=str,
default='mean',
choices=['mean', 'qr', 'iqn', 'fqf'])
parser.add_argument('--gpu', default=0,type=int)
args = parser.parse_args()
main(args)