forked from ChrisCummins/CompilerGym
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathjust_keep_going_env.py
74 lines (60 loc) · 2.24 KB
/
just_keep_going_env.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
from compiler_gym.wrappers import CompilerEnvWrapper
logger = logging.getLogger(__name__)
# TODO(github.com/facebookresearch/CompilerGym/issues/469): Once step() and
# reset() no longer raise exceptions than this wrapper class can be removed.
class JustKeepGoingEnv(CompilerEnvWrapper):
"""This wrapper class prevents the step() and close() methods from raising
an exception.
Just keep swimming ...
|\\ o
| \\ o
|\\ / .\\ o
| | (
|/\\ /
| /
|/
Usage:
>>> env = compiler_gym.make("llvm-v0")
>>> env = JustKeepGoingEnv(env)
# enjoy ...
"""
def step(self, *args, **kwargs):
try:
return self.env.step(*args, **kwargs)
except Exception as e: # pylint: disable=broad-except
logger.warning("step() error: %s", e)
# Return "null" observation / reward.
default_observation = (
self.env.observation_space_spec.default_value
if self.env.observation_space
else None
)
default_reward = (
float(
self.env.reward_space_spec.reward_on_error(self.env.episode_reward)
)
if self.env.reward_space
else None
)
self.close()
return default_observation, default_reward, True, {"error_details": str(e)}
def reset(self, *args, **kwargs):
for _ in range(5):
try:
return super().reset(*args, **kwargs)
except Exception as e: # pylint: disable=broad-except
logger.warning("reset() error, retrying: %s", e)
self.close()
return self.reset(*args, **kwargs)
# No more retries.
return super().reset(*args, **kwargs)
def close(self):
try:
self.env.close()
except Exception as e: # pylint: disable=broad-except
logger.warning("Ignoring close() error: %s", e)