-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathsc1_train_hierarchical_loadunit.py
111 lines (80 loc) · 3.91 KB
/
sc1_train_hierarchical_loadunit.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import torch
import gym
from gym_starcraft.simple_battle_env import SimpleBattleEnv,Unit_State
from copy import deepcopy
from itertools import count
from Model_hierarchical import *
from config import *
def load_unit(ddpg_agent, saved_folder, episode):
ddpg_agent.unit_actor = torch.load(os.path.join(saved_folder, 'unit_actor_{}.mod'.format(episode)))
ddpg_agent.unit_actor_target = torch.load(os.path.join(saved_folder, 'unit_actor_{}.mod'.format(episode)))
ddpg_agent.unit_critic = torch.load(os.path.join(saved_folder, 'unit_critic_{}.mod'.format(episode)))
ddpg_agent.unit_critic_target = torch.load(os.path.join(saved_folder, 'unit_critic_{}.mod'.format(episode)))
if ddpg_agent.config.GPU >= 0:
ddpg_agent.unit_actor.cuda(device=ddpg_agent.config.GPU)
ddpg_agent.unit_actor_target.cuda(device=ddpg_agent.config.GPU)
ddpg_agent.unit_critic.cuda(device=ddpg_agent.config.GPU)
ddpg_agent.unit_critic_target.cuda(device=ddpg_agent.config.GPU)
ddpg_agent.unit_critic_optimizer = optim.Adam(ddpg_agent.unit_critic.parameters(), lr=ddpg_agent.config.CRITIC_LR)
ddpg_agent.unit_actor_optimizer = optim.Adam(ddpg_agent.unit_actor.parameters(), lr=ddpg_agent.config.ACTOR_LR)
return ddpg_agent
config = DefaultConfig()
np.random.seed(config.RANDOM_SEED)
torch.manual_seed(config.RANDOM_SEED)
if config.GPU >= 0 :
torch.cuda.manual_seed(config.RANDOM_SEED)
# for debug
# from hyperboard import Agent
# HBagent = Agent(username='jlb',password='123',address='127.0.0.1',port=5002)
#
# hp = deepcopy(config.todict())
# hp['mode'] = 'test_reward'
# test_record = HBagent.register(hp,'reward',overwrite=True)
# hp['mode'] = 'train_reward'
# train_r = HBagent.register(hp, 'reward',overwrite=True)
env = SimpleBattleEnv(config.ip,config.port,config.MYSELF_NUM,config.ENEMY_NUM,config.ACTION_DIM,config.DISTANCE_FACTOR,config.POSITION_RANGE,
config.SCREEN_BOX,config.DIE_REWARD,config.HEALTH_REWARD_WEIGHT,config.DONE_REWARD_WEIGHT,config.MY_HEALTH_WEIGHT,config.ENEMY_HEALTH_WEIGHT,
config.FRAME_SKIP,config.MAX_STEP,)
env.seed(config.RANDOM_SEED)
ddpg_agent = DDPG(env, config=config)
folder = 'HierarchicalNet(10Marines_vs_13Zerglings.scm)(01-07_14:58)hierarchical_3'
episode_num = 300
ddpg_agent = load_unit(ddpg_agent, folder, episode_num)
for episode in count(1):
print('\n',episode,ddpg_agent.epsilon)
obs = env.reset()
state = ddpg_agent.extract_state(obs)
cl_total,al_total,qe_total,qt_total = 0,0,0,0
rs = []
for step in range(config.MAX_STEP):
action,command = ddpg_agent.select_action(state,decay_e=True)
next_obs,reward,done,info = env.step(action)
rs.append(np.asarray(reward))
next_state = ddpg_agent.extract_state(next_obs)
ddpg_agent.append_memory(state,command,action,next_state,reward,not done)
ddpg_agent.train_unit()
ddpg_agent.train_commander()
if done:
qs = []
q = np.zeros((config.MYSELF_NUM))
total_reward = 0
for r in rs[::-1]:
q = r + config.GAMMA*q
total_reward += r.sum()/config.MYSELF_NUM
qs.append(q)
qs = np.asarray(qs)
q_mean = np.mean(qs)
ddpg_agent.train_record.append((episode,total_reward))
print('memory: {}/{}'.format(ddpg_agent.commander_memory.current_index,ddpg_agent.commander_memory.max_len))
print('q_mean: ',q_mean)
print('train_reward',total_reward)
# HBagent.append(train_r,episode,total_reward)
break
state = next_state
if episode % config.TEST_ITERVAL == 0:
print('\ntest (no noise)\n')
test_reward,_1,_2 = ddpg_agent.test(episode,config.TEST_NUM)
# HBagent.append(test_record,episode,test_reward)
if episode % config.SAVE_ITERVAL == 0:
print('\nsave model\n')
ddpg_agent.save(episode)