-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathrunner.py
60 lines (48 loc) · 2.1 KB
/
runner.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import torch
import ray
from model import PolicyNet, QNet
from worker import Worker
from parameter import *
class Runner(object):
def __init__(self, meta_agent_id):
self.meta_agent_id = meta_agent_id
self.device = torch.device('cuda') if USE_GPU else torch.device('cpu')
self.local_network = PolicyNet(INPUT_DIM, EMBEDDING_DIM)
self.local_q_net = QNet(INPUT_DIM, EMBEDDING_DIM)
self.local_network.to(self.device)
self.local_q_net.to(self.device)
def get_weights(self):
return self.local_network.state_dict()
def set_policy_net_weights(self, weights):
self.local_network.load_state_dict(weights)
def set_q_net_weights(self, weights1):
self.local_q_net.load_state_dict(weights1)
def do_job(self, episode_number):
save_img = True if episode_number % SAVE_IMG_GAP == 0 else False
# save_img = True
worker = Worker(self.meta_agent_id, self.local_network, self.local_q_net, episode_number, device=self.device, save_image=save_img, greedy=False)
worker.work(episode_number)
job_results = worker.episode_buffer
perf_metrics = worker.perf_metrics
return job_results, perf_metrics
def job(self, weights_set, episode_number):
print("starting episode {} on metaAgent {}".format(episode_number, self.meta_agent_id))
# set the local weights to the global weight values from the master network
self.set_policy_net_weights(weights_set[0])
self.set_q_net_weights(weights_set[1])
job_results, metrics = self.do_job(episode_number)
info = {
"id": self.meta_agent_id,
"episode_number": episode_number,
}
return job_results, metrics, info
@ray.remote(num_cpus=1, num_gpus=NUM_GPU/NUM_META_AGENT)
class RLRunner(Runner):
def __init__(self, meta_agent_id):
super().__init__(meta_agent_id)
if __name__=='__main__':
ray.init()
runner = RLRunner.remote(0)
job_id = runner.do_job.remote(1)
out = ray.get(job_id)
print(out[1])