-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathlearner.py
128 lines (104 loc) · 5.26 KB
/
learner.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import multiprocessing
import pickle
import subprocess
from argparse import ArgumentParser
from itertools import count
from multiprocessing import Process
import horovod.tensorflow.keras as hvd
import tensorflow as tf
import zmq
from pyarrow import deserialize
from tensorflow.keras.backend import set_session
from common import init_components, load_yaml_config, save_yaml_config, create_experiment_dir
from core.mem_pool import MemPoolManager, MultiprocessingMemPool
from utils.cmdline import parse_cmdline_kwargs
# Horovod: initialize Horovod.
hvd.init()
# Horovod: pin GPU to be used to process local rank (one GPU per process)
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
config.gpu_options.visible_device_list = str(hvd.local_rank())
set_session(tf.Session(config=config))
callbacks = [hvd.callbacks.BroadcastGlobalVariablesCallback(0)]
parser = ArgumentParser()
parser.add_argument('--alg', type=str, default='ppo', help='The RL algorithm')
parser.add_argument('--env', type=str, default='CartPole-v1', help='The game environment')
parser.add_argument('--num_steps', type=float, default=2e5, help='The number of total training steps')
parser.add_argument('--data_port', type=int, default=5000, help='Learner server port to receive training data')
parser.add_argument('--param_port', type=int, default=5001, help='Learner server to publish model parameters')
parser.add_argument('--model', type=str, default='acmlp', help='Training model')
parser.add_argument('--pool_size', type=int, default=4000, help='The max length of data pool')
parser.add_argument('--training_freq', type=int, default=1,
help='How many receptions of new data are between each training, '
'which can be fractional to represent more than one training per reception')
parser.add_argument('--keep_training', action='store_true', help="No matter whether new data is received recently, "
"keep training as long as the data is enough and"
"ignore '--training_freq'")
parser.add_argument('--batch_size', type=int, default=4000, help='The batch size for training')
parser.add_argument('--exp_path', type=str, default=None, help='Directory to save logging data and config file')
parser.add_argument('--config', type=str, default=None, help='The YAML configuration file')
parser.add_argument('--record_throughput_interval', type=int, default=10,
help='The time interval between each throughput record')
def main():
# Parse input parameters
args, unknown_args = parser.parse_known_args()
args.num_steps = int(args.num_steps)
unknown_args = parse_cmdline_kwargs(unknown_args)
# Load config file
load_yaml_config(args, 'learner')
# Expose socket to actor(s)
context = zmq.Context()
weights_socket = context.socket(zmq.PUB)
weights_socket.bind(f'tcp://*:{args.param_port}')
env, agent = init_components(args, unknown_args)
# Save configuration file
create_experiment_dir(args, 'LEARNER-')
save_yaml_config(args.exp_path / 'config.yaml', args, 'learner', agent)
# Record commit hash
with open(args.exp_path / 'hash', 'w') as f:
f.write(str(subprocess.run('git rev-parse HEAD'.split(), stdout=subprocess.PIPE).stdout.decode('utf-8')))
# Variables to control the frequency of training
receiving_condition = multiprocessing.Condition()
num_receptions = multiprocessing.Value('i', 0)
# Start memory pool in another process
manager = MemPoolManager()
manager.start()
mem_pool = manager.MemPool(capacity=args.pool_size)
Process(target=recv_data,
args=(args.data_port, mem_pool, receiving_condition, num_receptions, args.keep_training)).start()
# Print throughput statistics
Process(target=MultiprocessingMemPool.record_throughput, args=(mem_pool, args.record_throughput_interval)).start()
for step in count(1):
# Do some updates
agent.update_training(step, args.num_steps)
if len(mem_pool) >= args.batch_size:
if args.keep_training:
agent.learn(mem_pool.sample(size=args.batch_size))
else:
with receiving_condition:
while num_receptions.value < args.training_freq:
receiving_condition.wait()
data = mem_pool.sample(size=args.batch_size)
num_receptions.value -= args.training_freq
# Training
agent.learn(data)
# Sync weights to actor
if hvd.rank() == 0:
weights_socket.send(pickle.dumps(agent.get_weights()))
def recv_data(data_port, mem_pool, receiving_condition, num_receptions, keep_training):
context = zmq.Context()
data_socket = context.socket(zmq.REP)
data_socket.bind(f'tcp://*:{data_port}')
while True:
# noinspection PyTypeChecker
data: dict = deserialize(data_socket.recv())
data_socket.send(b'200')
if keep_training:
mem_pool.push(data)
else:
with receiving_condition:
mem_pool.push(data)
num_receptions.value += 1
receiving_condition.notify()
if __name__ == '__main__':
main()