-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrunner.py
73 lines (58 loc) · 2.56 KB
/
runner.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import torch
import numpy as np
import ray
import os
from attention import AttentionNet
from worker import Worker
from parameters import *
from env.task_env import TaskEnv
class Runner(object):
"""Actor object to start running simulation on workers.
Gradient computation is also executed on this object."""
def __init__(self, metaAgentID):
self.metaAgentID = metaAgentID
self.device = torch.device('cuda') if TrainParams.USE_GPU else torch.device('cpu')
self.localNetwork = AttentionNet(TrainParams.AGENT_INPUT_DIM, TrainParams.TASK_INPUT_DIM, TrainParams.EMBEDDING_DIM)
self.localNetwork.to(self.device)
self.localBaseline = AttentionNet(TrainParams.AGENT_INPUT_DIM, TrainParams.TASK_INPUT_DIM, TrainParams.EMBEDDING_DIM)
self.localBaseline.to(self.device)
def get_weights(self):
return self.localNetwork.state_dict()
def set_weights(self, weights):
self.localNetwork.load_state_dict(weights)
def set_baseline_weights(self, weights):
self.localBaseline.load_state_dict(weights)
def training(self, global_weights, baseline_weights, curr_episode, env_params):
print("starting episode {} on metaAgent {}".format(curr_episode, self.metaAgentID))
# set the local weights to the global weight values from the master network
self.set_weights(global_weights)
self.set_baseline_weights(baseline_weights)
save_img = False
if SaverParams.SAVE_IMG:
if curr_episode % SaverParams.SAVE_IMG_GAP == 0:
save_img = True
worker = Worker(self.metaAgentID, self.localNetwork, self.localBaseline,
curr_episode, self.device, save_img, None, env_params)
worker.work(curr_episode)
buffer = worker.experience
perf_metrics = worker.perf_metrics
info = {
"id": self.metaAgentID,
"episode_number": curr_episode,
}
return buffer, perf_metrics, info
def testing(self, seed=None):
worker = Worker(self.metaAgentID, self.localNetwork, self.localBaseline,
0, self.device, False, seed)
reward = worker.baseline_test()
return reward, seed, self.metaAgentID
@ray.remote(num_cpus=1, num_gpus=TrainParams.NUM_GPU / TrainParams.NUM_META_AGENT)
class RLRunner(Runner):
def __init__(self, metaAgentID):
super().__init__(metaAgentID)
if __name__ == '__main__':
ray.init()
runner = RLRunner.remote(0)
job_id = runner.singleThreadedJob.remote(1)
out = ray.get(job_id)
print(out[1])