-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathdata_process.py
189 lines (102 loc) · 4.52 KB
/
data_process.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import numpy as np
import os
from myrllib.utils.myplot import simple_plot
import scipy.io as sio
import matplotlib.pyplot as plt
import scipy.stats as st
from sklearn.manifold import TSNE
###############################################################################
### DOMAIN can be 'navi_v1', 'navi_v2', 'navi_v3', 'hopper', 'cheetah', 'ant'
np.random.seed(950418)
DOMAIN = 'navi_v1'
p_output = 'output/%s'%DOMAIN
p_model = 'saves/%s'%DOMAIN
###############################################################################
### visualize the clustering results in domains: navi_v1, hopper, cheetah, ant
def task_clustering():
assert DOMAIN in ['navi_v1', 'hopper', 'cheetah', 'ant']
info = np.load(os.path.join(p_model, 'task_info.npy'))
tasks = info[:, :-1]
if DOMAIN in ['hopper', 'cheetah', 'ant']:
tasks = np.concatenate([tasks, np.random.uniform(0.4, 0.6, size=tasks.shape)], axis=1)
plt.figure(figsize=(6,2),dpi=200)
else:
plt.figure(figsize=(3,3),dpi=200)
info = np.concatenate((tasks, info[:, -1].reshape(-1,1)), axis=1)
clusters = max(info[:, -1])
points = np.zeros(int(clusters))
print('num of clusters:', int(clusters))
for idx in range(info.shape[0]):
item = info[idx][-3:]
points[int(item[-1])-1] += 1
S = 50
if item[-1] == 1:
plt.scatter(item[0], item[1], marker='x', color='k', s=S)
elif item[-1] == 2:
plt.scatter(item[0], item[1], marker='+', color='r', s=S)
elif item[-1] == 3:
plt.scatter(item[0], item[1], marker='s', color='b', s=S)
elif item[-1] == 4:
plt.scatter(item[0], item[1], marker='^', color='g', s=S)
elif item[-1] == 5:
plt.scatter(item[0], item[1], marker='*', color='m', s=S)
elif item[-1] == 6:
plt.scatter(item[0], item[1], marker='o', color='c', s=S)
if DOMAIN in ['navi_v1']:
plt.grid(axis='x', ls='--')
plt.grid(axis='y', ls='--')
tick_real = [-0.5, -0.25, 0, 0.25, 0.5]
tick_show = [-0.5, '', '', '', 0.5]
tick_show_y = ['', '', '', '', 0.5]
plt.yticks(tick_real, tick_show_y, fontsize=12)
plt.xticks(tick_real, tick_show, fontsize=12)
plt.axis([-0.5, 0.5, -0.5, 0.5])
elif DOMAIN in ['hopper', 'cheetah', 'ant']:
plt.xlabel('Goal velocity', fontsize=16)
plt.grid(axis='x', ls='--')
plt.yticks([])
if DOMAIN == 'hopper':
plt.xticks([0, 0.25, 0.5, 0.75, 1], fontsize=12)
plt.axis([0, 1, 0.3, 0.7])
elif DOMAIN == 'cheetah':
plt.xticks([0, 0.5, 1, 1.5, 2], fontsize=12)
plt.axis([0, 2, 0.3, 0.7])
elif DOMAIN == 'ant':
plt.xticks([0, 0.1, 0.2, 0.3, 0.4, 0.5], fontsize=12)
plt.axis([0, 0.5, 0.3, 0.7])
print(points)
return info
info = task_clustering()
###############################################################################
def perforamnce_comparison():
rews_llirl = np.load(os.path.join(p_output, 'rews_llirl.npy'))
rews_ca = np.load(os.path.join(p_output, 'rews_ca.npy'))
rews_robust = np.load(os.path.join(p_output, 'rews_robust.npy'))
rews_adapt = np.load(os.path.join(p_output, 'rews_adapt.npy'))
rews_maml = np.load(os.path.join(p_output, 'rews_maml.npy'))
if DOMAIN in ['navi_v1', 'navi_v2', 'navi_v3']:
cutoff = 100
elif DOMAIN in ['hopper']:
cutoff = 200
elif DOMAIN in ['cheetah', 'ant']:
cutoff = 500
rews_llirl = rews_llirl[:, :cutoff].mean(axis=0)
rews_ca = rews_ca[:, :cutoff].mean(axis=0)
rews_robust = rews_robust[:, :cutoff].mean(axis=0)
rews_adapt = rews_adapt[:, :cutoff].mean(axis=0)
rews_maml = rews_maml[:, :cutoff].mean(axis=0)
print('Return CA:', rews_ca.mean())
print('Return Robust:', rews_robust.mean())
print('Return Adaptive:', rews_adapt.mean())
print('Return Maml:', rews_maml.mean())
print('Return LIRL:', rews_llirl.mean())
xx = np.arange(rews_llirl.shape[0])
plt.figure()
plt.plot(xx, rews_ca, xx, rews_robust, xx, rews_adapt, xx, rews_maml, xx, rews_llirl)
plt.legend(['CA', 'Robust', 'Adaptive', 'MAML', 'LLIRL'])
plt.xlabel('Learning episode', fontsize=16)
plt.ylabel('Return', fontsize=16)
return (rews_ca, rews_robust, rews_adapt, rews_maml, rews_llirl)
data = perforamnce_comparison()