-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrl_utils.py
35 lines (30 loc) · 999 Bytes
/
rl_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import random
import numpy as np
import torch
import json
import matplotlib.pyplot as plt
from pathlib import Path
def all_seed(env, seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
# 如果环境支持设置种子,则也为环境设置种子
if hasattr(env, 'seed'):
env.seed(seed)
def save_args(args, path):
with open(Path(path) / "args.json", 'w') as f:
json.dump(vars(args), f, indent=4)
def save_results(results, tag, path):
with open(Path(path) / f"{tag}_results.json", 'w') as f:
json.dump(results, f, indent=4)
def plot_rewards(rewards, arg_dict, path, tag="train"):
plt.figure(figsize=(10, 5))
plt.plot(rewards, label="Rewards")
plt.xlabel("Episode")
plt.ylabel("Reward")
plt.title(f"{tag.capitalize()} Rewards Over Time")
plt.legend()
plt.savefig(Path(path) / f"{tag}_rewards_plot.png")
plt.close()