-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutil_eval.py
93 lines (75 loc) · 3.04 KB
/
util_eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import numpy as np
from collections import defaultdict, deque
class BaseRawObsDictGenerator:
def __init__(self, *args, **kwargs):
self.last_obs_dict = None
def get_raw_obs_dict(self, state_info):
"""
Args:
state_info (dict): A dictionary of robot state + images
"""
obs_dict = {}
raise NotImplementedError
def load(self):
raise NotImplementedError
class RobotStateRawObsDictGenerator(BaseRawObsDictGenerator):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def load(self, obs_dict, key, value, check_valid=True):
"""
This is to check if the data is correct or not. Sometimes the data will be all zero depending on the networking conditions.
"""
if (
(
np.sum(np.abs(value)) == 0.0
and key in ["ee_states", "joint_states", "gripper_states"]
)
and check_valid
and self.last_obs_dict is not None
):
value = self.last_obs_dict[key]
obs_dict[key] = value
def get_raw_obs_dict(self, state_info):
last_state = state_info["last_state"]
last_gripper_state = state_info["last_gripper_state"]
obs_dict = {}
ee_states = np.array(last_state.O_T_EE)
joint_states = np.array(last_state.q)
gripper_states = np.array([last_gripper_state.width])
self.load(obs_dict, "ee_states", ee_states)
self.load(obs_dict, "joint_states", joint_states)
# Gripper widh will probably become zero
self.load(obs_dict, "gripper_states", gripper_states, check_valid=False)
for state in ["ee_states", "joint_states", "gripper_states"]:
if (
np.sum(np.abs(obs_dict[state])) <= 1e-6
and self.last_obs_dict is not None
):
print(f"{state} missing!!!!")
obs_dict[state] = self.last_obs_dict[state]
self.last_obs_dict = obs_dict
return obs_dict
class FrameStackForTrans:
def __init__(self, num_frames):
self.num_frames = num_frames
self.obs_history = {}
def reset(self, init_obs):
"""
init_obs: dict of initial observation at the start of the episode
return stacked obs by repeating the first observation num_frames times
"""
self.obs_history = {}
for k in init_obs:
self.obs_history[k] = deque([init_obs[k][None] for _ in range(self.num_frames)], maxlen=self.num_frames,)
obs = { k : np.concatenate(self.obs_history[k], axis=0) for k in self.obs_history }
return obs
def add_new_obs(self, new_obs):
"""
new_obs: dict of new observation at current timestep
return stacked obs
"""
for k in new_obs:
if 'timesteps' in k or 'actions' in k: continue
self.obs_history[k].append(new_obs[k][None])
obs= { k : np.concatenate(self.obs_history[k], axis=0) for k in self.obs_history }
return obs