Skip to content

Commit

Permalink
merge
Browse files Browse the repository at this point in the history
  • Loading branch information
FulChou committed Aug 6, 2021
1 parent 4ec7e12 commit 457f754
Show file tree
Hide file tree
Showing 6 changed files with 22 additions and 21 deletions.
16 changes: 8 additions & 8 deletions first.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
env = gym.make(env_name)
# train_envs = gym.make('CartPole-v0')
# test_envs = gym.make('CartPole-v0')

device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu") # using Gpu
print(device)
device = 'cpu'
train_envs = ts.env.DummyVectorEnv([lambda: gym.make(env_name) for _ in range(10)])
test_envs = ts.env.DummyVectorEnv([lambda: gym.make(env_name) for _ in range(100)])

Expand All @@ -28,7 +30,7 @@ def __init__(self, state_shape, action_shape):

def forward(self, obs, state=None, info={}):
if not isinstance(obs, torch.Tensor):
obs = torch.tensor(obs, dtype=torch.float)
obs = torch.tensor(obs,device=device, dtype=torch.float)
batch = obs.shape[0]
logits = self.model(obs.view(batch, -1))
return logits, state
Expand All @@ -46,16 +48,16 @@ def __init__(self, state_shape, action_shape):

def forward(self, obs, state=None, info={}):
if not isinstance(obs, torch.Tensor):
obs = torch.tensor(obs, dtype=torch.float)
obs = torch.tensor(obs,device=device, dtype=torch.float)
batch = obs.shape[0]
logits = self.model(obs.view(batch, -1))
return logits, state


state_shape = env.observation_space.shape or env.observation_space.n # (4,)
action_shape = env.action_space.shape or env.action_space.n # 2
net = TeacherNet(state_shape, action_shape)
net_student = StudentNet(state_shape, action_shape)
net = TeacherNet(state_shape, action_shape).to(device)
net_student = StudentNet(state_shape, action_shape).to(device)

optim = torch.optim.Adam(net.parameters(), lr=1e-3)
optim_student = torch.optim.Adam(net_student.parameters(), lr=1e-3)
Expand Down Expand Up @@ -85,12 +87,10 @@ def update_student():
sample_size = 10
if len(train_collector.buffer) > sample_size:
batch, indice = train_collector.buffer.sample(sample_size)


# input = Batch(obs=Batch(obs=obs,mask=mask))
teacher = teacher_policy.forward(batch)
student = student_policy.forward(batch)
stds = torch.from_numpy(np.array([1e-6] * len(teacher.logits[0])))
stds = torch.tensor([1e-6] * len(teacher.logits[0]), device=device,dtype=torch.float)
stds = torch.stack([stds for _ in range(len(teacher.logits))])
loss = get_kl([teacher.logits, stds], [student.logits, stds])
student_policy.optim.zero_grad()
Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
27 changes: 14 additions & 13 deletions td3.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,22 @@
# 作者:vincent
# code time:2021/8/5 下午8:12
# 作者:vincent
# code time:2021/7/26 下午8:05
import gym
import tianshou as ts
import torch, numpy as np
from tianshou.data import Batch
from torch import nn
from utils import get_kl
env = gym.make('CartPole-v0')

env_name = 'Breakout-v0'
# env = gym.make('CartPole-v0')
env = gym.make(env_name)
# train_envs = gym.make('CartPole-v0')
# test_envs = gym.make('CartPole-v0')
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu") # using Gpu
print(device, 'env reward:', env.spec.reward_threshold)

train_envs = ts.env.DummyVectorEnv([lambda: gym.make('CartPole-v0') for _ in range(10)])
test_envs = ts.env.DummyVectorEnv([lambda: gym.make('CartPole-v0') for _ in range(100)])
train_envs = ts.env.DummyVectorEnv([lambda: gym.make(env_name) for _ in range(10)])
test_envs = ts.env.DummyVectorEnv([lambda: gym.make(env_name) for _ in range(100)])

class TeacherNet(nn.Module):
def __init__(self, state_shape, action_shape):
Expand All @@ -27,7 +30,7 @@ def __init__(self, state_shape, action_shape):

def forward(self, obs, state=None, info={}):
if not isinstance(obs, torch.Tensor):
obs = torch.tensor(obs, dtype=torch.float)
obs = torch.tensor(obs,device=device, dtype=torch.float)
batch = obs.shape[0]
logits = self.model(obs.view(batch, -1))
return logits, state
Expand All @@ -45,16 +48,16 @@ def __init__(self, state_shape, action_shape):

def forward(self, obs, state=None, info={}):
if not isinstance(obs, torch.Tensor):
obs = torch.tensor(obs, dtype=torch.float)
obs = torch.tensor(obs,device=device, dtype=torch.float)
batch = obs.shape[0]
logits = self.model(obs.view(batch, -1))
return logits, state


state_shape = env.observation_space.shape or env.observation_space.n # (4,)
action_shape = env.action_space.shape or env.action_space.n # 2
net = TeacherNet(state_shape, action_shape)
net_student = StudentNet(state_shape, action_shape)
net = TeacherNet(state_shape, action_shape).to(device)
net_student = StudentNet(state_shape, action_shape).to(device)

optim = torch.optim.Adam(net.parameters(), lr=1e-3)
optim_student = torch.optim.Adam(net_student.parameters(), lr=1e-3)
Expand Down Expand Up @@ -84,12 +87,10 @@ def update_student():
sample_size = 10
if len(train_collector.buffer) > sample_size:
batch, indice = train_collector.buffer.sample(sample_size)


# input = Batch(obs=Batch(obs=obs,mask=mask))
teacher = teacher_policy.forward(batch)
student = student_policy.forward(batch)
stds = torch.from_numpy(np.array([1e-6] * len(teacher.logits[0])))
stds = torch.tensor([1e-6] * len(teacher.logits[0]), device=device,dtype=torch.float)
stds = torch.stack([stds for _ in range(len(teacher.logits))])
loss = get_kl([teacher.logits, stds], [student.logits, stds])
student_policy.optim.zero_grad()
Expand All @@ -108,7 +109,7 @@ def update_student():
train_fn=train_fn,
update_student_fn=update_student,
test_fn=lambda epoch, env_step: teacher_policy.set_eps(0.05),
stop_fn=lambda mean_rewards: mean_rewards >= env.spec.reward_threshold)
)
print(f'Finished training! Use {result["duration"]}')
print(result)

Expand Down

0 comments on commit 457f754

Please sign in to comment.