diff --git a/first.py b/first.py index 5ce0358..d826de3 100644 --- a/first.py +++ b/first.py @@ -12,7 +12,9 @@ env = gym.make(env_name) # train_envs = gym.make('CartPole-v0') # test_envs = gym.make('CartPole-v0') - +device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu") # using Gpu +print(device) +device = 'cpu' train_envs = ts.env.DummyVectorEnv([lambda: gym.make(env_name) for _ in range(10)]) test_envs = ts.env.DummyVectorEnv([lambda: gym.make(env_name) for _ in range(100)]) @@ -28,7 +30,7 @@ def __init__(self, state_shape, action_shape): def forward(self, obs, state=None, info={}): if not isinstance(obs, torch.Tensor): - obs = torch.tensor(obs, dtype=torch.float) + obs = torch.tensor(obs,device=device, dtype=torch.float) batch = obs.shape[0] logits = self.model(obs.view(batch, -1)) return logits, state @@ -46,7 +48,7 @@ def __init__(self, state_shape, action_shape): def forward(self, obs, state=None, info={}): if not isinstance(obs, torch.Tensor): - obs = torch.tensor(obs, dtype=torch.float) + obs = torch.tensor(obs,device=device, dtype=torch.float) batch = obs.shape[0] logits = self.model(obs.view(batch, -1)) return logits, state @@ -54,8 +56,8 @@ def forward(self, obs, state=None, info={}): state_shape = env.observation_space.shape or env.observation_space.n # (4,) action_shape = env.action_space.shape or env.action_space.n # 2 -net = TeacherNet(state_shape, action_shape) -net_student = StudentNet(state_shape, action_shape) +net = TeacherNet(state_shape, action_shape).to(device) +net_student = StudentNet(state_shape, action_shape).to(device) optim = torch.optim.Adam(net.parameters(), lr=1e-3) optim_student = torch.optim.Adam(net_student.parameters(), lr=1e-3) @@ -85,12 +87,10 @@ def update_student(): sample_size = 10 if len(train_collector.buffer) > sample_size: batch, indice = train_collector.buffer.sample(sample_size) - - # input = Batch(obs=Batch(obs=obs,mask=mask)) teacher = teacher_policy.forward(batch) student = student_policy.forward(batch) - stds = torch.from_numpy(np.array([1e-6] * len(teacher.logits[0]))) + stds = torch.tensor([1e-6] * len(teacher.logits[0]), device=device,dtype=torch.float) stds = torch.stack([stds for _ in range(len(teacher.logits))]) loss = get_kl([teacher.logits, stds], [student.logits, stds]) student_policy.optim.zero_grad() diff --git a/log/dqn/events.out.tfevents.1628170692.WIN-659P95VJIVL.22004.0 b/log/dqn/events.out.tfevents.1628170692.WIN-659P95VJIVL.22004.0 new file mode 100644 index 0000000..c20d088 Binary files /dev/null and b/log/dqn/events.out.tfevents.1628170692.WIN-659P95VJIVL.22004.0 differ diff --git a/log/dqn/events.out.tfevents.1628170712.WIN-659P95VJIVL.28372.0 b/log/dqn/events.out.tfevents.1628170712.WIN-659P95VJIVL.28372.0 new file mode 100644 index 0000000..7f4236c Binary files /dev/null and b/log/dqn/events.out.tfevents.1628170712.WIN-659P95VJIVL.28372.0 differ diff --git a/log/dqn/events.out.tfevents.1628170993.WIN-659P95VJIVL.4984.0 b/log/dqn/events.out.tfevents.1628170993.WIN-659P95VJIVL.4984.0 new file mode 100644 index 0000000..61e445a Binary files /dev/null and b/log/dqn/events.out.tfevents.1628170993.WIN-659P95VJIVL.4984.0 differ diff --git a/log/dqn/events.out.tfevents.1628171102.WIN-659P95VJIVL.15064.0 b/log/dqn/events.out.tfevents.1628171102.WIN-659P95VJIVL.15064.0 new file mode 100644 index 0000000..b2822e9 Binary files /dev/null and b/log/dqn/events.out.tfevents.1628171102.WIN-659P95VJIVL.15064.0 differ diff --git a/td3.py b/td3.py index f7f1ab0..3b00230 100644 --- a/td3.py +++ b/td3.py @@ -1,6 +1,4 @@ # 作者:vincent -# code time:2021/8/5 下午8:12 -# 作者:vincent # code time:2021/7/26 下午8:05 import gym import tianshou as ts @@ -8,12 +6,17 @@ from tianshou.data import Batch from torch import nn from utils import get_kl -env = gym.make('CartPole-v0') + +env_name = 'Breakout-v0' +# env = gym.make('CartPole-v0') +env = gym.make(env_name) # train_envs = gym.make('CartPole-v0') # test_envs = gym.make('CartPole-v0') +device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu") # using Gpu +print(device, 'env reward:', env.spec.reward_threshold) -train_envs = ts.env.DummyVectorEnv([lambda: gym.make('CartPole-v0') for _ in range(10)]) -test_envs = ts.env.DummyVectorEnv([lambda: gym.make('CartPole-v0') for _ in range(100)]) +train_envs = ts.env.DummyVectorEnv([lambda: gym.make(env_name) for _ in range(10)]) +test_envs = ts.env.DummyVectorEnv([lambda: gym.make(env_name) for _ in range(100)]) class TeacherNet(nn.Module): def __init__(self, state_shape, action_shape): @@ -27,7 +30,7 @@ def __init__(self, state_shape, action_shape): def forward(self, obs, state=None, info={}): if not isinstance(obs, torch.Tensor): - obs = torch.tensor(obs, dtype=torch.float) + obs = torch.tensor(obs,device=device, dtype=torch.float) batch = obs.shape[0] logits = self.model(obs.view(batch, -1)) return logits, state @@ -45,7 +48,7 @@ def __init__(self, state_shape, action_shape): def forward(self, obs, state=None, info={}): if not isinstance(obs, torch.Tensor): - obs = torch.tensor(obs, dtype=torch.float) + obs = torch.tensor(obs,device=device, dtype=torch.float) batch = obs.shape[0] logits = self.model(obs.view(batch, -1)) return logits, state @@ -53,8 +56,8 @@ def forward(self, obs, state=None, info={}): state_shape = env.observation_space.shape or env.observation_space.n # (4,) action_shape = env.action_space.shape or env.action_space.n # 2 -net = TeacherNet(state_shape, action_shape) -net_student = StudentNet(state_shape, action_shape) +net = TeacherNet(state_shape, action_shape).to(device) +net_student = StudentNet(state_shape, action_shape).to(device) optim = torch.optim.Adam(net.parameters(), lr=1e-3) optim_student = torch.optim.Adam(net_student.parameters(), lr=1e-3) @@ -84,12 +87,10 @@ def update_student(): sample_size = 10 if len(train_collector.buffer) > sample_size: batch, indice = train_collector.buffer.sample(sample_size) - - # input = Batch(obs=Batch(obs=obs,mask=mask)) teacher = teacher_policy.forward(batch) student = student_policy.forward(batch) - stds = torch.from_numpy(np.array([1e-6] * len(teacher.logits[0]))) + stds = torch.tensor([1e-6] * len(teacher.logits[0]), device=device,dtype=torch.float) stds = torch.stack([stds for _ in range(len(teacher.logits))]) loss = get_kl([teacher.logits, stds], [student.logits, stds]) student_policy.optim.zero_grad() @@ -108,7 +109,7 @@ def update_student(): train_fn=train_fn, update_student_fn=update_student, test_fn=lambda epoch, env_step: teacher_policy.set_eps(0.05), - stop_fn=lambda mean_rewards: mean_rewards >= env.spec.reward_threshold) + ) print(f'Finished training! Use {result["duration"]}') print(result)