diff --git a/compiler_gym/BUILD b/compiler_gym/BUILD index 469189bc7..69964de61 100644 --- a/compiler_gym/BUILD +++ b/compiler_gym/BUILD @@ -12,7 +12,6 @@ py_library( srcs = ["__init__.py"], visibility = ["//visibility:public"], deps = [ - ":random_replay", ":random_search", ":validate", "//compiler_gym/bin", @@ -38,17 +37,6 @@ py_library( ], ) -py_library( - name = "random_replay", - srcs = ["random_replay.py"], - visibility = ["//visibility:public"], - deps = [ - ":random_search", - "//compiler_gym/envs", - "//compiler_gym/util", - ], -) - py_library( name = "random_search", srcs = ["random_search.py"], diff --git a/compiler_gym/bin/manual_env.py b/compiler_gym/bin/manual_env.py index d90a689ed..980e9892d 100644 --- a/compiler_gym/bin/manual_env.py +++ b/compiler_gym/bin/manual_env.py @@ -716,7 +716,7 @@ def do_set_default_reward(self, arg): def do_commandline(self, arg): """Show the command line equivalent of the actions taken so far""" - print("$", self.env.commandline(), flush=True) + print("$", self.env.action_space.to_string(self.env.actions), flush=True) def do_stack(self, arg): """Show the environments on the stack. The current environment is the first shown.""" diff --git a/compiler_gym/datasets/dataset.py b/compiler_gym/datasets/dataset.py index d65f4a1e8..c3375178c 100644 --- a/compiler_gym/datasets/dataset.py +++ b/compiler_gym/datasets/dataset.py @@ -22,11 +22,6 @@ logger = logging.getLogger(__name__) -# NOTE(cummins): This is only required to prevent a name conflict with the now -# deprecated Dataset.logger attribute. This can be removed once the logger -# attribute is removed, scheduled for release 0.2.3. -_logger = logger - _DATASET_VERSION_PATTERN = r"[a-zA-z0-9-_]+-v(?P[0-9]+)" _DATASET_VERSION_RE = re.compile(_DATASET_VERSION_PATTERN) @@ -131,21 +126,6 @@ def __init__( def __repr__(self): return self.name - @property - @mark_deprecated( - version="0.2.1", - reason=( - "The `Dataset.logger` attribute is deprecated. All Dataset " - "instances share a logger named compiler_gym.datasets" - ), - ) - def logger(self) -> logging.Logger: - """The logger for this dataset. - - :type: logging.Logger - """ - return _logger - @property def name(self) -> str: """The name of the dataset. diff --git a/compiler_gym/datasets/uri.py b/compiler_gym/datasets/uri.py index 3c661fedb..c1aad1418 100644 --- a/compiler_gym/datasets/uri.py +++ b/compiler_gym/datasets/uri.py @@ -3,49 +3,11 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. """This module contains utility code for working with URIs.""" -import re from typing import Dict, List, Union from urllib.parse import ParseResult, parse_qs, urlencode, urlparse, urlunparse -from deprecated.sphinx import deprecated from pydantic import BaseModel -# === BEGIN DEPRECATED DECLARATIONS === -# -# The following regular expression definitions have been deprecated and will be -# removed in a future release! Please update your code to use the new -# BenchmarkUri class defined in this file. - -# Regular expression that matches the full two-part URI prefix of a dataset: -# {{scheme}}://{{dataset}} -# -# An optional trailing slash is permitted. -# -# Example matches: "benchmark://foo-v0", "generator://bar-v0/". -DATASET_NAME_PATTERN = r"(?P(?P[a-zA-z0-9-_]+)://(?P[a-zA-z0-9-_]+-v(?P[0-9]+)))/?" -DATASET_NAME_RE = re.compile(DATASET_NAME_PATTERN) - -# Regular expression that matches the full three-part format of a benchmark URI: -# {{sceme}}://{{dataset}}/{{id}} -# -# Example matches: "benchmark://foo-v0/foo" or "generator://bar-v1/foo/bar.txt". -BENCHMARK_URI_PATTERN = r"(?P(?P[a-zA-z0-9-_]+)://(?P[a-zA-z0-9-_]+-v(?P[0-9]+)))/(?P.+)$" -BENCHMARK_URI_RE = re.compile(BENCHMARK_URI_PATTERN) - -# === END DEPRECATED DECLARATIONS === - - -@deprecated( - version="0.2.2", - reason=("Use compiler_gym.datasets.BenchmarkUri.canonicalize()"), -) -def resolve_uri_protocol(uri: str) -> str: - """Require that the URI has a scheme by applying a default "benchmark" - scheme if none is set.""" - if "://" not in uri: - return f"benchmark://{uri}" - return uri - class BenchmarkUri(BaseModel): """A URI used to identify a benchmark, and optionally a set of parameters diff --git a/compiler_gym/envs/__init__.py b/compiler_gym/envs/__init__.py index f8b8829df..48cc50b20 100644 --- a/compiler_gym/envs/__init__.py +++ b/compiler_gym/envs/__init__.py @@ -2,6 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import compiler_gym.envs.loop_tool # noqa from compiler_gym import config from compiler_gym.envs.compiler_env import CompilerEnv from compiler_gym.envs.gcc import GccEnv @@ -11,14 +12,12 @@ if config.enable_mlir_env: from compiler_gym.envs.mlir.mlir_env import MlirEnv # noqa: F401 -from compiler_gym.envs.loop_tool.loop_tool_env import LoopToolEnv from compiler_gym.util.registration import COMPILER_GYM_ENVS __all__ = [ "COMPILER_GYM_ENVS", "CompilerEnv", "GccEnv", - "LoopToolEnv", ] if config.enable_llvm_env: diff --git a/compiler_gym/envs/compiler_env.py b/compiler_gym/envs/compiler_env.py index e26d288d8..49910e645 100644 --- a/compiler_gym/envs/compiler_env.py +++ b/compiler_gym/envs/compiler_env.py @@ -12,7 +12,7 @@ from compiler_gym.compiler_env_state import CompilerEnvState from compiler_gym.datasets import Benchmark, BenchmarkUri, Dataset -from compiler_gym.spaces import Reward +from compiler_gym.spaces import ActionSpace, Reward from compiler_gym.util.gym_type_hints import ( ActionType, ObservationType, @@ -182,14 +182,6 @@ def episode_reward(self, episode_reward: Optional[float]): def actions(self) -> List[ActionType]: raise NotImplementedError("abstract method") - @property - @abstractmethod - @deprecated( - version="0.2.1", - ) - def logger(self): - raise NotImplementedError("abstract method") - @property @abstractmethod def version(self) -> str: @@ -210,7 +202,7 @@ def state(self) -> CompilerEnvState: @property @abstractmethod - def action_space(self) -> Space: + def action_space(self) -> ActionSpace: """The current action space. :getter: Get the current action space. @@ -227,7 +219,7 @@ def action_space(self, action_space: Optional[str]): @property @abstractmethod - def action_spaces(self) -> List[str]: + def action_spaces(self) -> List[ActionSpace]: """A list of supported action space names.""" raise NotImplementedError("abstract method") @@ -481,7 +473,9 @@ def render( """ raise NotImplementedError("abstract method") - @abstractmethod + @deprecated( + version="0.2.5", reason="Use env.action_space.to_string(env.actions) instead" + ) def commandline(self) -> str: """Interface for :class:`CompilerEnv ` subclasses to provide an equivalent commandline invocation to the @@ -494,7 +488,9 @@ def commandline(self) -> str: """ raise NotImplementedError("abstract method") - @abstractmethod + @deprecated( + version="0.2.5", reason='Use env.action_space.from_string("...") instead' + ) def commandline_to_actions(self, commandline: str) -> List[ActionType]: """Interface for :class:`CompilerEnv ` subclasses to convert from a commandline invocation to a sequence of diff --git a/compiler_gym/envs/gcc/BUILD b/compiler_gym/envs/gcc/BUILD index 02c10f7ed..df8c760f5 100644 --- a/compiler_gym/envs/gcc/BUILD +++ b/compiler_gym/envs/gcc/BUILD @@ -19,7 +19,7 @@ py_library( deps = [ "//compiler_gym/envs/gcc/datasets", "//compiler_gym/errors", - "//compiler_gym/service:client_service_compiler_env", + "//compiler_gym/service:in_process_client_compiler_env", "//compiler_gym/service/runtime", # Implicit dependency of service. "//compiler_gym/util", ], diff --git a/compiler_gym/envs/gcc/CMakeLists.txt b/compiler_gym/envs/gcc/CMakeLists.txt index ecac8500f..3a035288f 100644 --- a/compiler_gym/envs/gcc/CMakeLists.txt +++ b/compiler_gym/envs/gcc/CMakeLists.txt @@ -16,7 +16,7 @@ cg_py_library( DATA compiler_gym::envs::gcc::service::service DEPS - compiler_gym::service::client_service_compiler_env + compiler_gym::service::in_process_client_compiler_env compiler_gym::envs::gcc::datasets::datasets compiler_gym::errors::errors compiler_gym::service::runtime::runtime diff --git a/compiler_gym/envs/gcc/gcc_env.py b/compiler_gym/envs/gcc/gcc_env.py index fa77837a1..5a9e66aff 100644 --- a/compiler_gym/envs/gcc/gcc_env.py +++ b/compiler_gym/envs/gcc/gcc_env.py @@ -9,12 +9,17 @@ from pathlib import Path from typing import Any, Dict, List, Optional, Union +from deprecated.sphinx import deprecated + from compiler_gym.datasets import Benchmark from compiler_gym.envs.gcc.datasets import get_gcc_datasets from compiler_gym.envs.gcc.gcc import Gcc, GccSpec from compiler_gym.envs.gcc.gcc_rewards import AsmSizeReward, ObjSizeReward +from compiler_gym.envs.gcc.service.gcc_service import make_gcc_compilation_session from compiler_gym.service import ConnectionOpts -from compiler_gym.service.client_service_compiler_env import ClientServiceCompilerEnv +from compiler_gym.service.in_process_client_compiler_env import ( + InProcessClientCompilerEnv, +) from compiler_gym.spaces import Reward from compiler_gym.util.decorators import memoized_property from compiler_gym.util.gym_type_hints import ObservationType, OptionalArgumentValue @@ -24,7 +29,7 @@ DEFAULT_GCC: str = "docker:gcc:11.2.0" -class GccEnv(ClientServiceCompilerEnv): +class GccEnv(InProcessClientCompilerEnv): """A specialized ClientServiceCompilerEnv for GCC. This class exposes the optimization space of GCC's command line flags @@ -79,6 +84,7 @@ def __init__( super().__init__( *args, **kwargs, + make_session=make_gcc_compilation_session(), benchmark=benchmark, datasets=get_gcc_datasets( gcc_bin=gcc_bin, site_data_base=datasets_site_path @@ -109,6 +115,10 @@ def reset( self.send_param("timeout", str(self._timeout)) return observation + @deprecated( + version="0.2.1", + reason="Use `env.observation.command_line()` instead", + ) def commandline(self) -> str: """Return a string representing the command line options. diff --git a/compiler_gym/envs/llvm/BUILD b/compiler_gym/envs/llvm/BUILD index cca35adba..73bd6c4c4 100644 --- a/compiler_gym/envs/llvm/BUILD +++ b/compiler_gym/envs/llvm/BUILD @@ -16,6 +16,7 @@ py_library( ":benchmark_from_command_line", ":compute_observation", ":llvm_benchmark", + ":llvm_command_line", ":llvm_env", "//compiler_gym/util", ], @@ -54,12 +55,22 @@ py_library( ], ) +py_library( + name = "llvm_command_line", + srcs = ["llvm_command_line.py"], + deps = [ + "//compiler_gym/spaces", + "//compiler_gym/util", + ], +) + py_library( name = "llvm_env", srcs = ["llvm_env.py"], deps = [ ":benchmark_from_command_line", ":llvm_benchmark", + ":llvm_command_line", ":llvm_rewards", "//compiler_gym/datasets", "//compiler_gym/envs/llvm/datasets", diff --git a/compiler_gym/envs/llvm/CMakeLists.txt b/compiler_gym/envs/llvm/CMakeLists.txt index eba85208a..8efc2ee9e 100644 --- a/compiler_gym/envs/llvm/CMakeLists.txt +++ b/compiler_gym/envs/llvm/CMakeLists.txt @@ -20,6 +20,7 @@ cg_py_library( ::benchmark_from_command_line ::compute_observation ::llvm_benchmark + ::llvm_command_line ::llvm_env compiler_gym::util::util PUBLIC @@ -56,6 +57,16 @@ cg_py_library( PUBLIC ) +cg_py_library( + NAME + llvm_command_line + SRCS + llvm_command_line.py + DEPS + compiler_gym::spaces::spaces + compiler_gym::util::util +) + cg_py_library( NAME llvm_env @@ -63,6 +74,7 @@ cg_py_library( "llvm_env.py" DEPS ::llvm_benchmark + ::llvm_command_line ::llvm_rewards compiler_gym::datasets::datasets compiler_gym::errors::errors diff --git a/compiler_gym/envs/llvm/__init__.py b/compiler_gym/envs/llvm/__init__.py index 535c524cc..a15ec39b2 100644 --- a/compiler_gym/envs/llvm/__init__.py +++ b/compiler_gym/envs/llvm/__init__.py @@ -13,6 +13,7 @@ get_system_library_flags, make_benchmark, ) +from compiler_gym.envs.llvm.llvm_command_line import LlvmCommandLine from compiler_gym.envs.llvm.llvm_env import LlvmEnv # TODO(github.com/facebookresearch/CompilerGym/issues/506): Tidy up. @@ -24,6 +25,7 @@ __all__ = [ "BenchmarkFromCommandLine", + "LlvmCommandLine", "ClangInvocation", "compute_observation", "get_system_library_flags", diff --git a/compiler_gym/envs/llvm/llvm_command_line.py b/compiler_gym/envs/llvm/llvm_command_line.py new file mode 100644 index 000000000..9449e1af7 --- /dev/null +++ b/compiler_gym/envs/llvm/llvm_command_line.py @@ -0,0 +1,40 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from typing import List + +from compiler_gym.spaces import ActionSpace +from compiler_gym.util.gym_type_hints import ActionType + + +class LlvmCommandLine(ActionSpace): + """An action space for LLVM that supports serializing / deserializing to + opt command line. + """ + + def to_string(self, actions: List[ActionType]) -> str: + """Returns an LLVM :code:`opt` command line invocation for the given actions. + + :param actions: A list of actions to serialize. + + :returns: A command line string. + """ + return f"opt {self.wrapped.to_string(actions)} input.bc -o output.bc" + + def from_string(self, string: str) -> List[ActionType]: + """Returns a list of actions from the given command line. + + :param commandline: A command line invocation. + + :return: A list of actions. + + :raises ValueError: In case the command line string is malformed. + """ + if string.startswith("opt "): + string = string[len("opt ") :] + + if string.endswith(" input.bc -o output.bc"): + string = string[: -len(" input.bc -o output.bc")] + + return self.wrapped.from_string(string) diff --git a/compiler_gym/envs/llvm/llvm_env.py b/compiler_gym/envs/llvm/llvm_env.py index bb09be681..ac83fdb33 100644 --- a/compiler_gym/envs/llvm/llvm_env.py +++ b/compiler_gym/envs/llvm/llvm_env.py @@ -10,7 +10,7 @@ import subprocess from pathlib import Path from tempfile import TemporaryDirectory -from typing import Iterable, List, Optional, Union, cast +from typing import Any, Callable, Iterable, List, Optional, Union import numpy as np @@ -22,14 +22,19 @@ get_system_library_flags, make_benchmark, ) +from compiler_gym.envs.llvm.llvm_command_line import LlvmCommandLine from compiler_gym.envs.llvm.llvm_rewards import ( BaselineImprovementNormalizedReward, CostFunctionReward, NormalizedReward, ) from compiler_gym.errors import BenchmarkInitError -from compiler_gym.service.client_service_compiler_env import ClientServiceCompilerEnv -from compiler_gym.spaces import Box, Commandline +from compiler_gym.service.client_service_compiler_env import ( + ClientServiceCompilerEnv, + ServiceMessageConverters, +) +from compiler_gym.service.proto.py_converters import make_message_default_converter +from compiler_gym.spaces import Box from compiler_gym.spaces import Dict as DictSpace from compiler_gym.spaces import Scalar, Sequence from compiler_gym.third_party.autophase import AUTOPHASE_FEATURE_NAMES @@ -65,6 +70,10 @@ def _get_llvm_datasets(site_data_base: Optional[Path] = None) -> Iterable[Datase return get_llvm_datasets(site_data_base=site_data_base) +def make_llvm_action_space_converter() -> Callable[[Any], LlvmCommandLine]: + return lambda msg: LlvmCommandLine(space=make_message_default_converter()(msg)) + + class LlvmEnv(ClientServiceCompilerEnv): """A specialized ClientServiceCompilerEnv for LLVM. @@ -73,9 +82,9 @@ class LlvmEnv(ClientServiceCompilerEnv): functionality. Specifically, the actions use the :class:`CommandlineFlag ` space, which is a type of :code:`Discrete` space that provides additional documentation about each - action, and the :meth:`LlvmEnv.commandline() - ` method can be used to produce an - equivalent LLVM opt invocation for the current environment state. + action, and the :meth:`env.action_space.to_string(...) + ` method can be used to + produce an equivalent LLVM opt invocation for the given actions. """ def __init__( @@ -95,6 +104,9 @@ def __init__( # Set a default benchmark for use. benchmark=benchmark or "cbench-v1/qsort", datasets=_get_llvm_datasets(site_data_base=datasets_site_path), + service_message_converters=ServiceMessageConverters( + action_space_converter=make_llvm_action_space_converter() + ), rewards=[ CostFunctionReward( name="IrInstructionCount", @@ -469,41 +481,6 @@ def make_benchmark( timeout=timeout, ) - def commandline( # pylint: disable=arguments-differ - self, textformat: bool = False - ) -> str: - """Returns an LLVM :code:`opt` command line invocation for the current - environment state. - - :param textformat: Whether to generate a command line that processes - text-format LLVM-IR or bitcode (the default). - :returns: A command line string. - """ - command = cast(Commandline, self.action_space).commandline(self.actions) - if textformat: - return f"opt {command} input.ll -S -o output.ll" - else: - return f"opt {command} input.bc -o output.bc" - - def commandline_to_actions(self, commandline: str) -> List[int]: - """Returns a list of actions from the given command line. - - :param commandline: A command line invocation, as generated by - :meth:`env.commandline() `. - :return: A list of actions. - :raises ValueError: In case the command line string is malformed. - """ - # Strip the decorative elements that LlvmEnv.commandline() adds. - if not commandline.startswith("opt "): - raise ValueError(f"Invalid commandline: `{commandline}`") - if commandline.endswith(" input.ll -S -o output.ll"): - commandline = commandline[len("opt ") : -len(" input.ll -S -o output.ll")] - elif commandline.endswith(" input.bc -o output.bc"): - commandline = commandline[len("opt ") : -len(" input.bc -o output.bc")] - else: - raise ValueError(f"Invalid commandline: `{commandline}`") - return self.action_space.from_commandline(commandline) - @property def ir(self) -> str: """Print the LLVM-IR of the program in its current state. diff --git a/compiler_gym/envs/loop_tool/BUILD b/compiler_gym/envs/loop_tool/BUILD index 61ad772e9..7b99c62b8 100644 --- a/compiler_gym/envs/loop_tool/BUILD +++ b/compiler_gym/envs/loop_tool/BUILD @@ -8,11 +8,10 @@ py_library( name = "loop_tool", srcs = [ "__init__.py", - "loop_tool_env.py", ], - data = ["//compiler_gym/envs/loop_tool/service"], visibility = ["//visibility:public"], deps = [ + "//compiler_gym/envs/loop_tool/service", "//compiler_gym/service", "//compiler_gym/service:client_service_compiler_env", "//compiler_gym/service/proto", diff --git a/compiler_gym/envs/loop_tool/CMakeLists.txt b/compiler_gym/envs/loop_tool/CMakeLists.txt index f04f82442..19884a221 100644 --- a/compiler_gym/envs/loop_tool/CMakeLists.txt +++ b/compiler_gym/envs/loop_tool/CMakeLists.txt @@ -9,10 +9,10 @@ cg_py_library( NAME loop_tool SRCS "__init__.py" - "loop_tool_env.py" - DATA compiler_gym::envs::loop_tool::service::service + DEPS - compiler_gym::service::client_service_compiler_env + compiler_gym::envs::loop_tool::service::service + compiler_gym::service::in_process_client_compiler_env compiler_gym::service::service compiler_gym::service::proto::proto compiler_gym::service::runtime::runtime diff --git a/compiler_gym/envs/loop_tool/__init__.py b/compiler_gym/envs/loop_tool/__init__.py index 69532fcf4..0c73b0a66 100644 --- a/compiler_gym/envs/loop_tool/__init__.py +++ b/compiler_gym/envs/loop_tool/__init__.py @@ -3,18 +3,15 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. """Register the loop_tool environment and reward.""" -from pathlib import Path from typing import Iterable from compiler_gym.datasets import Benchmark, Dataset, benchmark from compiler_gym.datasets.uri import BenchmarkUri +from compiler_gym.envs.loop_tool.service.loop_tool_compilation_session import ( + LoopToolCompilationSession, +) from compiler_gym.spaces import Reward from compiler_gym.util.registration import register -from compiler_gym.util.runfiles_path import runfiles_path - -LOOP_TOOL_SERVICE_BINARY: Path = runfiles_path( - "compiler_gym/envs/loop_tool/service/compiler_gym-loop_tool-service" -) class FLOPSReward(Reward): @@ -83,12 +80,12 @@ def benchmark_from_parsed_uri(self, uri: BenchmarkUri) -> Benchmark: register( id="loop_tool-v0", - entry_point="compiler_gym.envs.loop_tool.loop_tool_env:LoopToolEnv", + entry_point="compiler_gym.service.in_process_client_compiler_env:InProcessClientCompilerEnv", kwargs={ + "session_type": LoopToolCompilationSession, "datasets": [LoopToolCPUDataset(), LoopToolCUDADataset()], "observation_space": "action_state", "reward_space": "flops", "rewards": [FLOPSReward()], - "service": LOOP_TOOL_SERVICE_BINARY, }, ) diff --git a/compiler_gym/envs/loop_tool/loop_tool_env.py b/compiler_gym/envs/loop_tool/loop_tool_env.py deleted file mode 100644 index e21d04208..000000000 --- a/compiler_gym/envs/loop_tool/loop_tool_env.py +++ /dev/null @@ -1,10 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. -from compiler_gym.service.client_service_compiler_env import ClientServiceCompilerEnv - - -class LoopToolEnv(ClientServiceCompilerEnv): - def commandline(self): - return ",".join(str(x) for x in self.actions) diff --git a/compiler_gym/envs/loop_tool/service/BUILD b/compiler_gym/envs/loop_tool/service/BUILD index b26213966..c76431681 100644 --- a/compiler_gym/envs/loop_tool/service/BUILD +++ b/compiler_gym/envs/loop_tool/service/BUILD @@ -3,10 +3,9 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -filegroup( +py_library( name = "service", srcs = [ - "compiler_gym-loop_tool-service", "loop_tool_compilation_session.py", ], visibility = ["//compiler_gym/envs/loop_tool:__subpackages__"], diff --git a/compiler_gym/envs/loop_tool/service/CMakeLists.txt b/compiler_gym/envs/loop_tool/service/CMakeLists.txt index aba317c95..20370f7c9 100644 --- a/compiler_gym/envs/loop_tool/service/CMakeLists.txt +++ b/compiler_gym/envs/loop_tool/service/CMakeLists.txt @@ -5,9 +5,8 @@ cg_add_all_subdirs() -cg_filegroup( +cg_py_library( NAME service FILES - "${CMAKE_CURRENT_LIST_DIR}/compiler_gym-loop_tool-service" "${CMAKE_CURRENT_LIST_DIR}/loop_tool_compilation_session.py" ) diff --git a/compiler_gym/envs/loop_tool/service/compiler_gym-loop_tool-service b/compiler_gym/envs/loop_tool/service/compiler_gym-loop_tool-service deleted file mode 100755 index 50fafab40..000000000 --- a/compiler_gym/envs/loop_tool/service/compiler_gym-loop_tool-service +++ /dev/null @@ -1,13 +0,0 @@ -#! /usr/bin/env python3 -# -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. -from compiler_gym.envs.loop_tool.service.loop_tool_compilation_session import ( - LoopToolCompilationSession, -) -from compiler_gym.service.runtime import create_and_run_compiler_gym_service - -if __name__ == "__main__": - create_and_run_compiler_gym_service(LoopToolCompilationSession) diff --git a/compiler_gym/random_replay.py b/compiler_gym/random_replay.py deleted file mode 100644 index 81be67ba2..000000000 --- a/compiler_gym/random_replay.py +++ /dev/null @@ -1,28 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. -"""Replay the sequence of actions that produced the best reward.""" -from pathlib import Path -from typing import List - -from deprecated import deprecated - -from compiler_gym.envs.compiler_env import CompilerEnv -from compiler_gym.random_search import replay_actions as replay_actions_ -from compiler_gym.random_search import ( - replay_actions_from_logs as replay_actions_from_logs_, -) - - -@deprecated(version="0.2.1", reason="Use env.step(action) instead") -def replay_actions(env: CompilerEnv, action_names: List[str], outdir: Path): - return replay_actions_(env, action_names, outdir) - - -@deprecated( - version="0.2.1", - reason="Use compiler_gym.random_search.replay_actions_from_logs() instead", -) -def replay_actions_from_logs(env: CompilerEnv, logdir: Path, benchmark=None) -> None: - return replay_actions_from_logs_(env, logdir, benchmark) diff --git a/compiler_gym/random_search.py b/compiler_gym/random_search.py index 903c41bee..a74146f93 100644 --- a/compiler_gym/random_search.py +++ b/compiler_gym/random_search.py @@ -135,7 +135,7 @@ def run_one_episode(self, env: CompilerEnv) -> bool: self.best_returns = total_returns self.best_actions = actions.copy() try: - self.best_commandline = env.commandline() + self.best_commandline = env.action_space.to_string(env.actions) except NotImplementedError: self.best_commandline = "" self.best_found_at_time = time() diff --git a/compiler_gym/service/BUILD b/compiler_gym/service/BUILD index 194882571..905ec5784 100644 --- a/compiler_gym/service/BUILD +++ b/compiler_gym/service/BUILD @@ -76,6 +76,25 @@ py_library( ], ) +py_library( + name = "in_process_client_compiler_env", + srcs = ["in_process_client_compiler_env.py"], + visibility = ["//compiler_gym:__subpackages__"], + deps = [ + ":compilation_session", + ":connection", + "//compiler_gym:compiler_env_state", + "//compiler_gym:validation_result", + "//compiler_gym/datasets", + "//compiler_gym/envs:compiler_env", + "//compiler_gym/errors", + "//compiler_gym/service/proto", + "//compiler_gym/spaces", + "//compiler_gym/util", + "//compiler_gym/views", + ], +) + py_library( name = "service_cache", srcs = ["service_cache.py"], diff --git a/compiler_gym/service/client_service_compiler_env.py b/compiler_gym/service/client_service_compiler_env.py index 4f48b80b8..771f4575b 100644 --- a/compiler_gym/service/client_service_compiler_env.py +++ b/compiler_gym/service/client_service_compiler_env.py @@ -15,7 +15,6 @@ from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union import numpy as np -from deprecated.sphinx import deprecated from gym.spaces import Space from compiler_gym.compiler_env_state import CompilerEnvState @@ -31,7 +30,8 @@ ValidationError, ) from compiler_gym.service import CompilerGymServiceConnection, ConnectionOpts -from compiler_gym.service.proto import ActionSpace, AddBenchmarkRequest +from compiler_gym.service.proto import ActionSpace as ActionSpaceProto +from compiler_gym.service.proto import AddBenchmarkRequest from compiler_gym.service.proto import Benchmark as BenchmarkProto from compiler_gym.service.proto import ( EndSessionReply, @@ -49,7 +49,8 @@ StepRequest, py_converters, ) -from compiler_gym.spaces import DefaultRewardFromObservation, NamedDiscrete, Reward +from compiler_gym.spaces import ActionSpace, DefaultRewardFromObservation, Reward +from compiler_gym.util.decorators import memoized_property from compiler_gym.util.gym_type_hints import ( ActionType, ObservationType, @@ -64,18 +65,13 @@ logger = logging.getLogger(__name__) -# NOTE(cummins): This is only required to prevent a name conflict with the now -# deprecated ClientServiceCompilerEnv.logger attribute. This can be removed once the logger -# attribute is removed, scheduled for release 0.2.3. -_logger = logger - def _wrapped_step( - service: CompilerGymServiceConnection, request: StepRequest + service: CompilerGymServiceConnection, request: StepRequest, timeout: float ) -> StepReply: """Call the Step() RPC endpoint.""" try: - return service(service.stub.Step, request) + return service(service.stub.Step, request, timeout=timeout) except FileNotFoundError as e: if str(e).startswith("Session not found"): raise SessionNotFound(str(e)) @@ -95,17 +91,21 @@ class ServiceMessageConverters: :code:`compiler_gym.service.proto.Event`. """ - action_space_converter: Callable[[ActionSpace], Space] + action_space_converter: Callable[[ActionSpaceProto], ActionSpace] action_converter: Callable[[ActionType], Event] def __init__( self, - action_space_converter: Optional[Callable[[ActionSpace], Space]] = None, + action_space_converter: Optional[ + Callable[[ActionSpaceProto], ActionSpace] + ] = None, action_converter: Optional[Callable[[Any], Event]] = None, ): """Constructor.""" self.action_space_converter = ( - py_converters.make_message_default_converter() + py_converters.make_action_space_wrapper( + py_converters.make_message_default_converter() + ) if action_space_converter is None else action_space_converter ) @@ -138,8 +138,6 @@ def __init__( service_message_converters: ServiceMessageConverters = None, connection_settings: Optional[ConnectionOpts] = None, service_connection: Optional[CompilerGymServiceConnection] = None, - logger: Optional[logging.Logger] = None, - timeout: Optional[float] = 300, ): """Construct and initialize a CompilerGym environment. @@ -195,25 +193,12 @@ def __init__( :param service_connection: An existing compiler gym service connection to use. - :param timeout: The maximum number of seconds to wait for an RPC method - call to succeed. Accepts a float value. The default is 300 seconds. - :raises FileNotFoundError: If service is a path to a file that is not found. :raises TimeoutError: If the compiler service fails to initialize within the parameters provided in :code:`connection_settings`. """ - # NOTE(cummins): Logger argument deprecated and scheduled to be removed - # in release 0.2.3. - if logger: - warnings.warn( - "The `logger` argument is deprecated on ClientServiceCompilerEnv.__init__() " - "and will be removed in a future release. All ClientServiceCompilerEnv " - "instances share a logger named compiler_gym.service.client_service_compiler_env", - DeprecationWarning, - ) - self.metadata = {"render.modes": ["human", "ansi"]} # A compiler service supports multiple simultaneous environments. This @@ -231,8 +216,6 @@ def __init__( self.action_space_name = action_space - self._timeout = timeout - # If no reward space is specified, generate some from numeric observation spaces rewards = rewards or [ DefaultRewardFromObservation(obs.name) @@ -298,10 +281,7 @@ def __init__( # Register any derived observation spaces now so that the observation # space can be set below. for derived_observation_space in derived_observation_spaces or []: - self.observation.add_derived_space_internal(**derived_observation_space) - - # Lazily evaluated version strings. - self._versions: Optional[GetVersionReply] = None + self.observation.add_derived_space(**derived_observation_space) self.action_space: Optional[Space] = None self.observation_space: Optional[Space] = None @@ -364,25 +344,10 @@ def episode_reward(self, episode_reward: Optional[float]): def actions(self) -> List[ActionType]: return self._actions - @property - @deprecated( - version="0.2.1", - reason=( - "The `ClientServiceCompilerEnv.logger` attribute is deprecated. All ClientServiceCompilerEnv " - "instances share a logger named compiler_gym.service.client_service_compiler_env" - ), - ) - def logger(self): - return _logger - - @property + @memoized_property def versions(self) -> GetVersionReply: """Get the version numbers from the compiler service.""" - if self._versions is None: - self._versions = self.service( - self.service.stub.GetVersion, GetVersionRequest() - ) - return self._versions + return self.service(self.service.stub.GetVersion, GetVersionRequest()) @property def version(self) -> str: @@ -394,20 +359,6 @@ def compiler_version(self) -> str: """The version string of the underlying compiler that this service supports.""" return self.versions.compiler_version - def commandline(self) -> str: - """Calling this method on a :class:`ClientServiceCompilerEnv - ` instance raises - :code:`NotImplementedError`. - """ - raise NotImplementedError("abstract method") - - def commandline_to_actions(self, commandline: str) -> List[ActionType]: - """Calling this method on a :class:`ClientServiceCompilerEnv - ` instance raises - :code:`NotImplementedError`. - """ - raise NotImplementedError("abstract method") - @property def episode_walltime(self) -> float: return time() - self.episode_start_time @@ -418,22 +369,22 @@ def state(self) -> CompilerEnvState: benchmark=str(self.benchmark) if self.benchmark else None, reward=self.episode_reward, walltime=self.episode_walltime, - commandline=self.commandline(), + commandline=self.action_space.to_string(self.actions), ) @property - def action_space(self) -> Space: + def action_space(self) -> ActionSpace: return self._action_space @action_space.setter - def action_space(self, action_space: Optional[str]): + def action_space(self, action_space: Optional[str]) -> None: self.action_space_name = action_space index = ( [a.name for a in self.action_spaces].index(action_space) if self.action_space_name else 0 ) - self._action_space: NamedDiscrete = self.action_spaces[index] + self._action_space: ActionSpace = self.action_spaces[index] @property def action_spaces(self) -> List[str]: @@ -548,7 +499,6 @@ def _init_kwargs(self) -> Dict[str, Any]: "benchmark": self.benchmark, "connection_settings": self._connection_settings, "service": self._service_endpoint, - "timeout": self._timeout, } def fork(self) -> "ClientServiceCompilerEnv": @@ -860,9 +810,6 @@ def raw_step( :param rewards: A list of reward spaces to compute rewards from. These are evaluated after the actions are applied. - :param timeout: The maximum number of seconds to wait for an RPC method - call to succeed. Accepts a float value. The default is 300 seconds. - :return: A tuple of observations, rewards, done, and info. Observations and rewards are lists. @@ -894,8 +841,6 @@ def raw_step( for i, observation_space in enumerate(observations_to_compute) } - self._timeout = timeout - # Record the actions. self._actions += actions @@ -910,7 +855,7 @@ def raw_step( ], ) try: - reply = _wrapped_step(self.service, request) + reply = _wrapped_step(self.service, request, timeout) except ( ServiceError, ServiceTransportError, @@ -1034,7 +979,12 @@ def step( category=DeprecationWarning, ) reward_spaces = rewards - return self.multistep([action], observation_spaces, reward_spaces) + return self.multistep( + actions=[action], + observation_spaces=observation_spaces, + reward_spaces=reward_spaces, + timeout=timeout, + ) def multistep( self, @@ -1169,7 +1119,7 @@ def apply(self, state: CompilerEnvState) -> None: # noqa f"to environment for benchmark '{self.benchmark}'" ) - actions = self.commandline_to_actions(state.commandline) + actions = self.action_space.from_string(state.commandline) done = False for action in actions: _, _, done, info = self.step(action) diff --git a/compiler_gym/service/compilation_session.py b/compiler_gym/service/compilation_session.py index de59aeaa1..15c8a3c07 100644 --- a/compiler_gym/service/compilation_session.py +++ b/compiler_gym/service/compilation_session.py @@ -55,9 +55,9 @@ def __init__( :param benchmark: The benchmark to use. """ - del action_space # Subclasses must use this. - del benchmark # Subclasses must use this. self.working_dir = working_dir + self.action_space = action_space + self.benchmark = benchmark def apply_action(self, action: Action) -> Tuple[bool, Optional[ActionSpace], bool]: """Apply an action. diff --git a/compiler_gym/service/in_process_client_compiler_env.py b/compiler_gym/service/in_process_client_compiler_env.py new file mode 100644 index 000000000..13a368d8b --- /dev/null +++ b/compiler_gym/service/in_process_client_compiler_env.py @@ -0,0 +1,972 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""Contains an implementation of the :class:`CompilerEnv` +interface as a gRPC client service.""" +import logging +import numbers +import random +import shutil +import warnings +from copy import deepcopy +from datetime import datetime +from math import isclose +from pathlib import Path +from time import time +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, Union + +import numpy as np +from gym.spaces import Space + +from compiler_gym.compiler_env_state import CompilerEnvState +from compiler_gym.datasets import Benchmark, Dataset, Datasets +from compiler_gym.datasets.uri import BenchmarkUri +from compiler_gym.envs.compiler_env import CompilerEnv +from compiler_gym.errors import ValidationError +from compiler_gym.service import CompilationSession +from compiler_gym.service.proto import ActionSpace as ActionSpaceProto +from compiler_gym.service.proto import Benchmark as BenchmarkProto +from compiler_gym.service.proto import Event, GetVersionReply +from compiler_gym.service.proto import NamedDiscreteSpace as NamedDiscreteSpaceProto +from compiler_gym.service.proto import ObservationSpace as ObservationSpaceProto +from compiler_gym.service.proto import Space as SpaceProto +from compiler_gym.service.proto import py_converters +from compiler_gym.spaces import ( + ActionSpace, + DefaultRewardFromObservation, + NamedDiscrete, + Reward, +) +from compiler_gym.util.gym_type_hints import ( + ActionType, + ObservationType, + OptionalArgumentValue, + RewardType, + StepType, +) +from compiler_gym.util.runfiles_path import transient_cache_path +from compiler_gym.util.timer import Timer +from compiler_gym.util.version import __version__ +from compiler_gym.validation_result import ValidationResult +from compiler_gym.views import ObservationSpaceSpec, ObservationView, RewardView + +logger = logging.getLogger(__name__) + + +class ServiceMessageConverters: + """Allows for customization of conversion to/from gRPC messages for the + :class:`InProcessClientCompilerEnv + `. + + Supports conversion customizations: + + - :code:`compiler_gym.service.proto.ActionSpace` -> + :code:`gym.spaces.Space`. + - :code:`compiler_gym.util.gym_type_hints.ActionType` -> + :code:`compiler_gym.service.proto.Event`. + """ + + action_space_converter: Callable[[ActionSpaceProto], ActionSpace] + action_converter: Callable[[ActionType], Event] + + def __init__( + self, + action_space_converter: Optional[ + Callable[[ActionSpaceProto], ActionSpace] + ] = None, + action_converter: Optional[Callable[[Any], Event]] = None, + ): + """Constructor.""" + self.action_space_converter = ( + py_converters.make_action_space_wrapper( + py_converters.make_message_default_converter() + ) + if action_space_converter is None + else action_space_converter + ) + self.action_converter = ( + py_converters.to_event_message_default_converter() + if action_converter is None + else action_converter + ) + + +def make_working_directory(session_type: Type[CompilationSession]) -> Path: + random_hash = random.getrandbits(16) + timestamp = datetime.now().strftime(f"s/%m%dT%H%M%S-%f-{random_hash:04x}") + working_directory = transient_cache_path(f"s/{session_type.__name__}-{timestamp}") + logger.debug( + "Created working directory for compilation session: %s", working_directory + ) + return working_directory + + +def action_space_to_proto(action_space: ActionSpace) -> ActionSpaceProto: + # TODO(cummins): This needs to be a true reverse mapping from python to + # proto. Currently it's hardcoded to work only for named discrete spaces. + return ActionSpaceProto( + name=action_space.name, + space=SpaceProto( + named_discrete=NamedDiscreteSpaceProto(name=action_space.names) + ), + ) + + +class InProcessClientCompilerEnv(CompilerEnv): + """Implementation of :class:`CompilerEnv ` + for Python services that run in the same process. + + This uses the same protocol buffer interface as + :class:`InProcessClientCompilerEnv + `, but without the overhead + of running a gRPC service. The tradeoff is reduced robustness in the face of + compiler errors, and the inability to run the service on a different + machine. + """ + + def __init__( + self, + session_type: Type[CompilationSession], + session: Optional[CompilationSession] = None, + rewards: Optional[List[Reward]] = None, + datasets: Optional[Iterable[Dataset]] = None, + benchmark: Optional[Union[str, Benchmark]] = None, + observation_space: Optional[Union[str, ObservationSpaceSpec]] = None, + reward_space: Optional[Union[str, Reward]] = None, + action_space: Optional[str] = None, + derived_observation_spaces: Optional[List[Dict[str, Any]]] = None, + service_message_converters: ServiceMessageConverters = None, + ): + """Construct and initialize a CompilerGym environment. + + In normal use you should use :code:`gym.make(...)` rather than calling + the constructor directly. + + :param rewards: The reward spaces that this environment supports. + Rewards are typically calculated based on observations generated by + the service. See :class:`Reward ` for + details. + + :param benchmark: The benchmark to use for this environment. Either a + URI string, or a :class:`Benchmark + ` instance. If not provided, the + first benchmark as returned by + :code:`next(env.datasets.benchmarks())` will be used as the default. + + :param observation_space: Compute and return observations at each + :func:`step()` from this space. Accepts a string name or an + :class:`ObservationSpaceSpec + `. If not provided, + :func:`step()` returns :code:`None` for the observation value. Can + be set later using :meth:`env.observation_space + `. For available + spaces, see :class:`env.observation.spaces + `. + + :param reward_space: Compute and return reward at each :func:`step()` + from this space. Accepts a string name or a :class:`Reward + `. If not provided, :func:`step()` + returns :code:`None` for the reward value. Can be set later using + :meth:`env.reward_space + `. For available spaces, + see :class:`env.reward.spaces `. + + :param action_space: The name of the action space to use. If not + specified, the default action space for this compiler is used. + + :param derived_observation_spaces: An optional list of arguments to be + passed to :meth:`env.observation.add_derived_space() + `. + + :param service_message_converters: Custom converters for action spaces and actions. + + :param connection_settings: The settings used to establish a connection + with the remote service. + + :param service_connection: An existing compiler gym service connection + to use. + + :raises FileNotFoundError: If service is a path to a file that is not + found. + + :raises TimeoutError: If the compiler service fails to initialize within + the parameters provided in :code:`connection_settings`. + """ + self.session_type = session_type + + self.metadata = {"render.modes": ["human", "ansi"]} + + self._datasets = Datasets(datasets or []) + + self.action_space_name = action_space + + # If no reward space is specified, generate some from numeric observation spaces + rewards = rewards or [ + DefaultRewardFromObservation(obs.name) + for obs in self.session_type.observation_spaces + if obs.default_observation.WhichOneof("value") + and isinstance( + getattr( + obs.default_observation, obs.default_observation.WhichOneof("value") + ), + numbers.Number, + ) + ] + + # The benchmark that is currently being used, and the benchmark that + # will be used on the next call to reset(). These are equal except in + # the gap between the user setting the env.benchmark property while in + # an episode and the next call to env.reset(). + self._benchmark_in_use: Optional[Benchmark] = None + self._benchmark_in_use_proto: BenchmarkProto = BenchmarkProto() + self._next_benchmark: Optional[Benchmark] = None + # Normally when the benchmark is changed the updated value is not + # reflected until the next call to reset(). We make an exception for the + # constructor-time benchmark as otherwise the behavior of the benchmark + # property is counter-intuitive: + # + # >>> env = gym.make("example-v0", benchmark="foo") + # >>> env.benchmark + # None + # >>> env.reset() + # >>> env.benchmark + # "foo" + # + # By forcing the _benchmark_in_use URI at constructor time, the first + # env.benchmark above returns the benchmark as expected. + try: + self.benchmark = benchmark or next(self.datasets.benchmarks()) + self._benchmark_in_use = self._next_benchmark + except StopIteration: + # StopIteration raised on next(self.datasets.benchmarks()) if there + # are no benchmarks available. This is to allow InProcessClientCompilerEnv to be + # used without any datasets by setting a benchmark before/during the + # first reset() call. + pass + + self.service_message_converters = ( + ServiceMessageConverters() + if service_message_converters is None + else service_message_converters + ) + + # Process the available action, observation, and reward spaces. + self.action_spaces = [ + self.service_message_converters.action_space_converter(space) + for space in self.session_type.action_spaces + ] + + self.observation = self._observation_view_type( + raw_step=self.multistep, + spaces=self.session_type.observation_spaces, + ) + self.reward = self._reward_view_type(rewards, self.observation) + + # Register any derived observation spaces now so that the observation + # space can be set below. + for derived_observation_space in derived_observation_spaces or []: + self.observation.add_derived_space_internal(**derived_observation_space) + + self.action_space: Optional[Space] = None + self.observation_space: Optional[Space] = None + + # Mutable state initialized in reset(). + self._reward_range: Tuple[float, float] = (-np.inf, np.inf) + self.episode_reward = None + self.episode_start_time: float = time() + self._actions: List[ActionType] = [] + + # Initialize the default observation/reward spaces. + self.observation_space_spec = None + self.reward_space_spec = None + self.observation_space = observation_space + self.reward_space = reward_space + + self.working_directory: Optional[Path] = None + self.session: Optional[CompilationSession] = session + + def close(self): + if self.working_directory: + shutil.rmtree(self.working_directory, ignore_errors=True) + self.working_directory = None + + def __del__(self): + # Don't let the service be orphaned if user forgot to close(), or + # if an exception was thrown. The conditional guard is because this + # may be called in case of early error. + if hasattr(self, "service") and getattr(self, "service"): + self.close() + + @property + def observation_space_spec(self) -> ObservationSpaceSpec: + return self._observation_space_spec + + @observation_space_spec.setter + def observation_space_spec( + self, observation_space_spec: Optional[ObservationSpaceSpec] + ): + self._observation_space_spec = observation_space_spec + + @property + def observation(self) -> ObservationView: + return self._observation + + @observation.setter + def observation(self, observation: ObservationView) -> None: + self._observation = observation + + @property + def reward_space_spec(self) -> Optional[Reward]: + return self._reward_space_spec + + @reward_space_spec.setter + def reward_space_spec(self, val: Optional[Reward]): + self._reward_space_spec = val + + @property + def datasets(self) -> Iterable[Dataset]: + return self._datasets + + @datasets.setter + def datasets(self, datasets: Iterable[Dataset]): + self._datastes = datasets + + @property + def episode_reward(self) -> Optional[float]: + return self._episode_reward + + @episode_reward.setter + def episode_reward(self, episode_reward: Optional[float]): + self._episode_reward = episode_reward + + @property + def actions(self) -> List[ActionType]: + return self._actions + + @property + def versions(self) -> GetVersionReply: + """Get the version numbers from the compiler service.""" + return GetVersionReply( + service_version=__version__, + compiler_version=self.session_type.compiler_version, + ) + + @property + def version(self) -> str: + """The version string of the compiler service.""" + return self.versions.service_version + + @property + def compiler_version(self) -> str: + """The version string of the underlying compiler that this service supports.""" + return self.versions.compiler_version + + @property + def episode_walltime(self) -> float: + return time() - self.episode_start_time + + @property + def state(self) -> CompilerEnvState: + return CompilerEnvState( + benchmark=str(self.benchmark) if self.benchmark else None, + reward=self.episode_reward, + walltime=self.episode_walltime, + commandline=self.action_space.to_string(self.actions), + ) + + @property + def action_space(self) -> Space: + return self._action_space + + @action_space.setter + def action_space(self, action_space: Optional[str]): + self.action_space_name = action_space + index = ( + [a.name for a in self.action_spaces].index(action_space) + if self.action_space_name + else 0 + ) + self._action_space: NamedDiscrete = self.action_spaces[index] + + @property + def action_spaces(self) -> List[str]: + return self._action_spaces + + @action_spaces.setter + def action_spaces(self, action_spaces: List[str]): + self._action_spaces = action_spaces + + @property + def benchmark(self) -> Benchmark: + return self._benchmark_in_use + + @benchmark.setter + def benchmark(self, benchmark: Union[str, Benchmark, BenchmarkUri]): + warnings.warn("Changing the benchmark has no effect until reset() is called") + if isinstance(benchmark, str): + benchmark_object = self.datasets.benchmark(benchmark) + logger.debug("Setting benchmark by name: %s", benchmark_object) + self._next_benchmark = benchmark_object + elif isinstance(benchmark, Benchmark): + logger.debug("Setting benchmark: %s", benchmark.uri) + self._next_benchmark = benchmark + elif isinstance(benchmark, BenchmarkUri): + benchmark_object = self.datasets.benchmark_from_parsed_uri(benchmark) + logger.debug("Setting benchmark by name: %s", benchmark_object) + self._next_benchmark = benchmark_object + else: + raise TypeError( + f"Expected a Benchmark or str, received: '{type(benchmark).__name__}'" + ) + + @property + def reward_space(self) -> Optional[Reward]: + return self.reward_space_spec + + @reward_space.setter + def reward_space(self, reward_space: Optional[Union[str, Reward]]) -> None: + # Coerce the observation space into a string. + reward_space: Optional[str] = ( + reward_space.name if isinstance(reward_space, Reward) else reward_space + ) + + if reward_space: + if reward_space not in self.reward.spaces: + raise LookupError(f"Reward space not found: {reward_space}") + # The reward space remains unchanged, nothing to do. + if reward_space == self.reward_space: + return + self.reward_space_spec = self.reward.spaces[reward_space] + self._reward_range = ( + self.reward_space_spec.min, + self.reward_space_spec.max, + ) + # Reset any cumulative rewards. + self.episode_reward = 0 + else: + # If no reward space is being used then set the reward range to + # unbounded. + self.reward_space_spec = None + self._reward_range = (-np.inf, np.inf) + + @property + def reward_range(self) -> Tuple[float, float]: + return self._reward_range + + @property + def reward(self) -> RewardView: + return self._reward + + @reward.setter + def reward(self, reward: RewardView) -> None: + self._reward = reward + + @property + def observation_space(self) -> Optional[Space]: + if self.observation_space_spec: + return self.observation_space_spec.space + + @observation_space.setter + def observation_space( + self, observation_space: Optional[Union[str, ObservationSpaceSpec]] + ) -> None: + # Coerce the observation space into a string. + observation_space: Optional[str] = ( + observation_space.id + if isinstance(observation_space, ObservationSpaceSpec) + else observation_space + ) + + if observation_space: + if observation_space not in self.observation.spaces: + raise LookupError(f"Observation space not found: {observation_space}") + self.observation_space_spec = self.observation.spaces[observation_space] + else: + self.observation_space_spec = None + + def _init_kwargs(self) -> Dict[str, Any]: + """Retturn a dictionary of keyword arguments used to initialize the + environment. + """ + return { + "session_type": self.session_type, + "action_space": self.action_space, + "benchmark": self.benchmark, + "connection_settings": self._connection_settings, + } + + def fork(self) -> "InProcessClientCompilerEnv": + try: + new_session: CompilationSession = self.session.fork() + + # Create a new environment that shares the connection. + new_env = type(self)(**self._init_kwargs(), session=new_session) + + # Now that we have initialized the environment with the current + # state, set the benchmark so that calls to new_env.reset() will + # correctly revert the environment to the initial benchmark state. + # + # pylint: disable=protected-access + new_env._next_benchmark = self._benchmark_in_use + + # Set the "visible" name of the current benchmark to hide the fact + # that we loaded from a custom benchmark file. + new_env._benchmark_in_use = self._benchmark_in_use + except NotImplementedError: + # Fallback implementation. If the compiler service does not support + # the Fork() operator then we create a new independent environment + # and apply the sequence of actions in the current environment to + # replay the state. + new_env = type(self)(**self._init_kwargs()) + new_env.reset() + _, _, done, _ = new_env.multistep(self.actions) + assert not done, "Failed to replay action sequence in forked environment" + + # Create copies of the mutable reward and observation spaces. This + # is required to correctly calculate incremental updates. + new_env.reward.spaces = deepcopy(self.reward.spaces) + new_env.observation.spaces = deepcopy(self.observation.spaces) + + # Set the default observation and reward types. Note the use of IDs here + # to prevent passing the spaces by reference. + if self.observation_space: + new_env.observation_space = self.observation_space_spec.id + if self.reward_space: + new_env.reward_space = self.reward_space.name + + # Copy over the mutable episode state. + new_env.episode_reward = self.episode_reward + new_env.episode_start_time = self.episode_start_time + new_env._actions = self.actions.copy() + + return new_env + + def reset( + self, + benchmark: Optional[Union[str, Benchmark]] = None, + action_space: Optional[str] = None, + reward_space: Union[ + OptionalArgumentValue, str, Reward + ] = OptionalArgumentValue.UNCHANGED, + observation_space: Union[ + OptionalArgumentValue, str, ObservationSpaceSpec + ] = OptionalArgumentValue.UNCHANGED, + timeout: Optional[float] = 300, + ) -> Optional[ObservationType]: + shutil.rmtree(self.working_directory, ignore_errors=True) + self.working_directory = make_working_directory(self.session_type) + + if observation_space != OptionalArgumentValue.UNCHANGED: + self.observation_space = observation_space + + if reward_space != OptionalArgumentValue.UNCHANGED: + self.reward_space = reward_space + + if not self._next_benchmark: + raise TypeError( + "No benchmark set. Set a benchmark using " + "`env.reset(benchmark=benchmark)`. Use `env.datasets` to " + "access the available benchmarks." + ) + + self.action_space_name = action_space or self.action_space_name + + # Update the user requested benchmark, if provided. + if benchmark: + self.benchmark = benchmark + self._benchmark_in_use = self._next_benchmark + self._benchmark_in_use_proto = self._benchmark_in_use.proto + + self.session = self.session_type( + working_directory=self.working_directory, + action_space=action_space_to_proto(self.action_space), + benchmark=self._benchmark_in_use_proto, + ) + + self.reward.get_cost = self.observation.__getitem__ + self.episode_start_time = time() + self._actions = [] + + self.reward.reset(benchmark=self.benchmark, observation_view=self.observation) + if self.reward_space: + self.episode_reward = 0.0 + + if self.observation_space: + return self.observation.spaces[self.observation_space_spec.id].translate( + self.session.get_observation( + ObservationSpaceProto(name=self.observation_space_spec.id) + ) + ) + + @property + def in_episode(self) -> bool: + return self.session is not None + + def step( + self, + action: ActionType, + observation_spaces: Optional[Iterable[Union[str, ObservationSpaceSpec]]] = None, + reward_spaces: Optional[Iterable[Union[str, Reward]]] = None, + timeout: Optional[float] = 300, + ) -> StepType: + """:raises SessionNotFound: If :meth:`reset() + ` has not been called. + """ + return self.multistep( + [action], observation_spaces, reward_spaces, timeout=timeout + ) + + def multistep( + self, + actions: Iterable[ActionType], + observation_spaces: Optional[Iterable[Union[str, ObservationSpaceSpec]]] = None, + reward_spaces: Optional[Iterable[Union[str, Reward]]] = None, + timeout: Optional[float] = 300, + ): + """:raises SessionNotFound: If :meth:`reset() + ` has not been called. + """ + # Coerce observation spaces into a list of ObservationSpaceSpec instances. + if observation_spaces: + observation_spaces_to_compute: List[ObservationSpaceSpec] = [ + obs + if isinstance(obs, ObservationSpaceSpec) + else self.observation.spaces[obs] + for obs in observation_spaces + ] + elif self.observation_space_spec: + observation_spaces_to_compute: List[ObservationSpaceSpec] = [ + self.observation_space_spec + ] + observation_spaces = [self.observation_space_spec] + else: + observation_spaces_to_compute: List[ObservationSpaceSpec] = [] + observation_spaces = [] + + # Coerce reward spaces into a list of Reward instances. + if reward_spaces: + reward_spaces_to_compute: List[Reward] = [ + rew if isinstance(rew, Reward) else self.reward.spaces[rew] + for rew in reward_spaces + ] + elif self.reward_space: + reward_spaces_to_compute: List[Reward] = [self.reward_space] + reward_spaces = [self.reward_space] + else: + reward_spaces_to_compute: List[Reward] = [] + reward_spaces = [] + + reward_observation_spaces: List[ObservationSpaceSpec] = [] + for reward_space in reward_spaces: + reward_observation_spaces += [ + self.observation.spaces[obs] for obs in reward_space.observation_spaces + ] + + observations_to_compute: List[ObservationSpaceSpec] = list( + set(observation_spaces).union(set(reward_observation_spaces)) + ) + + # Record the actions. + self._actions += actions + + done, new_action_space, action_had_no_effect = False, False, True + for action in actions: + ( + done, + new_new_action_space, + new_action_had_no_effect, + ) = self.session.apply_action( + self.service_message_converters.action_converter(action) + ) + new_action_space |= new_new_action_space is not None + action_had_no_effect &= new_action_had_no_effect + + # If the action space has changed, update it. + if new_new_action_space: + self._action_space = ( + self.service_message_converters.action_space_converter( + new_new_action_space + ) + ) + + if done: + default_observations = [ + observation_space.default_value + for observation_space in observation_spaces + ] + default_rewards = [ + float(reward_space.reward_on_error(self.episode_reward)) + for reward_space in reward_spaces + ] + return ( + default_observations, + default_rewards, + True, + { + "episode_ended_by_environment": True, + }, + ) + + # Translate observations to python representations. + computed_observations = { + observation_space.id: observation_space.translate( + self.session.get_observation( + ObservationSpaceProto(name=observation_space.id) + ) + ) + for observation_space in observations_to_compute + } + + # Get the user-requested observation. + observations: List[ObservationType] = [ + computed_observations[observation_space.id] + for observation_space in observation_spaces + ] + + # Update and compute the rewards. + rewards: List[RewardType] = [] + for reward_space in reward_spaces: + reward_observations = [ + computed_observations[observation_space] + for observation_space in reward_space.observation_spaces + ] + rewards.append( + float( + reward_space.update(actions, reward_observations, self.observation) + ) + ) + + info = { + "action_had_no_effect": action_had_no_effect, + "new_action_space": new_action_space, + } + + # Translate observations lists back to the appropriate types. + if observation_spaces is None and self.observation_space_spec: + observations = observations[0] + elif not observation_spaces_to_compute: + observations = None + + # Translate reward lists back to the appropriate types. + if reward_spaces is None and self.reward_space: + rewards = rewards[0] + # Update the cumulative episode reward + self.episode_reward += rewards + elif not reward_spaces_to_compute: + rewards = None + + return observations, rewards, done, info + + def render( + self, + mode="human", + ) -> Optional[str]: + """Render the environment. + + InProcessClientCompilerEnv instances support two render modes: "human", which prints + the current environment state to the terminal and return nothing; and + "ansi", which returns a string representation of the current environment + state. + + :param mode: The render mode to use. + :raises TypeError: If a default observation space is not set, or if the + requested render mode does not exist. + """ + if not self.observation_space: + raise ValueError("Cannot call render() when no observation space is used") + observation = self.observation[self.observation_space_spec.id] + if mode == "human": + print(observation) + elif mode == "ansi": + return str(observation) + else: + raise ValueError(f"Invalid mode: {mode}") + + @property + def _observation_view_type(self): + """Returns the type for observation views. + + Subclasses may override this to extend the default observation view. + """ + return ObservationView + + @property + def _reward_view_type(self): + """Returns the type for reward views. + + Subclasses may override this to extend the default reward view. + """ + return RewardView + + def apply(self, state: CompilerEnvState) -> None: # noqa + # TODO(cummins): Does this behavior make sense? Take, for example: + # + # >>> env.apply(state) + # >>> env.benchmark == state.benchmark + # False + # + # I think most users would reasonable expect `env.apply(state)` to fully + # apply the state, not just the sequence of actions. And what about the + # choice of observation space, reward space, etc? + if self.benchmark != state.benchmark: + warnings.warn( + f"Applying state from environment for benchmark '{state.benchmark}' " + f"to environment for benchmark '{self.benchmark}'" + ) + self.reset(benchmark=state.benchmark) + + actions = self.action_space.from_string(state.commandline) + done = False + for action in actions: + _, _, done, info = self.step(action) + if done: + raise ValueError( + f"Environment terminated with error: `{info.get('error_details')}`" + ) + + def validate(self, state: Optional[CompilerEnvState] = None) -> ValidationResult: + if state: + self.reset(benchmark=state.benchmark) + in_place = False + else: + state = self.state + in_place = True + + errors: ValidationError = [] + validation = { + "state": state, + "actions_replay_failed": False, + "reward_validated": False, + "reward_validation_failed": False, + "benchmark_semantics_validated": False, + "benchmark_semantics_validation_failed": False, + } + + fkd = self.fork() + try: + with Timer() as walltime: + replay_target = self if in_place else fkd + replay_target.reset(benchmark=state.benchmark) + # Use a while loop here so that we can `break` early out of the + # validation process in case a step fails. + while True: + try: + replay_target.apply(state) + except (ValueError, OSError) as e: + validation["actions_replay_failed"] = True + errors.append( + ValidationError( + type="Action replay failed", + data={ + "exception": str(e), + "exception_type": type(e).__name__, + }, + ) + ) + break + + if state.reward is not None and self.reward_space is None: + warnings.warn( + "Validating state with reward, but " + "environment has no reward space set" + ) + elif ( + state.reward is not None + and self.reward_space + and self.reward_space.deterministic + ): + validation["reward_validated"] = True + # If reward deviates from the expected amount record the + # error but continue with the remainder of the validation. + if not isclose( + state.reward, + replay_target.episode_reward, + rel_tol=1e-5, + abs_tol=1e-10, + ): + validation["reward_validation_failed"] = True + errors.append( + ValidationError( + type=( + f"Expected reward {state.reward} but " + f"received reward {replay_target.episode_reward}" + ), + data={ + "expected_reward": state.reward, + "actual_reward": replay_target.episode_reward, + }, + ) + ) + + benchmark = replay_target.benchmark + if benchmark.is_validatable(): + validation["benchmark_semantics_validated"] = True + semantics_errors = benchmark.validate(replay_target) + if semantics_errors: + validation["benchmark_semantics_validation_failed"] = True + errors += semantics_errors + + # Finished all checks, break the loop. + break + finally: + fkd.close() + + return ValidationResult.construct( + walltime=walltime.time, + errors=errors, + **validation, + ) + + def send_param(self, key: str, value: str) -> str: + """Send a single parameter to the compiler service. + + See :meth:`send_params() + ` for more + information. + + :param key: The parameter key. + + :param value: The parameter value. + + :return: The response from the compiler service. + + :raises SessionNotFound: If called before :meth:`reset() + `. + """ + return self.session.handle_session_parameter(key, value) + + def send_params(self, *params: Iterable[Tuple[str, str]]) -> List[str]: + """Send a list of parameters to the compiler service. + + This provides a mechanism to send messages to the backend compilation + session in a way that doesn't conform to the normal communication + pattern. This can be useful for things like configuring runtime + debugging settings, or applying "meta actions" to the compiler that are + not exposed in the compiler's action space. Consult the documentation + for a specific compiler service to see what parameters, if any, are + supported. + + Must have called :meth:`reset() + ` first. + + :param params: A list of parameters, where each parameter is a + :code:`(key, value)` tuple. + + :return: A list of string responses, one per parameter. + + :raises SessionNotFound: If called before :meth:`reset() + `. + """ + return [ + self.session.handle_session_parameter(key, value) for key, value in params + ] + + def __copy__(self) -> "InProcessClientCompilerEnv": + raise TypeError( + "InProcessClientCompilerEnv instances do not support shallow copies. Use deepcopy()" + ) + + def __deepcopy__(self, memo) -> "InProcessClientCompilerEnv": + del memo # unused + return self.fork() diff --git a/compiler_gym/service/proto/BUILD b/compiler_gym/service/proto/BUILD index caec4df87..ca401a00f 100644 --- a/compiler_gym/service/proto/BUILD +++ b/compiler_gym/service/proto/BUILD @@ -18,6 +18,7 @@ py_library( deps = [ ":compiler_gym_service_py", ":compiler_gym_service_py_grpc", + "//compiler_gym/spaces:action_space", "//compiler_gym/spaces:box", "//compiler_gym/spaces:commandline", "//compiler_gym/spaces:dict", diff --git a/compiler_gym/service/proto/CMakeLists.txt b/compiler_gym/service/proto/CMakeLists.txt index 7887e3f4b..0096dd184 100644 --- a/compiler_gym/service/proto/CMakeLists.txt +++ b/compiler_gym/service/proto/CMakeLists.txt @@ -14,6 +14,7 @@ cg_py_library( DEPS "::compiler_gym_service_py" "::compiler_gym_service_py_grpc" + compiler_gym::spaces::action_space compiler_gym::spaces::box compiler_gym::spaces::commandline compiler_gym::spaces::dict diff --git a/compiler_gym/service/proto/py_converters.py b/compiler_gym/service/proto/py_converters.py index 468f0d554..c5734f626 100644 --- a/compiler_gym/service/proto/py_converters.py +++ b/compiler_gym/service/proto/py_converters.py @@ -25,7 +25,9 @@ from gym.spaces import Space as GymSpace from compiler_gym.service.proto.compiler_gym_service_pb2 import ( - ActionSpace, + ActionSpace as ActionSpaceProto, +) +from compiler_gym.service.proto.compiler_gym_service_pb2 import ( BooleanBox, BooleanRange, BooleanSequenceSpace, @@ -62,6 +64,7 @@ StringSpace, StringTensor, ) +from compiler_gym.spaces.action_space import ActionSpace from compiler_gym.spaces.box import Box from compiler_gym.spaces.commandline import Commandline, CommandlineFlag from compiler_gym.spaces.dict import Dict @@ -442,6 +445,12 @@ def __call__(self, message: ObservationSpace) -> GymSpace: return res +def make_action_space_wrapper( + converter: Callable[[Any], Any] +) -> Callable[[Any], ActionSpace]: + return lambda msg: ActionSpace(space=converter(msg)) + + def make_message_default_converter() -> Callable[[Any], Any]: conversion_map = { bool: convert_trivial, @@ -493,7 +502,7 @@ def make_message_default_converter() -> Callable[[Any], Any]: conversion_map[ListSpace] = ListSpaceMessageConverter(conversion_map[Space]) conversion_map[DictSpace] = DictSpaceMessageConverter(conversion_map[Space]) conversion_map[SpaceSequenceSpace] = SpaceSequenceSpaceMessageConverter(res) - conversion_map[ActionSpace] = ActionSpaceMessageConverter(res) + conversion_map[ActionSpaceProto] = ActionSpaceMessageConverter(res) conversion_map[ObservationSpace] = ObservationSpaceMessageConverter(res) conversion_map[any_pb2.Any] = ProtobufAnyConverter( diff --git a/compiler_gym/spaces/BUILD b/compiler_gym/spaces/BUILD index 88e74e1a2..c5e4c8200 100644 --- a/compiler_gym/spaces/BUILD +++ b/compiler_gym/spaces/BUILD @@ -11,6 +11,7 @@ py_library( srcs = ["__init__.py"], visibility = ["//visibility:public"], deps = [ + ":action_space", ":box", ":commandline", ":common", @@ -28,8 +29,8 @@ py_library( ) py_library( - name = "common", - srcs = ["common.py"], + name = "action_space", + srcs = ["action_space.py"], visibility = ["//compiler_gym:__subpackages__"], ) @@ -48,6 +49,12 @@ py_library( ], ) +py_library( + name = "common", + srcs = ["common.py"], + visibility = ["//compiler_gym:__subpackages__"], +) + py_library( name = "dict", srcs = ["dict.py"], diff --git a/compiler_gym/spaces/CMakeLists.txt b/compiler_gym/spaces/CMakeLists.txt index e8d3bc69c..c9a36f5ab 100644 --- a/compiler_gym/spaces/CMakeLists.txt +++ b/compiler_gym/spaces/CMakeLists.txt @@ -11,6 +11,7 @@ cg_py_library( SRCS "__init__.py" DEPS + ::action_space ::common ::box ::commandline @@ -28,8 +29,8 @@ cg_py_library( ) cg_py_library( - NAME common - SRCS "common.py" + NAME action_space + SRCS action_space.py ) cg_py_library( @@ -47,6 +48,11 @@ cg_py_library( PUBLIC ) +cg_py_library( + NAME common + SRCS "common.py" +) + cg_py_library( NAME dict SRCS dict.py diff --git a/compiler_gym/spaces/__init__.py b/compiler_gym/spaces/__init__.py index f52ca0da2..385928649 100644 --- a/compiler_gym/spaces/__init__.py +++ b/compiler_gym/spaces/__init__.py @@ -2,6 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from compiler_gym.spaces.action_space import ActionSpace from compiler_gym.spaces.box import Box from compiler_gym.spaces.commandline import Commandline, CommandlineFlag from compiler_gym.spaces.dict import Dict @@ -16,6 +17,7 @@ from compiler_gym.spaces.tuple import Tuple __all__ = [ + "ActionSpace", "Box", "Commandline", "CommandlineFlag", diff --git a/compiler_gym/spaces/action_space.py b/compiler_gym/spaces/action_space.py new file mode 100644 index 000000000..edb8975dc --- /dev/null +++ b/compiler_gym/spaces/action_space.py @@ -0,0 +1,105 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from typing import List, Optional + +from gym.spaces import Space + +from compiler_gym.util.gym_type_hints import ActionType + + +class ActionSpace(Space): + """A wrapper around a :code:`gym.spaces.Space` with additional functionality + for action spaces. + """ + + def __init__(self, space: Space): + """Constructor. + + :param space: The space that this action space wraps. + """ + self.wrapped = space + + def __getattr__(self, name: str): + return getattr(self.wrapped, name) + + def __getitem__(self, name: str): + return self.wrapped[name] + + def sample(self) -> ActionType: + return self.wrapped.sample() + + def seed(self, seed: Optional[int] = None) -> ActionType: + return self.wrapped.seed(seed) + + def contains(self, x: ActionType) -> bool: + """Return boolean specifying if x is a valid member of this space.""" + raise self.wrapped.contains(x) + + def __contains__(self, x: ActionType) -> bool: + """Return boolean specifying if x is a valid member of this space.""" + return self.wrapped.contains(x) + + def __eq__(self, rhs) -> bool: + if isinstance(rhs, ActionSpace): + return self.wrapped == rhs.wrapped + else: + return self.wrapped == rhs + + def __ne__(self, rhs) -> bool: + if isinstance(rhs, ActionSpace): + return self.wrapped != rhs.wrapped + else: + return self.wrapped != rhs + + def to_string(self, actions: List[ActionType]) -> str: + """Render the provided list of actions to a string. + + This method is used to produce a human-readable string to represent a + sequence of actions. Subclasses may override the default implementation + to provide custom rendering. + + This is the complement of :meth:`from_string() + `. The two methods + are bidirectional: + + >>> actions = env.actions + >>> s = env.action_space.to_string(actions) + >>> actions == env.action_space.from_string(s) + True + + :param actions: A list of actions drawn from this space. + + :return: A string representation that can be decoded using + :meth:`from_string() + `. + """ + if hasattr(self.wrapped, "to_string"): + return self.wrapped.to_string(actions) + + return ",".join(str(x) for x in actions) + + def from_string(self, string: str) -> List[ActionType]: + """Return a list of actions from the given string. + + This is the complement of :meth:`to_string() + `. The two methods are + bidirectional: + + >>> actions = env.actions + >>> s = env.action_space.to_string(actions) + >>> actions == env.action_space.from_string(s) + True + + :param string: A string. + + :return: A list of actions. + """ + if hasattr(self.wrapped, "from_string"): + return self.wrapped.from_string(string) + + return [self.dtype.type(x) for x in string.split(",")] + + def __repr__(self) -> str: + return f"{type(self).__name__}({self.wrapped})" diff --git a/compiler_gym/spaces/commandline.py b/compiler_gym/spaces/commandline.py index 58a7c2f84..f89235a49 100644 --- a/compiler_gym/spaces/commandline.py +++ b/compiler_gym/spaces/commandline.py @@ -2,7 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from typing import Iterable, List, NamedTuple, Union +from typing import Iterable, List, NamedTuple from compiler_gym.spaces.named_discrete import NamedDiscrete @@ -65,22 +65,20 @@ def __init__(self, items: Iterable[CommandlineFlag], name: str): def __repr__(self) -> str: return f"Commandline([{' '.join(self.flags)}])" - def commandline(self, values: Union[int, Iterable[int]]) -> str: + def to_string(self, values: List[int]) -> str: """Produce a commandline invocation from a sequence of values. :param values: A numeric value from the space, or sequence of values. :return: A string commandline invocation. """ - if isinstance(values, int): - return self.flags[values] - else: - return " ".join([self.flags[v] for v in values]) + return " ".join([self.flags[v] for v in values]) - def from_commandline(self, commandline: str) -> List[int]: + def from_string(self, commandline: str) -> List[int]: """Produce a sequence of actions from a commandline. :param commandline: A string commandline invocation, as produced by - :func:`commandline() `. + :func:`to_string() + `. :return: A list of action values. :raises LookupError: If any of the flags in the commandline are not recognized. diff --git a/compiler_gym/spaces/named_discrete.py b/compiler_gym/spaces/named_discrete.py index afe67d6b8..5af703fe0 100644 --- a/compiler_gym/spaces/named_discrete.py +++ b/compiler_gym/spaces/named_discrete.py @@ -65,18 +65,13 @@ def to_string(self, values: Union[int, Iterable[ActionType]]) -> str: else: return self.names[values] - def from_string( - self, values: Union[str, Iterable[str]] - ) -> Union[ActionType, List[ActionType]]: + def from_string(self, string: str) -> Union[ActionType, List[ActionType]]: """Convert a name, or list of names, to numeric values. :param values: A name, or list of names. :return: A numeric value, or list of numeric values. """ - if isinstance(values, str): - return self.names.index(values) - else: - return [self.names.index(v) for v in values] + return [self.names.index(v) for v in string.split(" ")] def __eq__(self, other) -> bool: return ( diff --git a/compiler_gym/spaces/reward.py b/compiler_gym/spaces/reward.py index 75c9d5dad..ac164d988 100644 --- a/compiler_gym/spaces/reward.py +++ b/compiler_gym/spaces/reward.py @@ -2,7 +2,6 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import warnings from typing import List, Optional, Tuple, Union import numpy as np @@ -39,10 +38,7 @@ class Reward(Scalar): def __init__( self, - # NOTE(github.com/facebookresearch/CompilerGym/issues/381): Once `id` - # argument has been removed, the default value for `name` can be - # removed. - name: str = None, + name: str, observation_spaces: Optional[List[str]] = None, default_value: RewardType = 0, min: Optional[RewardType] = None, @@ -51,10 +47,6 @@ def __init__( success_threshold: Optional[RewardType] = None, deterministic: bool = False, platform_dependent: bool = True, - # NOTE(github.com/facebookresearch/CompilerGym/issues/381): Backwards - # compatability workaround for deprecated parameter, will be removed in - # v0.2.4. - id: Optional[str] = None, ): """Constructor. @@ -83,10 +75,6 @@ def __init__( :param deterministic: Whether the reward space is deterministic. :param platform_dependent: Whether the reward values depend on the execution environment of the service. - :param id: The name of the reward space. - - .. deprecated:: 0.2.3 - Use :code:`name` instead. """ super().__init__( name=name, @@ -95,19 +83,7 @@ def __init__( dtype=np.float64, ) - # NOTE(github.com/facebookresearch/CompilerGym/issues/381): Backwards - # compatability workaround for deprecated parameter, will be removed in - # v0.2.4. - if id is not None: - warnings.warn( - "The `id` argument of " - "compiler_gym.spaces.Reward.__init__() " - "has been renamed `name`. This will break in a future release, " - "please update your code.", - DeprecationWarning, - ) self.name = name or id - self.id = self.name if not self.name: raise TypeError("No name given") diff --git a/compiler_gym/util/BUILD b/compiler_gym/util/BUILD index 23893e78a..7ae046be5 100644 --- a/compiler_gym/util/BUILD +++ b/compiler_gym/util/BUILD @@ -18,7 +18,6 @@ py_library( "filesystem.py", "gym_type_hints.py", "logging.py", - "logs.py", "minimize_trajectory.py", "parallelization.py", "permutation.py", diff --git a/compiler_gym/util/CMakeLists.txt b/compiler_gym/util/CMakeLists.txt index 892d752f5..68ec9c681 100644 --- a/compiler_gym/util/CMakeLists.txt +++ b/compiler_gym/util/CMakeLists.txt @@ -19,7 +19,6 @@ cg_py_library( "filesystem.py" "gym_type_hints.py" "logging.py" - "logs.py" "minimize_trajectory.py" "parallelization.py" "permutation.py" diff --git a/compiler_gym/util/logs.py b/compiler_gym/util/logs.py deleted file mode 100644 index 6dfbda7b8..000000000 --- a/compiler_gym/util/logs.py +++ /dev/null @@ -1,29 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. -from pathlib import Path - -from deprecated.sphinx import deprecated - -from compiler_gym.util.runfiles_path import create_user_logs_dir - -# File names of generated logs. -METADATA_NAME = "random_search.json" -PROGRESS_LOG_NAME = "random_search_progress.csv" -BEST_ACTIONS_NAME = "random_search_best_actions.txt" -BEST_COMMANDLINE_NAME = "random_search_best_actions_commandline.txt" -# The name of the LLVM bitcode file generated by -BEST_ACTIONS_PROGRESS_NAME = "random_search_best_actions_progress.csv" - - -@deprecated( - version="0.2.1", - reason="Use compiler_gym.util.create_user_logs_dir() instead", -) -def create_logging_dir(name: str) -> Path: - """Deprecated function to create a directory for writing logs to. - - Use :code:`compiler_gym.util.runfiles_path.create_user_logs_dir()` instead. - """ - return create_user_logs_dir(name) diff --git a/compiler_gym/views/observation.py b/compiler_gym/views/observation.py index b3edfdd48..0338dc19e 100644 --- a/compiler_gym/views/observation.py +++ b/compiler_gym/views/observation.py @@ -4,8 +4,6 @@ # LICENSE file in the root directory of this source tree. from typing import Callable, Dict, List -from deprecated.sphinx import deprecated - from compiler_gym.errors import ServiceError from compiler_gym.service.proto import ObservationSpace from compiler_gym.util.gym_type_hints import ( @@ -95,54 +93,11 @@ def _add_space(self, space: ObservationSpaceSpec): # env.observation.FooBar(). setattr(self, space.id, lambda: self[space.id]) - @deprecated( - version="0.2.1", - reason=( - "Use the derived_observation_spaces argument to CompilerEnv constructor. " - "See ." - ), - ) def add_derived_space( self, id: str, base_id: str, **kwargs, - ) -> None: - """Alias to :func:`ObservationSpaceSpec.make_derived_space() - ` that adds - the derived space to the observation view. - - Example usage: - - >>> env.observation.add_derived_space( - id="src_len", - base_id="src", - translate=lambda src: np.array([len(src)], dtype=np.int32), - shape=Box(shape=(1,), dtype=np.int32), - ) - >>> env.observation["src_len"] - 1029 - - :param id: The name of the new observation space. - - :param base_id: The name of the observation space that this is derived - from. - - :param \\**kwargs: Arguments passed to - :func:`ObservationSpaceSpec.make_derived_space - `. - """ - base_space = self.spaces[base_id] - self._add_space(base_space.make_derived_space(id=id, **kwargs)) - - # NOTE(github.com/facebookresearch/CompilerGym/issues/461): This method will - # be renamed to add_derived_space() once the current method with that name - # is removed. - def add_derived_space_internal( - self, - id: str, - base_id: str, - **kwargs, ) -> None: """Internal API for adding a new observation space.""" base_space = self.spaces[base_id] diff --git a/compiler_gym/wrappers/commandline.py b/compiler_gym/wrappers/commandline.py index 976b339e5..ef920d407 100644 --- a/compiler_gym/wrappers/commandline.py +++ b/compiler_gym/wrappers/commandline.py @@ -36,7 +36,7 @@ def __init__( """ super().__init__(env) - if not isinstance(env.action_space, Commandline): + if not isinstance(env.action_space.wrapped, Commandline): raise TypeError( f"Unsupported action space: {type(env.action_space).__name__}" ) @@ -124,10 +124,10 @@ def __init__( if not flags: raise TypeError("No flags provided") - if not isinstance(env.action_space, Commandline): + if not issubclass(type(env.action_space.wrapped), Commandline): raise TypeError( "Can only wrap Commandline action space. " - f"Received: {type(env.action_space).__name__}" + f"Received: {type(env.action_space.wrapped).__name__}" ) self._forward_translation: List[int] = [self.action_space[f] for f in flags] diff --git a/compiler_gym/wrappers/core.py b/compiler_gym/wrappers/core.py index aebf41c75..b9323ea4d 100644 --- a/compiler_gym/wrappers/core.py +++ b/compiler_gym/wrappers/core.py @@ -7,6 +7,7 @@ from collections.abc import Iterable as IterableType from typing import Any, Iterable, List, Optional, Tuple, Union +from deprecated.sphinx import deprecated from gym import Wrapper from gym.spaces import Space @@ -260,9 +261,15 @@ def compiler_version(self) -> str: def state(self) -> CompilerEnvState: return self.env.state + @deprecated( + version="0.2.5", reason="Use env.action_space.to_string(env.actions) instead" + ) def commandline(self) -> str: return self.env.commandline() + @deprecated( + version="0.2.5", reason='Use env.action_space.from_string("...") instead' + ) def commandline_to_actions(self, commandline: str) -> List[ActionType]: return self.env.commandline_to_actions(commandline) diff --git a/docs/source/compiler_gym/service.rst b/docs/source/compiler_gym/service.rst index af8eab5c8..71749a860 100644 --- a/docs/source/compiler_gym/service.rst +++ b/docs/source/compiler_gym/service.rst @@ -40,6 +40,15 @@ ClientServiceCompilerEnv .. automethod:: __init__ +InProcessClientCompilerEnv +-------------------------- + +.. autoclass:: compiler_gym.service.client_service_compiler_env.InProcessClientCompilerEnv + :members: + + .. automethod:: __init__ + + The connection object --------------------- diff --git a/docs/source/compiler_gym/spaces.rst b/docs/source/compiler_gym/spaces.rst index f3ace14dd..e3953fb49 100644 --- a/docs/source/compiler_gym/spaces.rst +++ b/docs/source/compiler_gym/spaces.rst @@ -11,6 +11,13 @@ observation spaces available to compilers. .. currentmodule:: compiler_gym.spaces +ActionSpace +----------- + +.. autoclass:: ActionSpace + :members: + + Commandline ----------- diff --git a/examples/example_compiler_gym_service/env_tests.py b/examples/example_compiler_gym_service/env_tests.py index 0d1448cbb..6befc9ca0 100644 --- a/examples/example_compiler_gym_service/env_tests.py +++ b/examples/example_compiler_gym_service/env_tests.py @@ -16,7 +16,7 @@ import examples.example_compiler_gym_service as example from compiler_gym.envs import CompilerEnv from compiler_gym.errors import SessionNotFound -from compiler_gym.spaces import Box, NamedDiscrete, Scalar, Sequence +from compiler_gym.spaces import ActionSpace, Box, NamedDiscrete, Scalar, Sequence from compiler_gym.util.commands import Popen from tests.test_main import main @@ -72,9 +72,11 @@ def test_versions(env: CompilerEnv): def test_action_space(env: CompilerEnv): """Test that the environment reports the service's action spaces.""" assert env.action_spaces == [ - NamedDiscrete( - name="default", - items=["a", "b", "c"], + ActionSpace( + NamedDiscrete( + name="default", + items=["a", "b", "c"], + ) ) ] diff --git a/examples/llvm_autotuning/autotuners/__init__.py b/examples/llvm_autotuning/autotuners/__init__.py index 6eab0f9b9..703399b26 100644 --- a/examples/llvm_autotuning/autotuners/__init__.py +++ b/examples/llvm_autotuning/autotuners/__init__.py @@ -97,7 +97,7 @@ def __call__(self, env: CompilerEnv, seed: int = 0xCC) -> CompilerEnvState: return CompilerEnvState( benchmark=env.benchmark.uri, - commandline=env.commandline(), + commandline=env.action_space.to_string(env.actions), walltime=timer.time, reward=self.optimization_target.final_reward(env), ) diff --git a/examples/llvm_autotuning/autotuners/greedy_test.py b/examples/llvm_autotuning/autotuners/greedy_test.py index bb462ff22..d5045ff2b 100644 --- a/examples/llvm_autotuning/autotuners/greedy_test.py +++ b/examples/llvm_autotuning/autotuners/greedy_test.py @@ -24,7 +24,7 @@ def test_autotune(): print(result) assert result.benchmark == "benchmark://cbench-v1/crc32" assert result.walltime >= 3 - assert result.commandline == env.commandline() + assert result.commandline == env.action_space.to_string(env.actions) assert env.episode_reward assert env.benchmark == "benchmark://cbench-v1/crc32" assert env.reward_space == "IrInstructionCount" diff --git a/examples/llvm_autotuning/autotuners/nevergrad_test.py b/examples/llvm_autotuning/autotuners/nevergrad_test.py index 485e781aa..a3d9ed453 100644 --- a/examples/llvm_autotuning/autotuners/nevergrad_test.py +++ b/examples/llvm_autotuning/autotuners/nevergrad_test.py @@ -23,7 +23,7 @@ def test_autotune(): print(result) assert result.benchmark == "benchmark://cbench-v1/crc32" assert result.walltime >= 3 - assert result.commandline == env.commandline() + assert result.commandline == env.action_space.to_string(env.actions) assert env.episode_reward >= 0 assert env.benchmark == "benchmark://cbench-v1/crc32" assert env.reward_space == "IrInstructionCount" diff --git a/examples/llvm_autotuning/autotuners/opentuner_test.py b/examples/llvm_autotuning/autotuners/opentuner_test.py index 28bda8456..675956d91 100644 --- a/examples/llvm_autotuning/autotuners/opentuner_test.py +++ b/examples/llvm_autotuning/autotuners/opentuner_test.py @@ -24,7 +24,7 @@ def test_autotune(): print(result) assert result.benchmark == "benchmark://cbench-v1/crc32" assert result.walltime >= 3 - assert result.commandline == env.commandline() + assert result.commandline == env.action_space.to_string(env.actions) assert env.episode_reward >= 0 assert env.benchmark == "benchmark://cbench-v1/crc32" assert env.reward_space == "IrInstructionCount" diff --git a/examples/llvm_autotuning/autotuners/random_test.py b/examples/llvm_autotuning/autotuners/random_test.py index 82f85bf46..9f92fdf68 100644 --- a/examples/llvm_autotuning/autotuners/random_test.py +++ b/examples/llvm_autotuning/autotuners/random_test.py @@ -22,7 +22,7 @@ def test_autotune(): print(result) assert result.benchmark == "benchmark://cbench-v1/crc32" assert result.walltime >= 3 - assert result.commandline == env.commandline() + assert result.commandline == env.action_space.to_string(env.actions) assert env.episode_reward >= 0 assert env.benchmark == "benchmark://cbench-v1/crc32" assert env.reward_space == "IrInstructionCount" diff --git a/examples/llvm_rl/model/inference_result.py b/examples/llvm_rl/model/inference_result.py index 34572f8d7..f53ddae41 100644 --- a/examples/llvm_rl/model/inference_result.py +++ b/examples/llvm_rl/model/inference_result.py @@ -106,7 +106,7 @@ def from_agent( return cls( benchmark=env.benchmark.uri, inference_walltime_seconds=inference_timer.time, - commandline=env.commandline(), + commandline=env.action_space.to_string(env.actions), episode_len=len(env.actions), instruction_count_init=instruction_count_init, instruction_count_final=instruction_count_final, diff --git a/examples/loop_optimizations_service/env_tests.py b/examples/loop_optimizations_service/env_tests.py index 6e20025ad..78e396af9 100644 --- a/examples/loop_optimizations_service/env_tests.py +++ b/examples/loop_optimizations_service/env_tests.py @@ -15,7 +15,7 @@ from compiler_gym.envs import CompilerEnv from compiler_gym.errors import SessionNotFound from compiler_gym.service.client_service_compiler_env import ClientServiceCompilerEnv -from compiler_gym.spaces import Dict, NamedDiscrete, Scalar, Sequence +from compiler_gym.spaces import ActionSpace, Dict, NamedDiscrete, Scalar, Sequence from compiler_gym.third_party.autophase import AUTOPHASE_FEATURE_NAMES from tests.test_main import main @@ -64,20 +64,22 @@ def test_versions(env: ClientServiceCompilerEnv): def test_action_space(env: CompilerEnv): """Test that the environment reports the service's action spaces.""" assert env.action_spaces == [ - NamedDiscrete( - name="loop-opt", - items=[ - "--loop-unroll --unroll-count=2", - "--loop-unroll --unroll-count=4", - "--loop-unroll --unroll-count=8", - "--loop-unroll --unroll-count=16", - "--loop-unroll --unroll-count=32", - "--loop-vectorize -force-vector-width=2", - "--loop-vectorize -force-vector-width=4", - "--loop-vectorize -force-vector-width=8", - "--loop-vectorize -force-vector-width=16", - "--loop-vectorize -force-vector-width=32", - ], + ActionSpace( + NamedDiscrete( + name="loop-opt", + items=[ + "--loop-unroll --unroll-count=2", + "--loop-unroll --unroll-count=4", + "--loop-unroll --unroll-count=8", + "--loop-unroll --unroll-count=16", + "--loop-unroll --unroll-count=32", + "--loop-vectorize -force-vector-width=2", + "--loop-vectorize -force-vector-width=4", + "--loop-vectorize -force-vector-width=8", + "--loop-vectorize -force-vector-width=16", + "--loop-vectorize -force-vector-width=32", + ], + ) ) ] diff --git a/tests/bin/manual_env_bin_test.py b/tests/bin/manual_env_bin_test.py index cca349e29..791fd0fa0 100644 --- a/tests/bin/manual_env_bin_test.py +++ b/tests/bin/manual_env_bin_test.py @@ -347,7 +347,7 @@ def test_greedy(): ) -def test_commandline(): +def test_actions_string(): FLAGS.unparse_flags() io_check( """set_benchmark cbench-v1/adpcm diff --git a/tests/compiler_env_test.py b/tests/compiler_env_test.py index 0049dbaa9..8f5a969bf 100644 --- a/tests/compiler_env_test.py +++ b/tests/compiler_env_test.py @@ -74,13 +74,6 @@ def test_observation_space_set_in_reset(env: LlvmEnv, observation_space: str): assert env.observation_space_spec == observation_space -def test_logger_is_deprecated(env: LlvmEnv): - with pytest.deprecated_call( - match="The `ClientServiceCompilerEnv.logger` attribute is deprecated" - ): - env.logger - - def test_uri_substring_no_match(env: LlvmEnv): env.reset(benchmark="benchmark://cbench-v1/crc32") assert env.benchmark == "benchmark://cbench-v1/crc32" diff --git a/tests/fuzzing/llvm_commandline_opt_equivalence_fuzz_test.py b/tests/fuzzing/llvm_commandline_opt_equivalence_fuzz_test.py index b8e633375..0bb28bbc1 100644 --- a/tests/fuzzing/llvm_commandline_opt_equivalence_fuzz_test.py +++ b/tests/fuzzing/llvm_commandline_opt_equivalence_fuzz_test.py @@ -2,7 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -"""Fuzz test for LlvmEnv.commandline().""" +"""Fuzz test for LlvmEnv.action_space.to_string().""" import os import subprocess from pathlib import Path @@ -25,7 +25,7 @@ @pytest.mark.timeout(600) def test_fuzz(env: LlvmEnv, tmpwd: Path, llvm_opt: Path, llvm_diff: Path): - """This test produces a random trajectory and then uses the commandline() + """This test produces a random trajectory and then uses the commandline generated with opt to check that the states are equivalent. """ del tmpwd @@ -39,7 +39,7 @@ def test_fuzz(env: LlvmEnv, tmpwd: Path, llvm_opt: Path, llvm_diff: Path): apply_random_trajectory( env, random_trajectory_length_range=RANDOM_TRAJECTORY_LENGTH_RANGE, timeout=30 ) - commandline = env.commandline(textformat=True) + commandline = env.action_space.to_string(env.actions) print(env.state) # For debugging in case of failure. # Write the post-trajectory state to file. diff --git a/tests/gcc/gcc_env_test.py b/tests/gcc/gcc_env_test.py index 7e6747a40..683491ca1 100644 --- a/tests/gcc/gcc_env_test.py +++ b/tests/gcc/gcc_env_test.py @@ -319,11 +319,14 @@ def test_choices_observation(): @with_docker -def test_commandline(): +def test_action_space_string(): """Test observation spaces.""" with gym.make("gcc-v0") as env: env.reset() - assert env.commandline() == "docker:gcc:11.2.0 -w -c src.c -o obj.o" + assert ( + env.action_space.to_string(env.actions) + == "docker:gcc:11.2.0 -w -c src.c -o obj.o" + ) @with_docker @@ -341,9 +344,11 @@ def test_set_choices(): with gym.make("gcc-v0") as env: env.reset() env.choices = [-1] * len(env.gcc_spec.options) - assert env.commandline().startswith("docker:gcc:11.2.0 -w -c src.c -o obj.o") + assert env.action_space.to_string(env.actions).startswith( + "docker:gcc:11.2.0 -w -c src.c -o obj.o" + ) env.choices = [0] * len(env.gcc_spec.options) - assert env.commandline().startswith( + assert env.action_space.to_string(env.actions).startswith( "docker:gcc:11.2.0 -O0 -faggressive-loop-optimizations -falign-functions -falign-jumps -falign-labels" ) diff --git a/tests/llvm/action_space_test.py b/tests/llvm/action_space_test.py index 669958052..827929eed 100644 --- a/tests/llvm/action_space_test.py +++ b/tests/llvm/action_space_test.py @@ -9,18 +9,21 @@ pytest_plugins = ["tests.pytest_plugins.llvm"] -def test_commandline_no_actions(env: LlvmEnv): +def test_to_and_from_string_no_actions(env: LlvmEnv): env.reset(benchmark="cbench-v1/crc32") - assert env.commandline() == "opt input.bc -o output.bc" - assert env.commandline_to_actions(env.commandline()) == [] + assert env.action_space.to_string(env.actions) == "opt input.bc -o output.bc" + assert env.action_space.from_string(env.action_space.to_string(env.actions)) == [] -def test_commandline(env: LlvmEnv): +def test_to_and_from_string(env: LlvmEnv): env.reset(benchmark="cbench-v1/crc32") env.step(env.action_space.flags.index("-mem2reg")) env.step(env.action_space.flags.index("-reg2mem")) - assert env.commandline() == "opt -mem2reg -reg2mem input.bc -o output.bc" - assert env.commandline_to_actions(env.commandline()) == [ + assert ( + env.action_space.to_string(env.actions) + == "opt -mem2reg -reg2mem input.bc -o output.bc" + ) + assert env.action_space.from_string(env.action_space.to_string(env.actions)) == [ env.action_space.flags.index("-mem2reg"), env.action_space.flags.index("-reg2mem"), ] diff --git a/tests/llvm/fuzzing_regression_test.py b/tests/llvm/fuzzing_regression_test.py index 8a628bf8c..569fa40d3 100644 --- a/tests/llvm/fuzzing_regression_test.py +++ b/tests/llvm/fuzzing_regression_test.py @@ -22,7 +22,7 @@ def test_regression_test_const_offset_from_gep(env, tmpwd, llvm_diff, llvm_opt): env.write_ir("input.ll") # FIXME: Removing the -separate-const-offset-from-gep actions from the below # commandline "fixes" the test. - actions = env.commandline_to_actions( + actions = env.action_space.from_string( "opt -objc-arc-apelim -separate-const-offset-from-gep -sancov -indvars -loop-reduce -dse -inferattrs -loop-fusion -dce -break-crit-edges -constmerge -indvars -mem2reg -objc-arc-expand -ee-instrument -loop-reroll -break-crit-edges -separate-const-offset-from-gep -loop-idiom -float2int -dce -float2int -ipconstprop -simple-loop-unswitch -coro-cleanup -early-cse-memssa -strip -functionattrs -objc-arc-contract -sink -loop-distribute -loop-reroll -slsr -separate-const-offset-from-gep input.bc -o output.bc" ) @@ -32,7 +32,7 @@ def test_regression_test_const_offset_from_gep(env, tmpwd, llvm_diff, llvm_opt): env.write_ir("env.ll") subprocess.check_call( - env.commandline(textformat=True), + env.action_space.to_string(env.actions) + " -S -o output.ll", env={"PATH": str(llvm_opt.parent)}, shell=True, timeout=60, diff --git a/tests/spaces/BUILD b/tests/spaces/BUILD index 99465c6cd..83ed4dfb3 100644 --- a/tests/spaces/BUILD +++ b/tests/spaces/BUILD @@ -4,6 +4,16 @@ # LICENSE file in the root directory of this source tree. load("@rules_python//python:defs.bzl", "py_test") +py_test( + name = "action_space_test", + timeout = "short", + srcs = ["action_space_test.py"], + deps = [ + "//compiler_gym/spaces", + "//tests:test_main", + ], +) + py_test( name = "box_test", timeout = "short", diff --git a/tests/spaces/CMakeLists.txt b/tests/spaces/CMakeLists.txt index b27a6f167..a28c7d3d3 100644 --- a/tests/spaces/CMakeLists.txt +++ b/tests/spaces/CMakeLists.txt @@ -5,6 +5,16 @@ cg_add_all_subdirs() +cg_py_test( + NAME + action_space_test + SRCS + "action_space_test.py" + DEPS + compiler_gym::spaces::spaces + tests::test_main +) + cg_py_test( NAME box_test diff --git a/tests/spaces/action_space_test.py b/tests/spaces/action_space_test.py new file mode 100644 index 000000000..7bda9b717 --- /dev/null +++ b/tests/spaces/action_space_test.py @@ -0,0 +1,81 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""Unit tests for compiler_gym/spaces/action_space.py.""" +from compiler_gym.spaces import ActionSpace, Discrete, NamedDiscrete +from tests.test_main import main + + +class MockActionSpace: + name = "mock" + foo = 1 + + def sample(self): + return 1 + + def seed(self, s): + pass + + def contains(self, x): + pass + + def __repr__(self) -> str: + return self.name + + +def test_action_space_forward(mocker): + a = MockActionSpace() + ma = ActionSpace(a) + + assert ma.name == "mock" + assert ma.foo == 1 + + mocker.spy(a, "sample") + assert ma.sample() == 1 + assert a.sample.call_count == 1 + + mocker.spy(a, "seed") + ma.seed(10) + assert a.seed.call_count == 1 + + mocker.spy(a, "contains") + 10 in ma + assert a.contains.call_count == 1 + + +def test_action_space_comparison(): + a = MockActionSpace() + b = ActionSpace(a) + c = MockActionSpace() + + assert b == a + assert b.wrapped == a + assert b != c + + +def test_action_space_default_string_conversion(): + """Test that to_string() and from_string() are forward to subclasses.""" + a = Discrete(name="a", n=3) + ma = ActionSpace(a) + + assert ma.to_string([0, 1, 0]) == "0,1,0" + assert ma.from_string("0,1,0") == [0, 1, 0] + + +def test_action_space_forward_string_conversion(): + """Test that to_string() and from_string() are forward to subclasses.""" + a = NamedDiscrete(name="a", items=["a", "b", "c"]) + ma = ActionSpace(a) + + assert ma.to_string([0, 1, 2, 0]) == "a b c a" + assert ma.from_string("a b c a") == [0, 1, 2, 0] + + +def test_action_space_str(): + ma = ActionSpace(MockActionSpace()) + assert str(ma) == "ActionSpace(mock)" + + +if __name__ == "__main__": + main() diff --git a/tests/spaces/commandline_test.py b/tests/spaces/commandline_test.py index 7d0801720..4591e131d 100644 --- a/tests/spaces/commandline_test.py +++ b/tests/spaces/commandline_test.py @@ -36,7 +36,7 @@ def test_contains(): assert not space.contains(4) -def test_commandline(): +def test_to_and_from_string(): space = Commandline( [ CommandlineFlag(name="a", flag="-a", description=""), @@ -46,8 +46,8 @@ def test_commandline(): name="test", ) - assert space.commandline([0, 1, 2]) == "-a -b -c" - assert space.from_commandline(space.commandline([0, 1, 2])) == [0, 1, 2] + assert space.to_string([0, 1, 2]) == "-a -b -c" + assert space.from_string(space.to_string([0, 1, 2])) == [0, 1, 2] if __name__ == "__main__": diff --git a/www/www.py b/www/www.py index d080c07e4..2fcd97e5d 100644 --- a/www/www.py +++ b/www/www.py @@ -281,7 +281,7 @@ def _step(request: StepRequest) -> StepReply: ) ) return StepReply( - commandline=env.commandline(), + commandline=env.action_space.to_string(env.actions), done=done, ir=truncate(ir, max_line_len=250, max_lines=1024), states=states,