From 7a35872a6677da7b8e00fb3eb331a40c2c2baee8 Mon Sep 17 00:00:00 2001 From: Chris Cummins Date: Sun, 20 Feb 2022 17:22:11 +0000 Subject: [PATCH 01/41] [service] Make ConnectionOpts hashable. --- compiler_gym/bin/service.py | 4 +++- compiler_gym/service/connection.py | 13 +++++++------ 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/compiler_gym/bin/service.py b/compiler_gym/bin/service.py index f49e9f7e6..9bdfd1c97 100644 --- a/compiler_gym/bin/service.py +++ b/compiler_gym/bin/service.py @@ -266,7 +266,9 @@ def main(argv): if FLAGS.run_on_port: assert FLAGS.env, "Must specify an --env to run" - settings = ConnectionOpts(script_args=["--port", str(FLAGS.run_on_port)]) + settings = ConnectionOpts( + script_args=frozenset(["--port", str(FLAGS.run_on_port)]) + ) with gym.make(FLAGS.env, connection_settings=settings) as env: print( f"=== Started a service on port {FLAGS.run_on_port}. Use C-c to terminate. ===" diff --git a/compiler_gym/service/connection.py b/compiler_gym/service/connection.py index e8fd9778d..22b22cb0e 100644 --- a/compiler_gym/service/connection.py +++ b/compiler_gym/service/connection.py @@ -10,11 +10,12 @@ from pathlib import Path from signal import Signals from time import sleep, time -from typing import Dict, Iterable, List, Optional, TypeVar, Union +from typing import Dict, FrozenSet, Iterable, List, NamedTuple, Optional, TypeVar, Union import grpc from deprecated.sphinx import deprecated from pydantic import BaseModel +from frozendict import frozendict import compiler_gym.errors from compiler_gym.service.proto import ( @@ -49,7 +50,7 @@ logger = logging.getLogger(__name__) -class ConnectionOpts(BaseModel): +class ConnectionOpts(NamedTuple): """The options used to configure a connection to a service.""" rpc_call_max_seconds: float = 300 @@ -96,11 +97,11 @@ class ConnectionOpts(BaseModel): benchmark. In case of benchmark re-use, leave this :code:`False`. """ - script_args: List[str] = [] + script_args: FrozenSet[str] = frozenset([]) """If the service is started from a local script, this set of args is used on the command line. No effect when used for existing sockets.""" - script_env: Dict[str, str] = {} + script_env: Dict[str, str] = frozendict({}) """If the service is started from a local script, this set of env vars is used on the command line. No effect when used for existing sockets.""" @@ -301,7 +302,7 @@ def __init__( port_init_max_seconds: float, rpc_init_max_seconds: float, process_exit_max_seconds: float, - script_args: List[str], + script_args: FrozenSet[str], script_env: Dict[str, str], ): """Constructor. @@ -323,7 +324,7 @@ def __init__( f"--working_dir={self.cache.path}", ] # Add any custom arguments - cmd += script_args + cmd += list(script_args) # Set the root of the runfiles directory. env = os.environ.copy() From dda3b61c94eb6a0ce011816bb79b09587b568bc4 Mon Sep 17 00:00:00 2001 From: Chris Cummins Date: Sun, 20 Feb 2022 17:24:08 +0000 Subject: [PATCH 02/41] [datasets] Add a note about installing LLVM runtime data. --- compiler_gym/envs/llvm/datasets/cbench.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/compiler_gym/envs/llvm/datasets/cbench.py b/compiler_gym/envs/llvm/datasets/cbench.py index 9cd97c2df..fe979f1e5 100644 --- a/compiler_gym/envs/llvm/datasets/cbench.py +++ b/compiler_gym/envs/llvm/datasets/cbench.py @@ -234,6 +234,10 @@ def download_cBench_runtime_data() -> bool: if (cbench_data / "unpacked").is_file(): return False else: + logger.warning( + "Installing the cBench runtime inputs. This may take a few moments ..." + ) + # Clean up any partially-extracted data directory. if cbench_data.is_dir(): shutil.rmtree(cbench_data) From 07973f3a4381e8740f9931fc12c0cf6ddfce8fa0 Mon Sep 17 00:00:00 2001 From: Chris Cummins Date: Sun, 20 Feb 2022 18:31:21 +0000 Subject: [PATCH 03/41] [tests] Add a new corner case test for forked environments. --- tests/llvm/fork_env_test.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/tests/llvm/fork_env_test.py b/tests/llvm/fork_env_test.py index 55779a945..5b17de3ad 100644 --- a/tests/llvm/fork_env_test.py +++ b/tests/llvm/fork_env_test.py @@ -9,6 +9,7 @@ import pytest from compiler_gym.envs import LlvmEnv +from compiler_gym.service import ServiceError from compiler_gym.util.runfiles_path import runfiles_path from tests.test_main import main @@ -250,5 +251,22 @@ def test_fork_previous_cost_lazy_reward_update(env: LlvmEnv): assert env.reward["IrInstructionCount"] == fkd.reward["IrInstructionCount"] +def test_forked_service_dies(env: LlvmEnv): + """Test that if the service dies on a forked environment, each environment + creates new, independent services. + """ + with env.fork() as fkd: + assert env.service == fkd.service + try: + fkd.service.shutdown() + except ServiceError: + pass # shutdown() raises service error if in-episode. + fkd.service.close() + + env.reset() + fkd.reset() + assert env.service != fkd.service + + if __name__ == "__main__": main() From dcd7708674cd5b3e64e30ec8f19aec12722cbf34 Mon Sep 17 00:00:00 2001 From: Chris Cummins Date: Sat, 19 Feb 2022 16:53:50 +0000 Subject: [PATCH 04/41] [service] Add a reusable pool of service connections. --- compiler_gym/envs/gcc/gcc_env.py | 13 +- compiler_gym/requirements.txt | 1 + compiler_gym/service/BUILD | 10 + compiler_gym/service/__init__.py | 4 +- .../service/client_service_compiler_env.py | 97 ++++----- compiler_gym/service/connection.py | 77 ++++++-- compiler_gym/service/connection_pool.py | 179 +++++++++++++++++ docs/source/compiler_gym/service.rst | 8 + tests/llvm/BUILD | 2 +- tests/llvm/CMakeLists.txt | 2 +- tests/llvm/custom_benchmarks_test.py | 2 +- tests/llvm/fork_env_test.py | 125 ++++++------ tests/llvm/service_connection_test.py | 7 +- tests/pytest_plugins/llvm.py | 2 +- tests/service/BUILD | 11 ++ tests/service/CMakeLists.txt | 12 ++ tests/service/connection_pool_test.py | 184 ++++++++++++++++++ tests/service/connection_test.py | 9 +- 18 files changed, 617 insertions(+), 128 deletions(-) create mode 100644 compiler_gym/service/connection_pool.py create mode 100644 tests/service/connection_pool_test.py diff --git a/compiler_gym/envs/gcc/gcc_env.py b/compiler_gym/envs/gcc/gcc_env.py index fa77837a1..52049a728 100644 --- a/compiler_gym/envs/gcc/gcc_env.py +++ b/compiler_gym/envs/gcc/gcc_env.py @@ -9,6 +9,8 @@ from pathlib import Path from typing import Any, Dict, List, Optional, Union +from frozendict import frozendict + from compiler_gym.datasets import Benchmark from compiler_gym.envs.gcc.datasets import get_gcc_datasets from compiler_gym.envs.gcc.gcc import Gcc, GccSpec @@ -63,9 +65,13 @@ def __init__( :raises ServiceInitError: If the requested GCC version cannot be used. """ - connection_settings = connection_settings or ConnectionOpts() # Pass the executable path via an environment variable - connection_settings.script_env = {"CC": gcc_bin} + if connection_settings is None: + connection_settings = ConnectionOpts(script_env=frozendict({"CC": gcc_bin})) + else: + connection_settings = ConnectionOpts( + script_env=frozendict({"CC": gcc_bin}, **connection_settings._asdict()) + ) # Eagerly create a GCC compiler instance now because: # @@ -88,6 +94,9 @@ def __init__( ) self._timeout = timeout + def commandline_to_actions(self, commandline: str) -> List[int]: + return NotImplementedError + def reset( self, benchmark: Optional[Union[str, Benchmark]] = None, diff --git a/compiler_gym/requirements.txt b/compiler_gym/requirements.txt index b2f06105f..45d51cfb4 100644 --- a/compiler_gym/requirements.txt +++ b/compiler_gym/requirements.txt @@ -2,6 +2,7 @@ absl-py>=0.10.0 deprecated>=1.2.12 docker>=4.0.0 fasteners>=0.15 +frozendict>=1.0.0 grpcio>=1.32.0,<1.44.0 gym>=0.18.0,<0.21 humanize>=2.6.0 diff --git a/compiler_gym/service/BUILD b/compiler_gym/service/BUILD index 194882571..958624d44 100644 --- a/compiler_gym/service/BUILD +++ b/compiler_gym/service/BUILD @@ -14,6 +14,7 @@ py_library( deps = [ ":compilation_session", ":connection", + ":connection_pool", # TODO(github.com/facebookresearch/CompilerGym/pull/633): # add this after circular dependencies are resolved # ":client_service_compiler_env", @@ -76,6 +77,15 @@ py_library( ], ) +py_library( + name = "connection_pool", + srcs = ["connection_pool.py"], + visibility = ["//visibility:public"], + deps = [ + ":connection", + ] +) + py_library( name = "service_cache", srcs = ["service_cache.py"], diff --git a/compiler_gym/service/__init__.py b/compiler_gym/service/__init__.py index ddd10fc14..98074b0d1 100644 --- a/compiler_gym/service/__init__.py +++ b/compiler_gym/service/__init__.py @@ -14,12 +14,14 @@ ServiceTransportError, SessionNotFound, ) +from compiler_gym.service.connection_pool import ServiceConnectionPool __all__ = [ - "CompilerGymServiceConnection", "CompilationSession", + "CompilerGymServiceConnection", "ConnectionOpts", "EnvironmentNotSupported", + "ServiceConnectionPool", "ServiceError", "ServiceInitError", "ServiceIsClosed", diff --git a/compiler_gym/service/client_service_compiler_env.py b/compiler_gym/service/client_service_compiler_env.py index d46ea0527..23bacd946 100644 --- a/compiler_gym/service/client_service_compiler_env.py +++ b/compiler_gym/service/client_service_compiler_env.py @@ -14,6 +14,7 @@ from time import time from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union +from compiler_gym.service.connection_pool import ServiceConnectionPool import numpy as np from deprecated.sphinx import deprecated from gym.spaces import Space @@ -135,9 +136,9 @@ def __init__( reward_space: Optional[Union[str, Reward]] = None, action_space: Optional[str] = None, derived_observation_spaces: Optional[List[Dict[str, Any]]] = None, - service_message_converters: ServiceMessageConverters = None, connection_settings: Optional[ConnectionOpts] = None, service_connection: Optional[CompilerGymServiceConnection] = None, + service_pool: Optional[ServiceConnectionPool] = None, logger: Optional[logging.Logger] = None, ): """Construct and initialize a CompilerGym environment. @@ -167,7 +168,7 @@ def __init__( `. If not provided, :func:`step()` returns :code:`None` for the observation value. Can be set later using :meth:`env.observation_space - `. For available + `. For available spaces, see :class:`env.observation.spaces `. @@ -176,7 +177,7 @@ def __init__( `. If not provided, :func:`step()` returns :code:`None` for the reward value. Can be set later using :meth:`env.reward_space - `. For available spaces, + `. For available spaces, see :class:`env.reward.spaces `. :param action_space: The name of the action space to use. If not @@ -186,14 +187,15 @@ def __init__( passed to :meth:`env.observation.add_derived_space() `. - :param service_message_converters: Custom converters for action spaces and actions. - :param connection_settings: The settings used to establish a connection with the remote service. :param service_connection: An existing compiler gym service connection to use. + :param service_pool: A service pool to use for acquiring a service + connection. If not specified, the global service pool is used. + :raises FileNotFoundError: If service is a path to a file that is not found. @@ -204,9 +206,9 @@ def __init__( # in release 0.2.3. if logger: warnings.warn( - "The `logger` argument is deprecated on ClientServiceCompilerEnv.__init__() " - "and will be removed in a future release. All ClientServiceCompilerEnv " - "instances share a logger named compiler_gym.service.client_service_compiler_env", + "The `logger` argument is deprecated on CompilerEnv.__init__() " + "and will be removed in a future release. All CompilerEnv " + "instances share a logger named compiler_gym.envs.compiler_env", DeprecationWarning, ) @@ -219,11 +221,19 @@ def __init__( self._service_endpoint: Union[str, Path] = service self._connection_settings = connection_settings or ConnectionOpts() - self.service = service_connection or CompilerGymServiceConnection( - endpoint=self._service_endpoint, - opts=self._connection_settings, - ) - self._datasets = Datasets(datasets or []) + if service_connection is None: + self._service_pool = ( + ServiceConnectionPool.get() if service_pool is None else service_pool + ) + self.service = self._service_pool.acquire( + endpoint=self._service_endpoint, + opts=self._connection_settings, + ) + else: + self._service_pool = service_pool + self.service = service_connection + + self.datasets = Datasets(datasets or []) self.action_space_name = action_space @@ -266,21 +276,14 @@ def __init__( self._benchmark_in_use = self._next_benchmark except StopIteration: # StopIteration raised on next(self.datasets.benchmarks()) if there - # are no benchmarks available. This is to allow ClientServiceCompilerEnv to be + # are no benchmarks available. This is to allow CompilerEnv to be # used without any datasets by setting a benchmark before/during the # first reset() call. pass - self.service_message_converters = ( - ServiceMessageConverters() - if service_message_converters is None - else service_message_converters - ) - # Process the available action, observation, and reward spaces. self.action_spaces = [ - self.service_message_converters.action_space_converter(space) - for space in self.service.action_spaces + proto_to_action_space(space) for space in self.service.action_spaces ] self.observation = self._observation_view_type( @@ -302,13 +305,13 @@ def __init__( # Mutable state initialized in reset(). self._reward_range: Tuple[float, float] = (-np.inf, np.inf) - self.episode_reward = None + self.episode_reward: Optional[float] = None self.episode_start_time: float = time() - self._actions: List[ActionType] = [] + self.actions: List[ActionType] = [] # Initialize the default observation/reward spaces. - self.observation_space_spec = None - self.reward_space_spec = None + self.observation_space_spec: Optional[ObservationSpaceSpec] = None + self.reward_space_spec: Optional[Reward] = None self.observation_space = observation_space self.reward_space = reward_space @@ -545,7 +548,7 @@ def _init_kwargs(self) -> Dict[str, Any]: } def fork(self) -> "ClientServiceCompilerEnv": - if not self.in_episode: + if not self.in_episode: actions = self.actions.copy() self.reset() if actions: @@ -601,7 +604,7 @@ def fork(self) -> "ClientServiceCompilerEnv": # Copy over the mutable episode state. new_env.episode_reward = self.episode_reward new_env.episode_start_time = self.episode_start_time - new_env._actions = self.actions.copy() + new_env.actions = self.actions.copy() return new_env @@ -687,7 +690,7 @@ def _retry(error) -> Optional[ObservationType]: ) log_severity("%s during reset(): %s", type(error).__name__, error) - if self.service: + if self.service is not None: try: self.service.close() except ServiceError as e: @@ -699,6 +702,7 @@ def _retry(error) -> Optional[ObservationType]: e, type(e).__name__, ) + self.service = None if retry_count >= self._connection_settings.init_max_attempts: @@ -734,8 +738,15 @@ def _call_with_error( # Start a new service if required. if self.service is None: - self.service = CompilerGymServiceConnection( - self._service_endpoint, self._connection_settings + self.service = ( + CompilerGymServiceConnection( + self._service_endpoint, self._connection_settings + ) + if self._service_pool is None + else self._service_pool.acquire( + endpoint=self._service_endpoint, + opts=self._connection_settings, + ) ) self.action_space_name = action_space or self.action_space_name @@ -810,13 +821,11 @@ def _call_with_error( self.observation.session_id = reply.session_id self.reward.get_cost = self.observation.__getitem__ self.episode_start_time = time() - self._actions = [] + self.actions = [] # If the action space has changed, update it. if reply.HasField("new_action_space"): - self.action_space = self.service_message_converters.action_space_converter( - reply.new_action_space - ) + self.action_space = proto_to_action_space(reply.new_action_space) self.reward.reset(benchmark=self.benchmark, observation_view=self.observation) if self.reward_space: @@ -852,14 +861,14 @@ def raw_step( and rewards are lists. :raises SessionNotFound: If :meth:`reset() - ` has not been called. + ` has not been called. .. warning:: Don't call this method directly, use :meth:`step() - ` or :meth:`multistep() - ` instead. The - :meth:`raw_step() ` method is an + ` or :meth:`multistep() + ` instead. The + :meth:`raw_step() ` method is an implementation detail. """ if not self.in_episode: @@ -880,14 +889,12 @@ def raw_step( } # Record the actions. - self._actions += actions + self.actions += actions # Send the request to the backend service. request = StepRequest( session_id=self._session_id, - action=[ - self.service_message_converters.action_converter(a) for a in actions - ], + action=[Event(int64_value=a) for a in actions], observation_space=[ observation_space.index for observation_space in observations_to_compute ], @@ -931,9 +938,7 @@ def raw_step( # If the action space has changed, update it. if reply.HasField("new_action_space"): - self.action_space = self.service_message_converters.action_space_converter( - reply.new_action_space - ) + self.action_space = proto_to_action_space(reply.new_action_space) # Translate observations to python representations. if len(reply.observation) != len(observations_to_compute): diff --git a/compiler_gym/service/connection.py b/compiler_gym/service/connection.py index 22b22cb0e..fd08f4024 100644 --- a/compiler_gym/service/connection.py +++ b/compiler_gym/service/connection.py @@ -533,10 +533,8 @@ def close(self): def __repr__(self): if self.process.poll() is None: - return ( - f"Connection to service at {self.url} running on PID {self.process.pid}" - ) - return f"Connection to dead service at {self.url}" + return f"ManagedConnection({self.url}, pid={self.process.pid})" + return f"ManagedConnection({self.url}, not running)" class UnmanagedConnection(Connection): @@ -576,7 +574,7 @@ def __init__(self, url: str, rpc_init_max_seconds: float): super().__init__(channel, url) def __repr__(self): - return f"Connection to unmanaged service {self.url}" + return f"UnmanagedConnection({self.url})" class CompilerGymServiceConnection: @@ -631,20 +629,33 @@ class CompilerGymServiceConnection: def __init__( self, endpoint: Union[str, Path], - opts: ConnectionOpts = None, + opts: ConnectionOpts, + owning_service_pool: Optional["ServiceConnectionPool"] = None, # noqa: F821 ): """Constructor. :param endpoint: The connection endpoint. Either the URL of a service, e.g. "localhost:8080", or the path of a local service binary. + :param opts: The connection options. + + :param owning_service_pool: A backref to the owning + :class:`ServiceConnectionPool + `, if this service is + managed by one. + :raises ValueError: If the provided options are invalid. - :raises FileNotFoundError: In case opts.local_service_binary is not found. + + :raises FileNotFoundError: In case opts.local_service_binary is not + found. + :raises TimeoutError: In case the service failed to start within opts.init_max_seconds seconds. """ + self.released = False self.endpoint = endpoint self.opts = opts or ConnectionOpts() + self.owning_service_pool = owning_service_pool self.connection = None self.stub = None self._establish_connection() @@ -727,21 +738,56 @@ def _create_connection( ) def __repr__(self): - if self.connection is None: - return f"Closed connection to {self.endpoint}" - return str(self.endpoint) + return f"CompilerGymServiceConnection({self.connection or 'detached'})" @property def closed(self) -> bool: """Whether the connection is closed.""" return self.connection is None - def close(self): + def acquire(self) -> "CompilerGymServiceConnection": + """Mark this connection as in-use.""" + if not self.released: + raise TypeError( + "Attempting to acquire a connection that is already acquired." + ) + self.released = False + return self + + def release(self) -> None: + """Mark this connection as not in-use.""" + if not self.released: + self.owning_service_pool.release(self) + self.released = True + + def shutdown(self): + """Shut down the connection. + + Once a connection has been shutdown, it cannot be re-used. + """ if self.closed: return + self.connection.close() self.connection = None + def close(self): + """Mark this connection as closed. + + If the service is managed by a :class:`ServiceConnectionPool + `, this will indicate to + the pool that the connection is safe to re-use. If the service is not + managed by a pool, this will shut it down. + """ + if self.owned_by_service_pool: + self.release() + else: + self.shutdown() + + @property + def owned_by_service_pool(self): + return self.owning_service_pool is not None + def __del__(self): # Don't let the subprocess be orphaned if user forgot to close(), or # if an exception was thrown. @@ -826,3 +872,12 @@ def __call__( retry_wait_backoff_exponent or self.opts.retry_wait_backoff_exponent ), ) + + def __enter__(self) -> "CompilerGymServiceConnection": + """Support for 'with' statements.""" + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Support for 'with' statements.""" + self.close() + return False diff --git a/compiler_gym/service/connection_pool.py b/compiler_gym/service/connection_pool.py new file mode 100644 index 000000000..4b0f29a09 --- /dev/null +++ b/compiler_gym/service/connection_pool.py @@ -0,0 +1,179 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""This module contains a reusable pool of service connections.""" +import atexit +import logging +from collections import defaultdict +from pathlib import Path +from threading import Lock +from typing import Dict, List, Set, Tuple + +from compiler_gym.service.connection import CompilerGymServiceConnection, ConnectionOpts + +logger = logging.getLogger(__name__) + +# We identify connections by the binary path and set of connection opts. +ServiceConnectionCacheKey = Tuple[Path, ConnectionOpts] + + +class ServiceConnectionPool: + """An object pool for compiler service connections. + + This class implements a thread-safe pool for compiler service connections. + This enables compiler service connections to be reused, avoiding the + expensive initialization of a new service. + + There is a global instance of this class, available via the static + :meth:`ServiceConnectionPool.get() + ` method. + + To use the pool, acquire a reference to the global instance, and call the + :meth:`ServiceConnectionPool.acquire() + ` method to construct and + return service connections: + + >>> pool = ServiceConnectionPool.get() + >>> with pool.acquire(Path("/path/to/service"), ConnectionOpts()) as service: + ... # Do something with the service. + + When a service is closed (by calling :meth:`service.close() + `), it is + automatically released back to the pool so that a future request for the + same type of service will reuse the connection. + """ + + def __init__(self) -> None: + """""" + self._lock = Lock() + self.pool: Dict[ + ServiceConnectionCacheKey, List[CompilerGymServiceConnection] + ] = defaultdict(list) + self.allocated: Set[CompilerGymServiceConnection] = set() + + # Add a flag to indicate a closed connection pool because of + # out-of-order execution of destructors and the atexit callback. + self.closed = False + + atexit.register(self.close) + + def acquire( + self, endpoint: Path, opts: ConnectionOpts + ) -> CompilerGymServiceConnection: + """Acquire a service connection from the pool. + + If an existing connection is available in the pool, it is returned. + Otherwise, a new connection is created. + """ + key: ServiceConnectionCacheKey = (endpoint, opts) + with self._lock: + if self.closed: + # This should never happen. + raise TypeError("ServiceConnectionPool is closed") + + if self.pool[key]: + service = self.pool[key].pop().acquire() + logger.debug( + "Reusing %s, %d environments remaining in pool", + service.connection.url, + len(self.pool[key]), + ) + else: + # No free service connections, construct a new one. + service = CompilerGymServiceConnection( + endpoint=endpoint, opts=opts, owning_service_pool=self + ) + logger.debug("Created %s", service.connection.url) + + self.allocated.add(service) + + return service + + def release(self, service: CompilerGymServiceConnection) -> None: + """Release a service connection back to the pool. + + .. note:: + + This method is called automatically by the :meth:`service.close() + ` method of + acquired service connections. You do not have to call this method + yourself. + """ + key: ServiceConnectionCacheKey = (service.endpoint, service.opts) + with self._lock: + # During shutdown, the shutdown routine for this + # ServiceConnectionPool may be called before the destructor of + # the managed CompilerGymServiceConnection objects. + if self.closed: + return + + self.allocated.remove(service) + + # A dead service cannot be reused, discard it. + if service.closed or service.connection.process.poll() is not None: + return + + self.pool[key].append(service) + + logger.debug("Released %s, pool size %d", service.connection.url, self.size) + + def __contains__(self, service: CompilerGymServiceConnection): + """Check if a service connection is managed by the pool.""" + key: ServiceConnectionCacheKey = (service.endpoint, service.opts) + return service in self.allocated or service in self.pool[key] + + @property + def size(self): + """Return the total number of connections in the pool.""" + return sum(len(x) for x in self.pool.values()) + len(self.allocated) + + def __len__(self): + return self.size + + def close(self) -> None: + """Close the pool, terminating all connections. + + Once closed, the pool cannot be used again. It is safe to call this + method more than once. + """ + with self._lock: + if self.closed: + return + + logging.debug( + "Closing the service connection pool with %d cached and %d live connections", + self.size, + len(self.allocated), + ) + for connections in self.pool.values(): + for connection in connections: + connection.shutdown() + self.pool = defaultdict(list) + for connection in self.allocated: + connection.shutdown() + self.allocated = set() + self.closed = True + + def __del__(self) -> None: + self.close() + + def __enter__(self) -> "ServiceConnectionPool": + """Support for "with" statement.""" + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + """Support for "with" statement.""" + self.close() + return False + + @staticmethod + def get() -> "ServiceConnectionPool": + """Return the global instance of the service connection pool.""" + return _SERVICE_CONNECTION_POOL + + def __repr__(self) -> str: + return f"ServiceConnectionPool(size={self.size})" + + +_SERVICE_CONNECTION_POOL = ServiceConnectionPool() diff --git a/docs/source/compiler_gym/service.rst b/docs/source/compiler_gym/service.rst index 2f474dc1c..1499a2877 100644 --- a/docs/source/compiler_gym/service.rst +++ b/docs/source/compiler_gym/service.rst @@ -47,6 +47,7 @@ The connection object .. automethod:: __init__ .. automethod:: __call__ + Configuring the connection -------------------------- @@ -57,6 +58,13 @@ to configure the options used for managing a service connection. :members: +The connection pool +------------------- + +.. autoclass:: ServiceConnectionPool + :members: + + Exceptions ---------- diff --git a/tests/llvm/BUILD b/tests/llvm/BUILD index d3639abde..be4e75f97 100644 --- a/tests/llvm/BUILD +++ b/tests/llvm/BUILD @@ -110,7 +110,7 @@ py_test( "//compiler_gym/third_party/cbench:crc32", ], deps = [ - "//compiler_gym/envs", + "//compiler_gym", "//tests:test_main", "//tests/pytest_plugins:llvm", ], diff --git a/tests/llvm/CMakeLists.txt b/tests/llvm/CMakeLists.txt index f1c5403ff..a64199fee 100644 --- a/tests/llvm/CMakeLists.txt +++ b/tests/llvm/CMakeLists.txt @@ -108,7 +108,7 @@ cg_py_test( DATA compiler_gym::third_party::cbench::crc32 DEPS - compiler_gym::envs::envs + compiler_gym tests::pytest_plugins::llvm tests::test_main ) diff --git a/tests/llvm/custom_benchmarks_test.py b/tests/llvm/custom_benchmarks_test.py index e7372710e..15849b0f0 100644 --- a/tests/llvm/custom_benchmarks_test.py +++ b/tests/llvm/custom_benchmarks_test.py @@ -55,7 +55,7 @@ def test_invalid_benchmark_missing_file(env: LlvmEnv): ) ) - with pytest.raises(ValueError, match="No program set"): + with pytest.raises(ValueError, match="No program set in Benchmark:"): env.reset(benchmark=benchmark) diff --git a/tests/llvm/fork_env_test.py b/tests/llvm/fork_env_test.py index 5b17de3ad..09fe2263c 100644 --- a/tests/llvm/fork_env_test.py +++ b/tests/llvm/fork_env_test.py @@ -8,8 +8,13 @@ import pytest -from compiler_gym.envs import LlvmEnv -from compiler_gym.service import ServiceError +import compiler_gym +from compiler_gym.envs.llvm import LLVM_SERVICE_BINARY, LlvmEnv +from compiler_gym.service import ( + CompilerGymServiceConnection, + ConnectionOpts, + ServiceError, +) from compiler_gym.util.runfiles_path import runfiles_path from tests.test_main import main @@ -32,70 +37,78 @@ def test_with_statement(env: LlvmEnv): assert env.in_episode -def test_fork_child_process_is_not_orphaned(env: LlvmEnv): - env.reset("cbench-v1/crc32") - with env.fork() as fkd: - # Check that both environments share the same service. - assert isinstance(env.service.connection.process, subprocess.Popen) - assert isinstance(fkd.service.connection.process, subprocess.Popen) +def test_fork_child_process_is_not_orphaned(): + service = CompilerGymServiceConnection(LLVM_SERVICE_BINARY, ConnectionOpts()) - assert env.service.connection.process.pid == fkd.service.connection.process.pid - process = env.service.connection.process + with compiler_gym.make("llvm-v0", service_connection=service) as env: + env.reset("cbench-v1/crc32") + with env.fork() as fkd: + # Check that both environments share the same service. + assert isinstance(env.service.connection.process, subprocess.Popen) + assert isinstance(fkd.service.connection.process, subprocess.Popen) - # Sanity check that both services are alive. - assert not env.service.connection.process.poll() - assert not fkd.service.connection.process.poll() + assert ( + env.service.connection.process.pid == fkd.service.connection.process.pid + ) + process = env.service.connection.process - # Close the parent service. - env.close() + # Sanity check that both services are alive. + assert not env.service.connection.process.poll() + assert not fkd.service.connection.process.poll() - # Check that the service is still alive. - assert not env.service - assert not fkd.service.connection.process.poll() + # Close the parent service. + env.close() - # Close the forked service. - fkd.close() + # Check that the service is still alive. + assert not env.service + assert not fkd.service.connection.process.poll() - # Check that the service has been killed. - assert process.poll() is not None + # Close the forked service. + fkd.close() + + # Check that the service has been killed. + assert process.poll() is not None def test_fork_chain_child_processes_are_not_orphaned(env: LlvmEnv): - env.reset("cbench-v1/crc32") + service = CompilerGymServiceConnection(LLVM_SERVICE_BINARY, ConnectionOpts()) - # Create a chain of forked environments. - a = env.fork() - b = a.fork() - c = b.fork() - d = c.fork() + with compiler_gym.make("llvm-v0", service_connection=service) as env: + env.reset() - try: - # Sanity check that they share the same underlying service. - assert ( - env.service.connection.process - == a.service.connection.process - == b.service.connection.process - == c.service.connection.process - == d.service.connection.process - ) - proc = env.service.connection.process - # Kill the forked environments one by one. - a.close() - assert proc.poll() is None - b.close() - assert proc.poll() is None - c.close() - assert proc.poll() is None - d.close() - assert proc.poll() is None - # Kill the final environment, refcount 0, service is closed. - env.close() - assert proc.poll() is not None - finally: - a.close() - b.close() - c.close() - d.close() + # Create a chain of forked environments. + a = env.fork() + b = a.fork() + c = b.fork() + d = c.fork() + + try: + # Sanity check that they share the same underlying service. + assert ( + env.service.connection.process + == a.service.connection.process + == b.service.connection.process + == c.service.connection.process + == d.service.connection.process + ) + proc = env.service.connection.process + # Kill the forked environments one by one. + a.close() + assert proc.poll() is None + b.close() + assert proc.poll() is None + c.close() + assert proc.poll() is None + d.close() + assert proc.poll() is None + # Kill the final environment, refcount 0, service is closed. + env.close() + assert proc.poll() is not None + finally: + a.close() + b.close() + c.close() + d.close() def test_fork_before_reset(env: LlvmEnv): @@ -243,7 +256,7 @@ def test_fork_previous_cost_lazy_reward_update(env: LlvmEnv): env.reset("cbench-v1/crc32") env.step(env.action_space.flags.index("-mem2reg")) - env.reward["IrInstructionCount"] + env.reward["IrInstructionCount"] # noqa with env.fork() as fkd: env.step(env.action_space.flags.index("-mem2reg")) fkd.step(env.action_space.flags.index("-mem2reg")) diff --git a/tests/llvm/service_connection_test.py b/tests/llvm/service_connection_test.py index d7a4f87e2..581671253 100644 --- a/tests/llvm/service_connection_test.py +++ b/tests/llvm/service_connection_test.py @@ -13,6 +13,7 @@ from compiler_gym.errors import ServiceError from compiler_gym.service.client_service_compiler_env import ClientServiceCompilerEnv from compiler_gym.service.connection import CompilerGymServiceConnection +from compiler_gym.service.connection import CompilerGymServiceConnection, ConnectionOpts from compiler_gym.third_party.autophase import AUTOPHASE_FEATURE_DIM from tests.test_main import main @@ -27,7 +28,9 @@ def env(request) -> ClientServiceCompilerEnv: with gym.make("llvm-v0") as env: yield env else: - service = CompilerGymServiceConnection(llvm.LLVM_SERVICE_BINARY) + service = CompilerGymServiceConnection( + llvm.LLVM_SERVICE_BINARY, ConnectionOpts() + ) try: with LlvmEnv(service=service.connection.url) as env: yield env @@ -45,7 +48,7 @@ def test_service_env_dies_reset(env: ClientServiceCompilerEnv): # with env.reset() above. For UnmanagedConnection, this error will not be # raised. try: - env.service.close() + env.service.shutdown() except ServiceError as e: assert "Service exited with returncode " in str(e) diff --git a/tests/pytest_plugins/llvm.py b/tests/pytest_plugins/llvm.py index 3eeccf313..b8ebacdf5 100644 --- a/tests/pytest_plugins/llvm.py +++ b/tests/pytest_plugins/llvm.py @@ -93,7 +93,7 @@ def non_validatable_cbench_uri(request) -> str: @pytest.fixture(scope="function") def env() -> LlvmEnv: - """Create an LLVM environment.""" + """Test fixture that yields an environment.""" with gym.make("llvm-v0") as env_: yield env_ diff --git a/tests/service/BUILD b/tests/service/BUILD index c5b115c42..d8fb8117a 100644 --- a/tests/service/BUILD +++ b/tests/service/BUILD @@ -17,6 +17,17 @@ py_test( ], ) +py_test( + name = "connection_pool_test", + srcs = ["connection_pool_test.py"], + deps = [ + "//compiler_gym", + "//compiler_gym/service", + "//tests:test_main", + "//tests/pytest_plugins:llvm", + ], +) + py_test( name = "service_cache_test", timeout = "short", diff --git a/tests/service/CMakeLists.txt b/tests/service/CMakeLists.txt index b0fd0fd3b..6d1b77a5c 100644 --- a/tests/service/CMakeLists.txt +++ b/tests/service/CMakeLists.txt @@ -18,4 +18,16 @@ if(COMPILER_GYM_ENABLE_LLVM_ENV) compiler_gym::service::service_cache tests::test_main ) + + cg_py_test( + NAME + connection_pool_test + SRCS + "connection_pool_test.py" + DEPS + compiler_gym::errors::errors + compiler_gym::service::service + tests::pytest_plugins::llvm + tests::test_main + ) endif() diff --git a/tests/service/connection_pool_test.py b/tests/service/connection_pool_test.py new file mode 100644 index 000000000..c897202bb --- /dev/null +++ b/tests/service/connection_pool_test.py @@ -0,0 +1,184 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""Unit tests for compiler_gym/service/connection_pool.py.""" + +import pytest + +import compiler_gym +from compiler_gym.envs.llvm import LLVM_SERVICE_BINARY +from compiler_gym.service import ConnectionOpts, ServiceConnectionPool +from compiler_gym.errors import ServiceError +from tests.test_main import main + +pytest_plugins = ["tests.pytest_plugins.llvm"] + + +@pytest.fixture(scope="function") +def pool() -> ServiceConnectionPool: + with ServiceConnectionPool() as pool_: + yield pool_ + + +def test_service_pool_with_statement(): + with ServiceConnectionPool() as pool: + assert not pool.closed + assert pool.closed + + +def test_service_pool_double_close(pool: ServiceConnectionPool): + assert not pool.closed + pool.close() + assert pool.closed + pool.close() + assert pool.closed + + +def test_service_pool_acquire_release(pool: ServiceConnectionPool): + service = pool.acquire(LLVM_SERVICE_BINARY, ConnectionOpts()) + assert service in pool + service.release() + assert service in pool + + +def test_service_pool_contains(pool: ServiceConnectionPool): + with ServiceConnectionPool() as other_pool: + with pool.acquire(LLVM_SERVICE_BINARY, ConnectionOpts()) as service: + assert service in pool + assert service not in other_pool + assert service not in ServiceConnectionPool.get() + + # Service remains in pool after release. + assert service in pool + + +def test_service_pool_close_frees_service(pool: ServiceConnectionPool): + service = pool.acquire(LLVM_SERVICE_BINARY, ConnectionOpts()) + assert not service.closed + pool.close() + assert service.closed + + +def test_service_pool_service_is_not_closed(pool: ServiceConnectionPool): + service = None + service = pool.acquire(LLVM_SERVICE_BINARY, ConnectionOpts()) + service.close() + assert not service.closed + + +def test_service_pool_with_service_is_not_closed(pool: ServiceConnectionPool): + service = None + with pool.acquire(LLVM_SERVICE_BINARY, ConnectionOpts()) as service: + assert not service.closed + assert not service.closed + + +def test_service_pool_with_env_is_not_closed(pool: ServiceConnectionPool): + with compiler_gym.make("llvm-v0", service_pool=pool) as env: + service = env.service + assert not service.closed + assert not service.closed + + +def test_service_pool_fork(pool: ServiceConnectionPool): + with compiler_gym.make("llvm-v0", service_pool=pool) as env: + env.reset() + with env.fork() as fkd: + fkd.reset() + assert env.service == fkd.service + assert not env.service.closed + assert not env.service.closed + + +def test_service_pool_release_service(pool: ServiceConnectionPool): + service = pool.acquire(LLVM_SERVICE_BINARY, ConnectionOpts()) + service.close() + # A released service remains alive. + assert not service.closed + + +def test_service_pool_release_dead_service(pool: ServiceConnectionPool): + service = pool.acquire(LLVM_SERVICE_BINARY, ConnectionOpts()) + service.shutdown() + assert service.closed + service.close() + # A dead service cannot be reused, discard it. + assert service not in pool + + +def test_service_pool_size(pool: ServiceConnectionPool): + assert pool.size == 0 + assert len(pool) == pool.size + + with pool.acquire(LLVM_SERVICE_BINARY, ConnectionOpts()): + assert pool.size == 1 + assert len(pool) == pool.size + with pool.acquire(LLVM_SERVICE_BINARY, ConnectionOpts()): + assert pool.size == 2 + assert len(pool) == pool.size + + +def test_service_pool_make_release(pool: ServiceConnectionPool): + with compiler_gym.make("llvm-v0", service_pool=pool) as a: + assert len(pool) == 1 + with compiler_gym.make("llvm-v0", service_pool=pool) as b: + a_service = a.service + b_service = b.service + assert a_service != b_service + assert len(pool) == 2 + + with compiler_gym.make("llvm-v0", service_pool=pool) as c: + c_service = c.service + assert a_service == c_service + assert a_service != b_service + assert pool.size == 2 + + +def test_service_pool_make_release_loop(pool: ServiceConnectionPool): + for _ in range(5): + with compiler_gym.make("llvm-v0", service_pool=pool): + assert pool.size == 1 + assert pool.size == 1 + + +def test_service_pool_environment_restarts_service(pool: ServiceConnectionPool): + with compiler_gym.make("llvm-v0", service_pool=pool) as env: + old_service = env.service + env.service.shutdown() + env.service.close() + assert env.service.closed + + # For environment to restart service. + env.reset() + assert not env.service.closed + + new_service = env.service + assert new_service in pool + assert old_service not in pool + + +def test_service_pool_forked_service_dies(pool: ServiceConnectionPool): + with compiler_gym.make("llvm-v0", service_pool=pool) as env: + with env.fork() as fkd: + assert env.service == fkd.service + try: + fkd.service.shutdown() + except ServiceError: + pass # shutdown() raises service error if in-episode. + fkd.service.close() + + env.reset() + fkd.reset() + assert env.service != fkd.service + assert env.service in pool + assert fkd.service in pool + + +# TODO: Test case where forked environment kills the service. + +# TODO: Service pool connection does not interfere with pool. + + +if __name__ == "__main__": + main() diff --git a/tests/service/connection_test.py b/tests/service/connection_test.py index c51e58463..63e051ba5 100644 --- a/tests/service/connection_test.py +++ b/tests/service/connection_test.py @@ -33,7 +33,7 @@ def dead_connection() -> CompilerGymServiceConnection: def test_create_invalid_options(): with pytest.raises(TypeError, match="No endpoint provided for service connection"): - CompilerGymServiceConnection("") + CompilerGymServiceConnection("", ConnectionOpts()) def test_create_channel_failed_subprocess( @@ -89,16 +89,13 @@ def test_call_stub_negative_timeout(connection: CompilerGymServiceConnection): def test_ManagedConnection_repr(connection: CompilerGymServiceConnection): cnx = connection.connection - assert ( - repr(cnx) - == f"Connection to service at {cnx.url} running on PID {cnx.process.pid}" - ) + assert repr(cnx) == f"ManagedConnection({cnx.url}, pid={cnx.process.pid})" # Kill the service. cnx.process.terminate() cnx.process.communicate() - assert repr(cnx) == f"Connection to dead service at {cnx.url}" + assert repr(cnx) == f"ManagedConnection({cnx.url}, not running)" if __name__ == "__main__": From 0bb97b174b1413c4022624dec1438be22909be6e Mon Sep 17 00:00:00 2001 From: Chris Cummins Date: Wed, 23 Feb 2022 17:05:50 +0000 Subject: [PATCH 05/41] Hardening patch for shutdown errors. --- compiler_gym/service/connection.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/compiler_gym/service/connection.py b/compiler_gym/service/connection.py index fd08f4024..5a2d0b15d 100644 --- a/compiler_gym/service/connection.py +++ b/compiler_gym/service/connection.py @@ -768,7 +768,17 @@ def shutdown(self): if self.closed: return - self.connection.close() + try: + self.connection.close() + except ServiceError as e: + # close() can raise ServiceError if the service exists with a + # non-zero return code. We swallow the error here as we are + # disposing o f the service. + logger.debug( + "Ignoring service error during shutdown attempt: %s (%s)", + e, + type(e).__name__, + ) self.connection = None def close(self): @@ -786,7 +796,11 @@ def close(self): @property def owned_by_service_pool(self): - return self.owning_service_pool is not None + # Defensive hasattr() test because this property is accessed by the + # destructor, where the object could be in a partially initialized + # state. + if hasattr(self, "owning_service_pool"): + return self.owning_service_pool is not None def __del__(self): # Don't let the subprocess be orphaned if user forgot to close(), or From 781727fda4bc4a32f9db21ee0e75c8a3d2f90997 Mon Sep 17 00:00:00 2001 From: Chris Cummins Date: Wed, 23 Feb 2022 17:06:11 +0000 Subject: [PATCH 06/41] [gcc] Fix connection property setter. --- compiler_gym/envs/gcc/gcc_env.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/compiler_gym/envs/gcc/gcc_env.py b/compiler_gym/envs/gcc/gcc_env.py index 52049a728..2e58d52fe 100644 --- a/compiler_gym/envs/gcc/gcc_env.py +++ b/compiler_gym/envs/gcc/gcc_env.py @@ -69,9 +69,10 @@ def __init__( if connection_settings is None: connection_settings = ConnectionOpts(script_env=frozendict({"CC": gcc_bin})) else: - connection_settings = ConnectionOpts( - script_env=frozendict({"CC": gcc_bin}, **connection_settings._asdict()) - ) + script_env = frozendict({"CC": gcc_bin}, **connection_settings.script_env) + opts = connection_settings._asdict() + opts["script_env"] = script_env + connection_settings = ConnectionOpts(**opts) # Eagerly create a GCC compiler instance now because: # From bdada20fcf3835228faa5c2731a5c77bc6f7f2aa Mon Sep 17 00:00:00 2001 From: Chris Cummins Date: Wed, 23 Feb 2022 17:19:27 +0000 Subject: [PATCH 07/41] [tests] Update tests. --- tests/llvm/llvm_env_test.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/llvm/llvm_env_test.py b/tests/llvm/llvm_env_test.py index 8402098c7..59f0365ef 100644 --- a/tests/llvm/llvm_env_test.py +++ b/tests/llvm/llvm_env_test.py @@ -20,7 +20,7 @@ from compiler_gym.envs import CompilerEnv, llvm from compiler_gym.envs.llvm.llvm_env import LlvmEnv from compiler_gym.errors import ServiceError -from compiler_gym.service.connection import CompilerGymServiceConnection +from compiler_gym.service.connection import CompilerGymServiceConnection, ConnectionOpts from tests.pytest_plugins import llvm as llvm_plugin from tests.test_main import main @@ -34,7 +34,9 @@ def env(request) -> CompilerEnv: with gym.make("llvm-v0") as env: yield env else: - service = CompilerGymServiceConnection(llvm.LLVM_SERVICE_BINARY) + service = CompilerGymServiceConnection( + llvm.LLVM_SERVICE_BINARY, ConnectionOpts() + ) try: with LlvmEnv(service=service.connection.url) as env: yield env @@ -90,7 +92,7 @@ def test_connection_dies_default_reward(env: LlvmEnv): # with env.reset() above. For UnmanagedConnection, this error will not be # raised. try: - env.service.close() + env.service.shutdown() except ServiceError as e: assert "Service exited with returncode " in str(e) @@ -114,7 +116,7 @@ def test_connection_dies_default_reward_negated(env: LlvmEnv): # with env.reset() above. For UnmanagedConnection, this error will not be # raised. try: - env.service.close() + env.service.shutdown() except ServiceError as e: assert "Service exited with returncode " in str(e) From 4304d515a1b57b81303e531abb1adadbb30cf152 Mon Sep 17 00:00:00 2001 From: Chris Cummins Date: Wed, 23 Feb 2022 17:20:39 +0000 Subject: [PATCH 08/41] [service] Make ConnectionOpts mutable again. --- compiler_gym/service/connection.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/compiler_gym/service/connection.py b/compiler_gym/service/connection.py index 5a2d0b15d..0cccb6e44 100644 --- a/compiler_gym/service/connection.py +++ b/compiler_gym/service/connection.py @@ -10,12 +10,13 @@ from pathlib import Path from signal import Signals from time import sleep, time -from typing import Dict, FrozenSet, Iterable, List, NamedTuple, Optional, TypeVar, Union +from typing import Dict, FrozenSet, Iterable, List, Optional, TypeVar, Union import grpc from deprecated.sphinx import deprecated from pydantic import BaseModel from frozendict import frozendict +from pydantic import BaseModel import compiler_gym.errors from compiler_gym.service.proto import ( @@ -50,7 +51,14 @@ logger = logging.getLogger(__name__) -class ConnectionOpts(NamedTuple): +class HashableBaseModel(BaseModel): + """A pydantic model that is hashable.""" + + def __hash__(self): + return hash((type(self),) + tuple(self.__dict__.values())) + + +class ConnectionOpts(HashableBaseModel): """The options used to configure a connection to a service.""" rpc_call_max_seconds: float = 300 From 37bda6c4c845519d37629ceff5416420185d2f4e Mon Sep 17 00:00:00 2001 From: Chris Cummins Date: Wed, 23 Feb 2022 17:20:50 +0000 Subject: [PATCH 09/41] [service] Fix race in destructor. --- compiler_gym/service/connection.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/compiler_gym/service/connection.py b/compiler_gym/service/connection.py index 0cccb6e44..21191d78d 100644 --- a/compiler_gym/service/connection.py +++ b/compiler_gym/service/connection.py @@ -751,7 +751,10 @@ def __repr__(self): @property def closed(self) -> bool: """Whether the connection is closed.""" - return self.connection is None + # Defensive hasattr() because this property is accessed by destructor. + if hasattr(self, "connection"): + return self.connection is None + return True def acquire(self) -> "CompilerGymServiceConnection": """Mark this connection as in-use.""" From 3e11bb10dc67fd3e28f40d4757d18abdc4dbc77d Mon Sep 17 00:00:00 2001 From: Chris Cummins Date: Wed, 23 Feb 2022 17:21:02 +0000 Subject: [PATCH 10/41] [service] Fix unmanaged service pool. --- compiler_gym/service/connection_pool.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/compiler_gym/service/connection_pool.py b/compiler_gym/service/connection_pool.py index 4b0f29a09..297e0a67a 100644 --- a/compiler_gym/service/connection_pool.py +++ b/compiler_gym/service/connection_pool.py @@ -110,9 +110,11 @@ def release(self, service: CompilerGymServiceConnection) -> None: self.allocated.remove(service) - # A dead service cannot be reused, discard it. - if service.closed or service.connection.process.poll() is not None: - return + # Only managed processes have a process attribute. + if hasattr(service.connection, "process"): + # A dead service cannot be reused, discard it. + if service.closed or service.connection.process.poll() is not None: + return self.pool[key].append(service) From 15b6248935240e1974161fb10e4ea9b239feb5b8 Mon Sep 17 00:00:00 2001 From: Chris Cummins Date: Thu, 3 Mar 2022 10:02:30 +0000 Subject: [PATCH 11/41] Test fixes from new connection pool. --- benchmarks/bench_test.py | 3 ++- compiler_gym/envs/gcc/gcc_env.py | 13 ++++------- compiler_gym/service/connection.py | 8 ++++++- compiler_gym/service/connection_pool.py | 14 +++++++++++- tests/llvm/custom_benchmarks_test.py | 30 ++++++++++++++++--------- tests/llvm/fork_env_test.py | 2 +- tests/service/connection_pool_test.py | 3 +-- 7 files changed, 48 insertions(+), 25 deletions(-) diff --git a/benchmarks/bench_test.py b/benchmarks/bench_test.py index 497f73e53..005581298 100644 --- a/benchmarks/bench_test.py +++ b/benchmarks/bench_test.py @@ -24,6 +24,7 @@ from compiler_gym.envs import CompilerEnv, LlvmEnv, llvm from compiler_gym.service import CompilerGymServiceConnection from compiler_gym.service.client_service_compiler_env import ClientServiceCompilerEnv +from compiler_gym.service import CompilerGymServiceConnection, ConnectionOpts from tests.pytest_plugins.llvm import OBSERVATION_SPACE_NAMES, REWARD_SPACE_NAMES from tests.test_main import main @@ -64,7 +65,7 @@ def test_make_local(benchmark, env_id): ) def test_make_service(benchmark, args): service_binary, env_class = args - service = CompilerGymServiceConnection(service_binary) + service = CompilerGymServiceConnection(service_binary, ConnectionOpts()) try: benchmark(lambda: env_class(service=service.connection.url).close()) finally: diff --git a/compiler_gym/envs/gcc/gcc_env.py b/compiler_gym/envs/gcc/gcc_env.py index 2e58d52fe..77bccfa93 100644 --- a/compiler_gym/envs/gcc/gcc_env.py +++ b/compiler_gym/envs/gcc/gcc_env.py @@ -9,8 +9,6 @@ from pathlib import Path from typing import Any, Dict, List, Optional, Union -from frozendict import frozendict - from compiler_gym.datasets import Benchmark from compiler_gym.envs.gcc.datasets import get_gcc_datasets from compiler_gym.envs.gcc.gcc import Gcc, GccSpec @@ -66,13 +64,10 @@ def __init__( :raises ServiceInitError: If the requested GCC version cannot be used. """ # Pass the executable path via an environment variable - if connection_settings is None: - connection_settings = ConnectionOpts(script_env=frozendict({"CC": gcc_bin})) - else: - script_env = frozendict({"CC": gcc_bin}, **connection_settings.script_env) - opts = connection_settings._asdict() - opts["script_env"] = script_env - connection_settings = ConnectionOpts(**opts) + connection_settings = connection_settings or ConnectionOpts() + connection_settings.script_env = connection_settings.script_env.set( + "CC", gcc_bin + ) # Eagerly create a GCC compiler instance now because: # diff --git a/compiler_gym/service/connection.py b/compiler_gym/service/connection.py index 21191d78d..ed677b60c 100644 --- a/compiler_gym/service/connection.py +++ b/compiler_gym/service/connection.py @@ -16,7 +16,7 @@ from deprecated.sphinx import deprecated from pydantic import BaseModel from frozendict import frozendict -from pydantic import BaseModel +from pydantic import BaseModel, root_validator import compiler_gym.errors from compiler_gym.service.proto import ( @@ -113,6 +113,12 @@ class ConnectionOpts(HashableBaseModel): """If the service is started from a local script, this set of env vars is used on the command line. No effect when used for existing sockets.""" + @root_validator + def freeze_types(cls, values): + values["script_args"] = frozenset(values["script_args"]) + values["script_env"] = frozendict(values["script_env"]) + return values + # Deprecated since v0.2.4. # This type is for backwards compatibility that will be removed in a future release. diff --git a/compiler_gym/service/connection_pool.py b/compiler_gym/service/connection_pool.py index 297e0a67a..47e85c139 100644 --- a/compiler_gym/service/connection_pool.py +++ b/compiler_gym/service/connection_pool.py @@ -108,13 +108,25 @@ def release(self, service: CompilerGymServiceConnection) -> None: if self.closed: return + if service not in self.allocated: + logger.debug( + "Ignoring attempt to release connection " + "that does not belong to pool" + ) + return + self.allocated.remove(service) # Only managed processes have a process attribute. if hasattr(service.connection, "process"): # A dead service cannot be reused, discard it. if service.closed or service.connection.process.poll() is not None: + logger.debug("Ignoring attempt to release dead connection") return + # A service that has been shutdown cannot be reused, discard it. + if not service.connection: + logger.debug("Ignoring attempt to service without connection") + return self.pool[key].append(service) @@ -143,7 +155,7 @@ def close(self) -> None: if self.closed: return - logging.debug( + logger.debug( "Closing the service connection pool with %d cached and %d live connections", self.size, len(self.allocated), diff --git a/tests/llvm/custom_benchmarks_test.py b/tests/llvm/custom_benchmarks_test.py index 15849b0f0..0339c9e58 100644 --- a/tests/llvm/custom_benchmarks_test.py +++ b/tests/llvm/custom_benchmarks_test.py @@ -39,11 +39,12 @@ def test_reset_invalid_benchmark(env: LlvmEnv): def test_invalid_benchmark_data(env: LlvmEnv): benchmark = Benchmark.from_file_contents( - "benchmark://new", "Invalid bitcode".encode("utf-8") + "benchmark://test_invalid_benchmark_data", "Invalid bitcode".encode("utf-8") ) with pytest.raises( - ValueError, match='Failed to parse LLVM bitcode: "benchmark://new"' + ValueError, + match='Failed to parse LLVM bitcode: "benchmark://test_invalid_benchmark_data"', ): env.reset(benchmark=benchmark) @@ -51,7 +52,7 @@ def test_invalid_benchmark_data(env: LlvmEnv): def test_invalid_benchmark_missing_file(env: LlvmEnv): benchmark = Benchmark( BenchmarkProto( - uri="benchmark://new", + uri="benchmark://test_invalid_benchmark_missing_file", ) ) @@ -64,7 +65,9 @@ def test_benchmark_path_empty_file(env: LlvmEnv): tmpdir = Path(tmpdir) (tmpdir / "test.bc").touch() - benchmark = Benchmark.from_file("benchmark://new", tmpdir / "test.bc") + benchmark = Benchmark.from_file( + "benchmark://test_benchmark_path_empty_file", tmpdir / "test.bc" + ) with pytest.raises(ValueError, match="Failed to parse LLVM bitcode"): env.reset(benchmark=benchmark) @@ -76,7 +79,9 @@ def test_invalid_benchmark_path_contents(env: LlvmEnv): with open(str(tmpdir / "test.bc"), "w") as f: f.write("Invalid bitcode") - benchmark = Benchmark.from_file("benchmark://new", tmpdir / "test.bc") + benchmark = Benchmark.from_file( + "benchmark://test_invalid_benchmark_path_contents", tmpdir / "test.bc" + ) with pytest.raises(ValueError, match="Failed to parse LLVM bitcode"): env.reset(benchmark=benchmark) @@ -85,7 +90,8 @@ def test_invalid_benchmark_path_contents(env: LlvmEnv): def test_benchmark_path_invalid_scheme(env: LlvmEnv): benchmark = Benchmark( BenchmarkProto( - uri="benchmark://new", program=File(uri="invalid_scheme://test") + uri="benchmark://test_benchmark_path_invalid_scheme", + program=File(uri="invalid_scheme://test"), ), ) @@ -100,16 +106,20 @@ def test_benchmark_path_invalid_scheme(env: LlvmEnv): def test_custom_benchmark(env: LlvmEnv): - benchmark = Benchmark.from_file("benchmark://new", EXAMPLE_BITCODE_FILE) + benchmark = Benchmark.from_file( + "benchmark://test_custom_benchmark", EXAMPLE_BITCODE_FILE + ) env.reset(benchmark=benchmark) - assert env.benchmark == "benchmark://new" + assert env.benchmark == "benchmark://test_custom_benchmark" def test_custom_benchmark_constructor(): - benchmark = Benchmark.from_file("benchmark://new", EXAMPLE_BITCODE_FILE) + benchmark = Benchmark.from_file( + "benchmark://test_custom_benchmark_constructor", EXAMPLE_BITCODE_FILE + ) with gym.make("llvm-v0", benchmark=benchmark) as env: env.reset() - assert env.benchmark == "benchmark://new" + assert env.benchmark == "benchmark://test_custom_benchmark_constructor" def test_make_benchmark_single_bitcode(env: LlvmEnv): diff --git a/tests/llvm/fork_env_test.py b/tests/llvm/fork_env_test.py index 09fe2263c..a02ad7436 100644 --- a/tests/llvm/fork_env_test.py +++ b/tests/llvm/fork_env_test.py @@ -271,7 +271,7 @@ def test_forked_service_dies(env: LlvmEnv): with env.fork() as fkd: assert env.service == fkd.service try: - fkd.service.shutdown() + fkd.service.connection.close() except ServiceError: pass # shutdown() raises service error if in-episode. fkd.service.close() diff --git a/tests/service/connection_pool_test.py b/tests/service/connection_pool_test.py index c897202bb..c50fdcdfb 100644 --- a/tests/service/connection_pool_test.py +++ b/tests/service/connection_pool_test.py @@ -163,10 +163,9 @@ def test_service_pool_forked_service_dies(pool: ServiceConnectionPool): with env.fork() as fkd: assert env.service == fkd.service try: - fkd.service.shutdown() + fkd.service.connection.close() except ServiceError: pass # shutdown() raises service error if in-episode. - fkd.service.close() env.reset() fkd.reset() From a2a312353196f18419e62d5d4988790552093105 Mon Sep 17 00:00:00 2001 From: Chris Cummins Date: Thu, 3 Mar 2022 10:28:50 +0000 Subject: [PATCH 12/41] Fix typo in docstring. --- compiler_gym/envs/compiler_env.py | 238 ++++++++++++++++++ .../service/client_service_compiler_env.py | 5 +- 2 files changed, 242 insertions(+), 1 deletion(-) diff --git a/compiler_gym/envs/compiler_env.py b/compiler_gym/envs/compiler_env.py index e6a48b284..0c1442132 100644 --- a/compiler_gym/envs/compiler_env.py +++ b/compiler_gym/envs/compiler_env.py @@ -98,6 +98,244 @@ def observation_space_spec(self) -> ObservationSpaceSpec: def observation_space_spec( self, observation_space_spec: Optional[ObservationSpaceSpec] ): +<<<<<<< HEAD +======= + """Construct and initialize a CompilerGym environment. + + In normal use you should use :code:`gym.make(...)` rather than calling + the constructor directly. + + :param service: The hostname and port of a service that implements the + CompilerGym service interface, or the path of a binary file which + provides the CompilerGym service interface when executed. See + :doc:`/compiler_gym/service` for details. + + :param rewards: The reward spaces that this environment supports. + Rewards are typically calculated based on observations generated by + the service. See :class:`Reward ` for + details. + + :param benchmark: The benchmark to use for this environment. Either a + URI string, or a :class:`Benchmark + ` instance. If not provided, the + first benchmark as returned by + :code:`next(env.datasets.benchmarks())` will be used as the default. + + :param observation_space: Compute and return observations at each + :func:`step()` from this space. Accepts a string name or an + :class:`ObservationSpaceSpec + `. If not provided, + :func:`step()` returns :code:`None` for the observation value. Can + be set later using :meth:`env.observation_space + `. For available + spaces, see :class:`env.observation.spaces + `. + + :param reward_space: Compute and return reward at each :func:`step()` + from this space. Accepts a string name or a :class:`Reward + `. If not provided, :func:`step()` + returns :code:`None` for the reward value. Can be set later using + :meth:`env.reward_space + `. For available spaces, + see :class:`env.reward.spaces `. + + :param action_space: The name of the action space to use. If not + specified, the default action space for this compiler is used. + + :param derived_observation_spaces: An optional list of arguments to be + passed to :meth:`env.observation.add_derived_space() + `. + + :param connection_settings: The settings used to establish a connection + with the remote service. + + :param service_connection: An existing compiler gym service connection + to use. + + :param service_pool: A service pool to use for acquiring a service + connection. If not specified, the :meth:`global service pool + ` is used. + + :raises FileNotFoundError: If service is a path to a file that is not + found. + + :raises TimeoutError: If the compiler service fails to initialize within + the parameters provided in :code:`connection_settings`. + """ + # NOTE(cummins): Logger argument deprecated and scheduled to be removed + # in release 0.2.3. + if logger: + warnings.warn( + "The `logger` argument is deprecated on CompilerEnv.__init__() " + "and will be removed in a future release. All CompilerEnv " + "instances share a logger named compiler_gym.envs.compiler_env", + DeprecationWarning, + ) + + self.metadata = {"render.modes": ["human", "ansi"]} + + # A compiler service supports multiple simultaneous environments. This + # session ID is used to identify this environment. + self._session_id: Optional[int] = None + + self._service_endpoint: Union[str, Path] = service + self._connection_settings = connection_settings or ConnectionOpts() + + if service_connection is None: + self._service_pool = ( + ServiceConnectionPool.get() if service_pool is None else service_pool + ) + self.service = self._service_pool.acquire( + endpoint=self._service_endpoint, + opts=self._connection_settings, + ) + else: + self._service_pool = service_pool + self.service = service_connection + + self.datasets = Datasets(datasets or []) + + self.action_space_name = action_space + + # If no reward space is specified, generate some from numeric observation spaces + rewards = rewards or [ + DefaultRewardFromObservation(obs.name) + for obs in self.service.observation_spaces + if obs.default_observation.WhichOneof("value") + and isinstance( + getattr( + obs.default_observation, obs.default_observation.WhichOneof("value") + ), + numbers.Number, + ) + ] + + # The benchmark that is currently being used, and the benchmark that + # will be used on the next call to reset(). These are equal except in + # the gap between the user setting the env.benchmark property while in + # an episode and the next call to env.reset(). + self._benchmark_in_use: Optional[Benchmark] = None + self._benchmark_in_use_proto: BenchmarkProto = BenchmarkProto() + self._next_benchmark: Optional[Benchmark] = None + # Normally when the benchmark is changed the updated value is not + # reflected until the next call to reset(). We make an exception for the + # constructor-time benchmark as otherwise the behavior of the benchmark + # property is counter-intuitive: + # + # >>> env = gym.make("example-v0", benchmark="foo") + # >>> env.benchmark + # None + # >>> env.reset() + # >>> env.benchmark + # "foo" + # + # By forcing the _benchmark_in_use URI at constructor time, the first + # env.benchmark above returns the benchmark as expected. + try: + self.benchmark = benchmark or next(self.datasets.benchmarks()) + self._benchmark_in_use = self._next_benchmark + except StopIteration: + # StopIteration raised on next(self.datasets.benchmarks()) if there + # are no benchmarks available. This is to allow CompilerEnv to be + # used without any datasets by setting a benchmark before/during the + # first reset() call. + pass + + # Process the available action, observation, and reward spaces. + self.action_spaces = [ + proto_to_action_space(space) for space in self.service.action_spaces + ] + + self.observation = self._observation_view_type( + raw_step=self.raw_step, + spaces=self.service.observation_spaces, + ) + self.reward = self._reward_view_type(rewards, self.observation) + + # Register any derived observation spaces now so that the observation + # space can be set below. + for derived_observation_space in derived_observation_spaces or []: + self.observation.add_derived_space_internal(**derived_observation_space) + + # Lazily evaluated version strings. + self._versions: Optional[GetVersionReply] = None + + self.action_space: Optional[Space] = None + self.observation_space: Optional[Space] = None + + # Mutable state initialized in reset(). + self.reward_range: Tuple[float, float] = (-np.inf, np.inf) + self.episode_reward: Optional[float] = None + self.episode_start_time: float = time() + self.actions: List[ActionType] = [] + + # Initialize the default observation/reward spaces. + self.observation_space_spec: Optional[ObservationSpaceSpec] = None + self.reward_space_spec: Optional[Reward] = None + self.observation_space = observation_space + self.reward_space = reward_space + + @property + @deprecated( + version="0.2.1", + reason=( + "The `CompilerEnv.logger` attribute is deprecated. All CompilerEnv " + "instances share a logger named compiler_gym.envs.compiler_env" + ), + ) + def logger(self): + return _logger + + @property + def versions(self) -> GetVersionReply: + """Get the version numbers from the compiler service.""" + if self._versions is None: + self._versions = self.service( + self.service.stub.GetVersion, GetVersionRequest() + ) + return self._versions + + @property + def version(self) -> str: + """The version string of the compiler service.""" + return self.versions.service_version + + @property + def compiler_version(self) -> str: + """The version string of the underlying compiler that this service supports.""" + return self.versions.compiler_version + + def commandline(self) -> str: + """Interface for :class:`CompilerEnv ` + subclasses to provide an equivalent commandline invocation to the + current environment state. + + See also :meth:`commandline_to_actions() + `. + + Calling this method on a :class:`CompilerEnv + ` instance raises + :code:`NotImplementedError`. + + :return: A string commandline invocation. + """ + raise NotImplementedError("abstract method") + + def commandline_to_actions(self, commandline: str) -> List[ActionType]: + """Interface for :class:`CompilerEnv ` + subclasses to convert from a commandline invocation to a sequence of + actions. + + See also :meth:`commandline() + `. + + Calling this method on a :class:`CompilerEnv + ` instance raises + :code:`NotImplementedError`. + + :return: A list of actions. + """ +>>>>>>> 4a874cee (Fix typo in docstring.) raise NotImplementedError("abstract method") @property diff --git a/compiler_gym/service/client_service_compiler_env.py b/compiler_gym/service/client_service_compiler_env.py index 23bacd946..a634e56c5 100644 --- a/compiler_gym/service/client_service_compiler_env.py +++ b/compiler_gym/service/client_service_compiler_env.py @@ -194,7 +194,8 @@ def __init__( to use. :param service_pool: A service pool to use for acquiring a service - connection. If not specified, the global service pool is used. + connection. If not specified, the :meth:`global service pool + ` is used. :raises FileNotFoundError: If service is a path to a file that is not found. @@ -395,6 +396,8 @@ def commandline(self) -> str: """Calling this method on a :class:`ClientServiceCompilerEnv ` instance raises :code:`NotImplementedError`. + + :return: A string commandline invocation. """ raise NotImplementedError("abstract method") From c104632d2606c0790063efc3c30d4409f78be1d9 Mon Sep 17 00:00:00 2001 From: Chris Cummins Date: Thu, 3 Mar 2022 10:31:44 +0000 Subject: [PATCH 13/41] Add a type annotation for env._service_pool. --- compiler_gym/envs/compiler_env.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/compiler_gym/envs/compiler_env.py b/compiler_gym/envs/compiler_env.py index 0c1442132..a243d6afa 100644 --- a/compiler_gym/envs/compiler_env.py +++ b/compiler_gym/envs/compiler_env.py @@ -182,7 +182,7 @@ def observation_space_spec( self._connection_settings = connection_settings or ConnectionOpts() if service_connection is None: - self._service_pool = ( + self._service_pool: Optional[ServiceConnectionPool] = ( ServiceConnectionPool.get() if service_pool is None else service_pool ) self.service = self._service_pool.acquire( @@ -190,7 +190,7 @@ def observation_space_spec( opts=self._connection_settings, ) else: - self._service_pool = service_pool + self._service_pool: Optional[ServiceConnectionPool] = service_pool self.service = service_connection self.datasets = Datasets(datasets or []) From 3a7b745fe031ec66060375dd176d7faa2dcb62b2 Mon Sep 17 00:00:00 2001 From: Chris Cummins Date: Thu, 3 Mar 2022 10:33:21 +0000 Subject: [PATCH 14/41] Clarify docstring. --- compiler_gym/service/connection.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compiler_gym/service/connection.py b/compiler_gym/service/connection.py index ed677b60c..89bc9acdb 100644 --- a/compiler_gym/service/connection.py +++ b/compiler_gym/service/connection.py @@ -52,7 +52,7 @@ class HashableBaseModel(BaseModel): - """A pydantic model that is hashable.""" + """A pydantic model that is hashable. Requires that all fields are hashable.""" def __hash__(self): return hash((type(self),) + tuple(self.__dict__.values())) From 2010f387556bc5a961c3d1df7c9f00096911f906 Mon Sep 17 00:00:00 2001 From: Chris Cummins Date: Thu, 3 Mar 2022 10:45:00 +0000 Subject: [PATCH 15/41] Docstring improvements. --- compiler_gym/service/connection.py | 2 +- compiler_gym/service/connection_pool.py | 8 ++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/compiler_gym/service/connection.py b/compiler_gym/service/connection.py index 89bc9acdb..2e000f945 100644 --- a/compiler_gym/service/connection.py +++ b/compiler_gym/service/connection.py @@ -790,7 +790,7 @@ def shutdown(self): except ServiceError as e: # close() can raise ServiceError if the service exists with a # non-zero return code. We swallow the error here as we are - # disposing o f the service. + # disposing of the service. logger.debug( "Ignoring service error during shutdown attempt: %s (%s)", e, diff --git a/compiler_gym/service/connection_pool.py b/compiler_gym/service/connection_pool.py index 47e85c139..4722e0a28 100644 --- a/compiler_gym/service/connection_pool.py +++ b/compiler_gym/service/connection_pool.py @@ -42,6 +42,14 @@ class ServiceConnectionPool: `), it is automatically released back to the pool so that a future request for the same type of service will reuse the connection. + + :ivar pool: A pool of service connections that are ready for use. + + :vartype pool: Dict[ServiceConnectionCacheKey, List[CompilerGymServiceConnection]] + + :ivar allocated: The set of service connections that are currently in use. + + :vartype allocated: Set[CompilerGymServiceConnection] """ def __init__(self) -> None: From 63d3bd853bfbbaad8aacd111a9388c5a4d7b922a Mon Sep 17 00:00:00 2001 From: Chris Cummins Date: Thu, 3 Mar 2022 12:00:40 +0000 Subject: [PATCH 16/41] Refactor env.action_space.flags.index(x) to env.action_space[x]. They are equivalent. --- compiler_gym/wrappers/commandline.py | 2 +- examples/tabular_q.py | 6 +++--- tests/llvm/action_space_test.py | 8 +++---- tests/llvm/datasets/cbench_validate_test.py | 4 ++-- tests/llvm/fork_env_test.py | 23 ++++++++++++--------- tests/llvm/llvm_env_test.py | 18 ++++++++-------- tests/llvm/reward_spaces_test.py | 2 +- 7 files changed, 33 insertions(+), 30 deletions(-) diff --git a/compiler_gym/wrappers/commandline.py b/compiler_gym/wrappers/commandline.py index 976b339e5..79dbcab25 100644 --- a/compiler_gym/wrappers/commandline.py +++ b/compiler_gym/wrappers/commandline.py @@ -143,7 +143,7 @@ def __init__( flag=env.action_space.flags[a], description=env.action_space.descriptions[a], ) - for a in (env.action_space.flags.index(f) for f in flags) + for a in (env.action_space[f] for f in flags) ], name=f"{type(self).__name__}<{name or env.action_space.name}, {len(flags)}>", ) diff --git a/examples/tabular_q.py b/examples/tabular_q.py index da7419c25..94124af93 100644 --- a/examples/tabular_q.py +++ b/examples/tabular_q.py @@ -120,7 +120,7 @@ def rollout(qtable, env, printout=False): for i in range(FLAGS.episode_length): a = select_action(qtable, observation, i) action_seq.append(a) - observation, reward, done, info = env.step(env.action_space.flags.index(a)) + observation, reward, done, info = env.step(env.action_space[a]) rewards.append(reward) if done: break @@ -146,10 +146,10 @@ def train(q_table, env): hashed = make_q_table_key(observation, a, current_length) if hashed not in q_table: q_table[hashed] = 0 - # Take a stap in the environment, record the reward and state transition. + # Take a step in the environment, record the reward and state transition. # Effectively we are evaluating the policy by taking a step in the # environment. - observation, reward, done, info = env.step(env.action_space.flags.index(a)) + observation, reward, done, info = env.step(env.action_space[a]) if done: break current_length += 1 diff --git a/tests/llvm/action_space_test.py b/tests/llvm/action_space_test.py index 669958052..5a377c0b0 100644 --- a/tests/llvm/action_space_test.py +++ b/tests/llvm/action_space_test.py @@ -17,12 +17,12 @@ def test_commandline_no_actions(env: LlvmEnv): def test_commandline(env: LlvmEnv): env.reset(benchmark="cbench-v1/crc32") - env.step(env.action_space.flags.index("-mem2reg")) - env.step(env.action_space.flags.index("-reg2mem")) + env.step(env.action_space["-mem2reg"]) + env.step(env.action_space["-reg2mem"]) assert env.commandline() == "opt -mem2reg -reg2mem input.bc -o output.bc" assert env.commandline_to_actions(env.commandline()) == [ - env.action_space.flags.index("-mem2reg"), - env.action_space.flags.index("-reg2mem"), + env.action_space["-mem2reg"], + env.action_space["-reg2mem"], ] diff --git a/tests/llvm/datasets/cbench_validate_test.py b/tests/llvm/datasets/cbench_validate_test.py index 377f720d3..92ce417ef 100644 --- a/tests/llvm/datasets/cbench_validate_test.py +++ b/tests/llvm/datasets/cbench_validate_test.py @@ -19,7 +19,7 @@ def test_validate_benchmark_semantics(env: LlvmEnv, validatable_cbench_uri: str) env.reset(benchmark=validatable_cbench_uri) # Run a single step. - env.step(env.action_space.flags.index("-mem2reg")) + env.step(env.action_space["-mem2reg"]) # Validate the environment state. result: ValidationResult = env.validate() @@ -41,7 +41,7 @@ def test_non_validatable_benchmark_validate( env.reset(benchmark=non_validatable_cbench_uri) # Run a single step. - env.step(env.action_space.flags.index("-mem2reg")) + env.step(env.action_space["-mem2reg"]) # Validate the environment state. result: ValidationResult = env.validate() diff --git a/tests/llvm/fork_env_test.py b/tests/llvm/fork_env_test.py index a02ad7436..8d30f3acd 100644 --- a/tests/llvm/fork_env_test.py +++ b/tests/llvm/fork_env_test.py @@ -201,7 +201,7 @@ def test_fork_modified_ir_is_the_same(env: LlvmEnv): env.reset("cbench-v1/crc32") # Apply an action that modifies the benchmark. - _, _, done, info = env.step(env.action_space.flags.index("-mem2reg")) + _, _, done, info = env.step(env.action_space["-mem2reg"]) assert not done assert not info["action_had_no_effect"] @@ -209,8 +209,8 @@ def test_fork_modified_ir_is_the_same(env: LlvmEnv): assert "\n".join(env.ir.split("\n")[1:]) == "\n".join(fkd.ir.split("\n")[1:]) # Apply another action. - _, _, done, info = env.step(env.action_space.flags.index("-gvn")) - _, _, done, info = fkd.step(fkd.action_space.flags.index("-gvn")) + _, _, done, info = env.step(env.action_space["-gvn"]) + _, _, done, info = fkd.step(fkd.action_space["-gvn"]) assert not done assert not info["action_had_no_effect"] @@ -227,7 +227,10 @@ def test_fork_rewards(env: LlvmEnv, reward_space: str): env.reward_space = reward_space env.reset("cbench-v1/dijkstra") - actions = [env.action_space.flags.index(n) for n in ["-mem2reg", "-simplifycfg"]] + actions = [ + env.action_space["-mem2reg"], + env.action_space["-simplifycfg"], + ] forked = env.fork() try: @@ -245,21 +248,21 @@ def test_fork_previous_cost_reward_update(env: LlvmEnv): env.reward_space = "IrInstructionCount" env.reset("cbench-v1/crc32") - env.step(env.action_space.flags.index("-mem2reg")) + env.step(env.action_space["-mem2reg"]) with env.fork() as fkd: - _, a, _, _ = env.step(env.action_space.flags.index("-mem2reg")) - _, b, _, _ = fkd.step(env.action_space.flags.index("-mem2reg")) + _, a, _, _ = env.step(env.action_space["-mem2reg"]) + _, b, _, _ = fkd.step(env.action_space["-mem2reg"]) assert a == b def test_fork_previous_cost_lazy_reward_update(env: LlvmEnv): env.reset("cbench-v1/crc32") - env.step(env.action_space.flags.index("-mem2reg")) + env.step(env.action_space["-mem2reg"]) env.reward["IrInstructionCount"] # noqa with env.fork() as fkd: - env.step(env.action_space.flags.index("-mem2reg")) - fkd.step(env.action_space.flags.index("-mem2reg")) + env.step(env.action_space["-mem2reg"]) + fkd.step(env.action_space["-mem2reg"]) assert env.reward["IrInstructionCount"] == fkd.reward["IrInstructionCount"] diff --git a/tests/llvm/llvm_env_test.py b/tests/llvm/llvm_env_test.py index 59f0365ef..3fbf0bcec 100644 --- a/tests/llvm/llvm_env_test.py +++ b/tests/llvm/llvm_env_test.py @@ -146,7 +146,7 @@ def test_apply_state(env: LlvmEnv): """Test that apply() on a clean environment produces same state.""" env.reward_space = "IrInstructionCount" env.reset(benchmark="cbench-v1/crc32") - env.step(env.action_space.flags.index("-mem2reg")) + env.step(env.action_space["-mem2reg"]) with gym.make("llvm-v0", reward_space="IrInstructionCount") as other: other.apply(env.state) @@ -176,7 +176,7 @@ def test_same_reward_after_reset(env: LlvmEnv): env.reward_space = "IrInstructionCount" env.benchmark = "cbench-v1/dijkstra" - action = env.action_space.flags.index("-instcombine") + action = env.action_space["-instcombine"] env.reset() _, reward_a, _, _ = env.step(action) @@ -203,7 +203,7 @@ def test_ir_sha1(env: LlvmEnv, tmpwd: Path): env.reset(benchmark="cbench-v1/crc32") before = env.ir_sha1 - _, _, done, info = env.step(env.action_space.flags.index("-mem2reg")) + _, _, done, info = env.step(env.action_space["-mem2reg"]) assert not done, info assert not info["action_had_no_effect"], "sanity check failed, action had no effect" @@ -220,8 +220,8 @@ def test_step_multiple_actions_list(env: LlvmEnv): """Pass a list of actions to step().""" env.reset(benchmark="cbench-v1/crc32") actions = [ - env.action_space.flags.index("-mem2reg"), - env.action_space.flags.index("-reg2mem"), + env.action_space["-mem2reg"], + env.action_space["-reg2mem"], ] _, _, done, _ = env.multistep(actions) assert not done @@ -232,14 +232,14 @@ def test_step_multiple_actions_generator(env: LlvmEnv): """Pass an iterable of actions to step().""" env.reset(benchmark="cbench-v1/crc32") actions = ( - env.action_space.flags.index("-mem2reg"), - env.action_space.flags.index("-reg2mem"), + env.action_space["-mem2reg"], + env.action_space["-reg2mem"], ) _, _, done, _ = env.multistep(actions) assert not done assert env.actions == [ - env.action_space.flags.index("-mem2reg"), - env.action_space.flags.index("-reg2mem"), + env.action_space["-mem2reg"], + env.action_space["-reg2mem"], ] diff --git a/tests/llvm/reward_spaces_test.py b/tests/llvm/reward_spaces_test.py index e9214500b..27a7db627 100644 --- a/tests/llvm/reward_spaces_test.py +++ b/tests/llvm/reward_spaces_test.py @@ -23,7 +23,7 @@ def test_instruction_count_reward(env: LlvmEnv): env.reset(benchmark="cbench-v1/crc32") assert env.observation.IrInstructionCount() == CRC32_INSTRUCTION_COUNT - action = env.action_space.flags.index("-reg2mem") + action = env.action_space["-reg2mem"] env.step(action) assert env.observation.IrInstructionCount() == CRC32_INSTRUCTION_COUNT_AFTER_REG2MEM From eeef55fda6fe7f40f2bfd5ccc97dfcc6a2b918fc Mon Sep 17 00:00:00 2001 From: Chris Cummins Date: Thu, 3 Mar 2022 12:01:29 +0000 Subject: [PATCH 17/41] [tests] Increase timeout on validation test. --- tests/llvm/datasets/cbench_validate_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/llvm/datasets/cbench_validate_test.py b/tests/llvm/datasets/cbench_validate_test.py index 92ce417ef..6b5b974d1 100644 --- a/tests/llvm/datasets/cbench_validate_test.py +++ b/tests/llvm/datasets/cbench_validate_test.py @@ -12,7 +12,7 @@ pytest_plugins = ["tests.pytest_plugins.llvm"] -@pytest.mark.timeout(600) +@pytest.mark.timeout(900) # Validation can take a long time! def test_validate_benchmark_semantics(env: LlvmEnv, validatable_cbench_uri: str): """Run the validation routine on all benchmarks.""" env.reward_space = "IrInstructionCount" From 91c38ef56126adb436b02534cab09768acd492a0 Mon Sep 17 00:00:00 2001 From: Chris Cummins Date: Thu, 3 Mar 2022 12:15:40 +0000 Subject: [PATCH 18/41] Documentation improvements. --- compiler_gym/service/connection.py | 50 +++++++++++++++---------- compiler_gym/service/connection_pool.py | 12 +++--- compiler_gym/util/flags/README.md | 4 +- docs/source/compiler_gym/service.rst | 8 ++-- 4 files changed, 42 insertions(+), 32 deletions(-) diff --git a/compiler_gym/service/connection.py b/compiler_gym/service/connection.py index 2e000f945..2a74f1d98 100644 --- a/compiler_gym/service/connection.py +++ b/compiler_gym/service/connection.py @@ -42,8 +42,8 @@ # Spurious error UNAVAILABLE "Trying to connect an http1.x server". # https://putridparrot.com/blog/the-unavailable-trying-to-connect-an-http1-x-server-grpc-error/ ("grpc.enable_http_proxy", 0), - # Disable TCP port re-use to mitigate port conflict errors when starting - # many services in parallel. Context: + # Disable TCP port reuse to mitigate port conflict errors when starting many + # services in parallel. Context: # https://github.com/facebookresearch/CompilerGym/issues/572 ("grpc.so_reuseport", 0), ] @@ -100,9 +100,9 @@ class ConnectionOpts(HashableBaseModel): always_send_benchmark_on_reset: bool = False """Send the full benchmark program data to the compiler service on ever call to :meth:`env.reset() `. This is more - efficient in cases where the majority of calls to - :meth:`env.reset() ` uses a different - benchmark. In case of benchmark re-use, leave this :code:`False`. + efficient in cases where the majority of calls to :meth:`env.reset() + ` uses a different benchmark. In case + of benchmark reuse, leave this :code:`False`. """ script_args: FrozenSet[str] = frozenset([]) @@ -595,18 +595,19 @@ class CompilerGymServiceConnection: """A connection to a compiler gym service. There are two types of service connections: managed and unmanaged. The type - of connection is determined by the endpoint. If a "host:port" URL is provided, - an unmanaged connection is created. If the path of a file is provided, a - managed connection is used. The difference between a managed and unmanaged - connection is that with a managed connection, the lifecycle of the service - if controlled by the client connection. That is, when a managed connection - is created, a service subprocess is started by executing the specified path. - When the connection is closed, the subprocess is terminated. With an - unmanaged connection, if the service fails is goes offline, the client will - fail. - - This class provides a common abstraction between the two types of connection, - and provides a call method for invoking remote procedures on the service. + of connection is determined by the endpoint. If a "host:port" URL is + provided, an unmanaged connection is created. If the path of a file is + provided, a managed connection is used. The difference between a managed and + unmanaged connection is that with a managed connection, the lifecycle of the + service if controlled by the client connection. That is, when a managed + connection is created, a service subprocess is started by executing the + specified path. When the connection is closed, the subprocess is terminated. + With an unmanaged connection, if the service fails is goes offline, the + client will fail. + + This class provides a common abstraction between the two types of + connection, and provides a call method for invoking remote procedures on the + service. Example usage of an unmanaged service connection: @@ -635,7 +636,9 @@ class CompilerGymServiceConnection: :ivar stub: A CompilerGymServiceStub that can be used as the first argument to :py:meth:`__call__()` to specify an RPC method to call. + :ivar action_spaces: A list of action spaces provided by the service. + :ivar observation_spaces: A list of observation spaces provided by the service. """ @@ -648,6 +651,13 @@ def __init__( ): """Constructor. + .. note:: + + Starting new services is expensive. Consider using the + :class:`ServiceConnectionPool + ` class to manage + services rather than constructing them yourself. + :param endpoint: The connection endpoint. Either the URL of a service, e.g. "localhost:8080", or the path of a local service binary. @@ -780,7 +790,7 @@ def release(self) -> None: def shutdown(self): """Shut down the connection. - Once a connection has been shutdown, it cannot be re-used. + Once a connection has been shutdown, it cannot be reused. """ if self.closed: return @@ -802,8 +812,8 @@ def close(self): """Mark this connection as closed. If the service is managed by a :class:`ServiceConnectionPool - `, this will indicate to - the pool that the connection is safe to re-use. If the service is not + `, this will indicate to the + pool that the connection is ready to be reused. If the service is not managed by a pool, this will shut it down. """ if self.owned_by_service_pool: diff --git a/compiler_gym/service/connection_pool.py b/compiler_gym/service/connection_pool.py index 4722e0a28..494d94eff 100644 --- a/compiler_gym/service/connection_pool.py +++ b/compiler_gym/service/connection_pool.py @@ -27,9 +27,8 @@ class ServiceConnectionPool: There is a global instance of this class, available via the static :meth:`ServiceConnectionPool.get() - ` method. - - To use the pool, acquire a reference to the global instance, and call the + ` method. To use the pool, + acquire a reference to the global instance, and call the :meth:`ServiceConnectionPool.acquire() ` method to construct and return service connections: @@ -39,13 +38,14 @@ class ServiceConnectionPool: ... # Do something with the service. When a service is closed (by calling :meth:`service.close() - `), it is + `), it is automatically released back to the pool so that a future request for the same type of service will reuse the connection. :ivar pool: A pool of service connections that are ready for use. - :vartype pool: Dict[ServiceConnectionCacheKey, List[CompilerGymServiceConnection]] + :vartype pool: Dict[Tuple[Path, ConnectionOpts], + List[CompilerGymServiceConnection]] :ivar allocated: The set of service connections that are currently in use. @@ -104,7 +104,7 @@ def release(self, service: CompilerGymServiceConnection) -> None: .. note:: This method is called automatically by the :meth:`service.close() - ` method of + ` method of acquired service connections. You do not have to call this method yourself. """ diff --git a/compiler_gym/util/flags/README.md b/compiler_gym/util/flags/README.md index 38180ffd9..5751b87dd 100644 --- a/compiler_gym/util/flags/README.md +++ b/compiler_gym/util/flags/README.md @@ -2,8 +2,8 @@ This directory contains modules that define command line flags for use by `compiler_gym.bin` and other scripts. The reason for defining flags here is to -allow flag names to be re-used across scripts without causing -multiple-definition errors when the scripts are imported. +allow flag names to be reused across scripts without causing multiple-definition +errors when the scripts are imported. Using these flags requires that the absl flags library is initialized. As such they should not be used in the core library. diff --git a/docs/source/compiler_gym/service.rst b/docs/source/compiler_gym/service.rst index 1499a2877..5e0033b4d 100644 --- a/docs/source/compiler_gym/service.rst +++ b/docs/source/compiler_gym/service.rst @@ -38,7 +38,7 @@ ClientServiceCompilerEnv .. automethod:: __init__ -The connection object +The Connection Object --------------------- .. autoclass:: CompilerGymServiceConnection @@ -48,7 +48,7 @@ The connection object .. automethod:: __call__ -Configuring the connection +Configuring the Connection -------------------------- The :class:`ConnectionOpts ` object is used @@ -58,8 +58,8 @@ to configure the options used for managing a service connection. :members: -The connection pool -------------------- +Re-using Connections +-------------------- .. autoclass:: ServiceConnectionPool :members: From 623a23d6a215394f2f4d2f0a84b1187c81058595 Mon Sep 17 00:00:00 2001 From: Chris Cummins Date: Thu, 3 Mar 2022 12:17:19 +0000 Subject: [PATCH 19/41] [tests] Add test case for forked environment scope. --- tests/service/connection_pool_test.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/tests/service/connection_pool_test.py b/tests/service/connection_pool_test.py index c50fdcdfb..62ee5d2f6 100644 --- a/tests/service/connection_pool_test.py +++ b/tests/service/connection_pool_test.py @@ -174,9 +174,14 @@ def test_service_pool_forked_service_dies(pool: ServiceConnectionPool): assert fkd.service in pool -# TODO: Test case where forked environment kills the service. - -# TODO: Service pool connection does not interfere with pool. +def test_service_pool_forked_environment_ends_scope(pool: ServiceConnectionPool): + """Test that the original service does not close when the forked environment + goes out of scope.""" + with compiler_gym.make("llvm-v0", service_pool=pool) as env: + with env.fork() as fkd: + assert env.service == fkd.service + assert not env.service.closed + assert not env.service.closed if __name__ == "__main__": From 35cbd5b2c9e505495aaf3962b342ebe878597654 Mon Sep 17 00:00:00 2001 From: Chris Cummins Date: Thu, 3 Mar 2022 12:31:46 +0000 Subject: [PATCH 20/41] [benchmarks] Merge init benchmarks and add GCC + loop_tool. This adds gcc-v0 and the loop_tool-v0 environments to the gym.make(...) benchmark, and removes the benchmark for environment initialization from an existing service, as that is now the default behavior when using ServiceConnectionPools. --- benchmarks/bench_test.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/benchmarks/bench_test.py b/benchmarks/bench_test.py index 005581298..0cbd58e5e 100644 --- a/benchmarks/bench_test.py +++ b/benchmarks/bench_test.py @@ -25,6 +25,8 @@ from compiler_gym.service import CompilerGymServiceConnection from compiler_gym.service.client_service_compiler_env import ClientServiceCompilerEnv from compiler_gym.service import CompilerGymServiceConnection, ConnectionOpts +import examples.example_compiler_gym_service # noqa Environment import. +from compiler_gym.envs import CompilerEnv from tests.pytest_plugins.llvm import OBSERVATION_SPACE_NAMES, REWARD_SPACE_NAMES from tests.test_main import main @@ -47,10 +49,10 @@ def env(request) -> CompilerEnv: @pytest.mark.parametrize( "env_id", - ["llvm-v0", "example-cc-v0", "example-py-v0"], - ids=["llvm", "dummy-cc", "dummy-py"], + ["llvm-v0", "example-cc-v0", "example-py-v0", "loop_tool-v0"], + ids=["llvm", "dummy-cc", "dummy-py", "loop_tool"], ) -def test_make_local(benchmark, env_id): +def test_make_env(benchmark, env_id): benchmark(lambda: gym.make(env_id).close()) From 826b5c4f947d0a60211baa8d161d7a5aeec5113e Mon Sep 17 00:00:00 2001 From: Chris Cummins Date: Thu, 3 Mar 2022 15:20:33 +0000 Subject: [PATCH 21/41] [tests] Increase timeouts for validation tests. --- tests/llvm/validate_test.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/llvm/validate_test.py b/tests/llvm/validate_test.py index dc65b34e5..526c4cfd7 100644 --- a/tests/llvm/validate_test.py +++ b/tests/llvm/validate_test.py @@ -16,6 +16,7 @@ pytest_plugins = ["tests.pytest_plugins.llvm"] +@pytest.mark.timeout(900) # Validation can take a long time! def test_validate_state_no_reward(): state = CompilerEnvState( benchmark="benchmark://cbench-v1/crc32", @@ -30,6 +31,7 @@ def test_validate_state_no_reward(): assert str(result) == "✅ cbench-v1/crc32" +@pytest.mark.timeout(900) # Validation can take a long time! def test_validate_state_with_reward(): state = CompilerEnvState( benchmark="benchmark://cbench-v1/crc32", @@ -46,6 +48,7 @@ def test_validate_state_with_reward(): assert str(result) == "✅ cbench-v1/crc32 0.0000" +@pytest.mark.timeout(900) # Validation can take a long time! def test_validate_state_invalid_reward(): state = CompilerEnvState( benchmark="benchmark://cbench-v1/crc32", @@ -64,6 +67,7 @@ def test_validate_state_invalid_reward(): ) +@pytest.mark.timeout(900) # Validation can take a long time! def test_validate_state_without_state_reward(): """Validating state when state has no reward value.""" state = CompilerEnvState( @@ -102,6 +106,7 @@ def test_validate_state_without_env_reward(): assert not result.reward_validation_failed +@pytest.mark.timeout(900) # Validation can take a long time! def test_no_validation_callback_for_custom_benchmark(env: LlvmEnv): """Test that a custom benchmark has no validation callback.""" with tempfile.TemporaryDirectory() as d: From a7207a6e7d95392e31520b09ae030fde21bd3ca8 Mon Sep 17 00:00:00 2001 From: Chris Cummins Date: Thu, 3 Mar 2022 12:42:32 +0000 Subject: [PATCH 22/41] [gcc] Add post-install steps to docker install instructions. --- docs/source/envs/gcc.rst | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/docs/source/envs/gcc.rst b/docs/source/envs/gcc.rst index 93efd42cb..45dd8bf97 100644 --- a/docs/source/envs/gcc.rst +++ b/docs/source/envs/gcc.rst @@ -42,9 +42,11 @@ On Linux, install Docker using: "deb [arch=amd64 signed-by=/usr/share/keyrings/docker-archive-keyring.gpg] https://download.docker.com/linux/ubuntu \ $(lsb_release -cs) stable" | sudo tee /etc/apt/sources.list.d/docker.list > /dev/null sudo apt-get update && sudo apt-get install docker-ce docker-ce-cli containerd.io + sudo usermod -aG docker $USER + su - $USER See the `official documentation `_ for -alternative installation options. +more details and alternative installation options. On both Linux and macOS, use the following command to check if Docker is working: From de7db628b02f3b19126bd5d2f04196a5f3f73ccb Mon Sep 17 00:00:00 2001 From: Chris Cummins Date: Mon, 7 Mar 2022 16:25:18 +0000 Subject: [PATCH 23/41] [gcc] Move the thread lock on constructor into core library. --- compiler_gym/envs/gcc/__init__.py | 2 +- compiler_gym/envs/gcc/gcc_env.py | 17 +++++++++++++++++ examples/gcc_autotuning/tune.py | 12 +----------- 3 files changed, 19 insertions(+), 12 deletions(-) diff --git a/compiler_gym/envs/gcc/__init__.py b/compiler_gym/envs/gcc/__init__.py index a17c4f95b..c55fd8ccf 100644 --- a/compiler_gym/envs/gcc/__init__.py +++ b/compiler_gym/envs/gcc/__init__.py @@ -16,7 +16,7 @@ register( id="gcc-v0", - entry_point="compiler_gym.envs.gcc:GccEnv", + entry_point="compiler_gym.envs.gcc.gcc_env:make", kwargs={"service": GCC_SERVICE_BINARY}, ) diff --git a/compiler_gym/envs/gcc/gcc_env.py b/compiler_gym/envs/gcc/gcc_env.py index 77bccfa93..85adec318 100644 --- a/compiler_gym/envs/gcc/gcc_env.py +++ b/compiler_gym/envs/gcc/gcc_env.py @@ -7,6 +7,7 @@ import json import pickle from pathlib import Path +from threading import Lock from typing import Any, Dict, List, Optional, Union from compiler_gym.datasets import Benchmark @@ -218,3 +219,19 @@ def _init_kwargs(self) -> Dict[str, Any]: "gcc_bin": self.gcc_spec.gcc.bin, **super()._init_kwargs(), } + + +_GCC_ENV_DOCKER_CONSTRUCTOR_LOCK = Lock() + + +def make(*args, gcc_bin: Union[str, Path] = DEFAULT_GCC, **kwargs): + """Construct a GccEnv class using a lock to ensure thread exclusivity. + + This is to prevent multiple threads running the docker initialization + routines simultaneously as this can cause issues with the docker API. + """ + if gcc_bin.startswith("docker:"): + with _GCC_ENV_DOCKER_CONSTRUCTOR_LOCK: + return GccEnv(*args, gcc_bin=gcc_bin, **kwargs) + else: + return GccEnv(*args, gcc_bin=gcc_bin, **kwargs) diff --git a/examples/gcc_autotuning/tune.py b/examples/gcc_autotuning/tune.py index 5a702cbdc..efa90f05c 100644 --- a/examples/gcc_autotuning/tune.py +++ b/examples/gcc_autotuning/tune.py @@ -5,7 +5,6 @@ """Autotuning script for GCC command line options.""" import random from itertools import islice, product -from multiprocessing import Lock from pathlib import Path from typing import NamedTuple @@ -64,10 +63,6 @@ "objective", "obj_size", ["asm_size", "obj_size"], "Which objective to use" ) -# Lock to prevent multiple processes all calling compiler_gym.make("gcc-v0") -# simultaneously as this can cause issues with the docker API. -GCC_ENV_CONSTRUCTOR_LOCK = Lock() - def random_search(env: CompilerEnv): best = float("inf") @@ -160,10 +155,7 @@ def scaled_best(self) -> float: def run_search(search: str, benchmark: str, seed: int) -> SearchResult: """Run a search and return the search class instance.""" - with GCC_ENV_CONSTRUCTOR_LOCK: - env = compiler_gym.make("gcc-v0", gcc_bin=FLAGS.gcc_bin) - - try: + with compiler_gym.make("gcc-v0", gcc_bin=FLAGS.gcc_bin) as env: random.seed(seed) np.random.seed(seed) @@ -172,8 +164,6 @@ def run_search(search: str, benchmark: str, seed: int) -> SearchResult: baseline_size = objective(env) env.reset(benchmark=benchmark) best_size = _SEARCH_FUNCTIONS[search](env) - finally: - env.close() return SearchResult( search=search, From adad70c911d1757d8113d4d2c8446aaca506f23a Mon Sep 17 00:00:00 2001 From: Chris Cummins Date: Mon, 7 Mar 2022 17:56:45 +0000 Subject: [PATCH 24/41] [service] Add a base class for the connection pool. Add a base class for the ServiceConnectionPool that provides the same interface but does no caching. --- .../service/client_service_compiler_env.py | 4 ++-- compiler_gym/service/connection_pool.py | 18 ++++++++++++++++-- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/compiler_gym/service/client_service_compiler_env.py b/compiler_gym/service/client_service_compiler_env.py index a634e56c5..9bddf1eb1 100644 --- a/compiler_gym/service/client_service_compiler_env.py +++ b/compiler_gym/service/client_service_compiler_env.py @@ -223,7 +223,7 @@ def __init__( self._connection_settings = connection_settings or ConnectionOpts() if service_connection is None: - self._service_pool = ( + self._service_pool: Optional[ServiceConnectionPoolBase] = ( ServiceConnectionPool.get() if service_pool is None else service_pool ) self.service = self._service_pool.acquire( @@ -231,7 +231,7 @@ def __init__( opts=self._connection_settings, ) else: - self._service_pool = service_pool + self._service_pool: Optional[ServiceConnectionPoolBase] = service_pool self.service = service_connection self.datasets = Datasets(datasets or []) diff --git a/compiler_gym/service/connection_pool.py b/compiler_gym/service/connection_pool.py index 494d94eff..6630f191d 100644 --- a/compiler_gym/service/connection_pool.py +++ b/compiler_gym/service/connection_pool.py @@ -18,7 +18,21 @@ ServiceConnectionCacheKey = Tuple[Path, ConnectionOpts] -class ServiceConnectionPool: +class ServiceConnectionPoolBase: + """A class that provides the base interface for service connection pools.""" + + def acquire( + self, endpoint: Path, opts: ConnectionOpts + ) -> CompilerGymServiceConnection: + return CompilerGymServiceConnection( + endpoint=endpoint, opts=opts, owning_service_pool=self + ) + + def release(self, service: CompilerGymServiceConnection) -> None: + pass + + +class ServiceConnectionPool(ServiceConnectionPoolBase): """An object pool for compiler service connections. This class implements a thread-safe pool for compiler service connections. @@ -52,7 +66,7 @@ class ServiceConnectionPool: :vartype allocated: Set[CompilerGymServiceConnection] """ - def __init__(self) -> None: + def __init__(self): """""" self._lock = Lock() self.pool: Dict[ From b1781ea9d1e531067011cf3c999032764ff4a020 Mon Sep 17 00:00:00 2001 From: Chris Cummins Date: Mon, 7 Mar 2022 17:57:13 +0000 Subject: [PATCH 25/41] [service] Tweak logging messages. --- compiler_gym/service/connection_pool.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/compiler_gym/service/connection_pool.py b/compiler_gym/service/connection_pool.py index 6630f191d..37d9ce403 100644 --- a/compiler_gym/service/connection_pool.py +++ b/compiler_gym/service/connection_pool.py @@ -131,10 +131,7 @@ def release(self, service: CompilerGymServiceConnection) -> None: return if service not in self.allocated: - logger.debug( - "Ignoring attempt to release connection " - "that does not belong to pool" - ) + logger.debug("Discarding service that does not belong to pool") return self.allocated.remove(service) @@ -143,11 +140,11 @@ def release(self, service: CompilerGymServiceConnection) -> None: if hasattr(service.connection, "process"): # A dead service cannot be reused, discard it. if service.closed or service.connection.process.poll() is not None: - logger.debug("Ignoring attempt to release dead connection") + logger.debug("Discarding service with dead process") return # A service that has been shutdown cannot be reused, discard it. if not service.connection: - logger.debug("Ignoring attempt to service without connection") + logger.debug("Discarding service that has no connection") return self.pool[key].append(service) @@ -209,7 +206,7 @@ def get() -> "ServiceConnectionPool": return _SERVICE_CONNECTION_POOL def __repr__(self) -> str: - return f"ServiceConnectionPool(size={self.size})" + return f"{type(self).__name__}(size={self.size})" _SERVICE_CONNECTION_POOL = ServiceConnectionPool() From 9ec10e1da8d1ff5d7b3d494d648e60f4d0ee1ba1 Mon Sep 17 00:00:00 2001 From: Chris Cummins Date: Mon, 7 Mar 2022 17:57:46 +0000 Subject: [PATCH 26/41] [gcc] Disable service connection pool for GCC. Issue #583. --- compiler_gym/envs/gcc/gcc_env.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/compiler_gym/envs/gcc/gcc_env.py b/compiler_gym/envs/gcc/gcc_env.py index 85adec318..3ca0b6246 100644 --- a/compiler_gym/envs/gcc/gcc_env.py +++ b/compiler_gym/envs/gcc/gcc_env.py @@ -17,6 +17,7 @@ from compiler_gym.service import ConnectionOpts from compiler_gym.service.client_service_compiler_env import ClientServiceCompilerEnv from compiler_gym.spaces import Reward +from compiler_gym.service.connection_pool import ServiceConnectionPoolBase from compiler_gym.util.decorators import memoized_property from compiler_gym.util.gym_type_hints import ObservationType, OptionalArgumentValue from compiler_gym.views import ObservationSpaceSpec @@ -79,6 +80,13 @@ def __init__( # initialization may time out. Gcc(bin=gcc_bin) + # NOTE(github.com/facebookresearch/CompilerGym/pull/583): The GCC + # environment stalls on the StartSession() RPC call when service + # connection caching is enabled. I believe this has something to do with + # the runtime code generation, but have not been able to diagnose it + # yet. For now, disable service connection caching for GCC environments. + kwargs["service_pool"] = ServiceConnectionPoolBase() + super().__init__( *args, **kwargs, From 8c8ba2f73ed4f3c8e8e1d149c453f4e4aa4d838f Mon Sep 17 00:00:00 2001 From: Chris Cummins Date: Mon, 7 Mar 2022 18:00:58 +0000 Subject: [PATCH 27/41] [tests] Better GCC fixture enumeration. --- examples/gcc_autotuning/tune_test.py | 28 ++++++++++---------- tests/pytest_plugins/gcc.py | 39 ++++++++++++++++------------ 2 files changed, 38 insertions(+), 29 deletions(-) diff --git a/examples/gcc_autotuning/tune_test.py b/examples/gcc_autotuning/tune_test.py index d6a00a480..1fd9900ce 100644 --- a/examples/gcc_autotuning/tune_test.py +++ b/examples/gcc_autotuning/tune_test.py @@ -2,6 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import shutil import subprocess import sys from functools import lru_cache @@ -24,33 +25,34 @@ def docker_is_available() -> bool: return False -@lru_cache(maxsize=2) -def system_gcc_is_available() -> bool: +def system_has_functional_gcc(gcc_path: str) -> bool: """Return whether there is a system GCC available.""" try: stdout = subprocess.check_output( - ["gcc", "--version"], universal_newlines=True, stderr=subprocess.DEVNULL + [gcc_path, "--version"], + universal_newlines=True, + stderr=subprocess.DEVNULL, + timeout=30, ) # On some systems "gcc" may alias to a different compiler, so check for # the presence of the name "gcc" in the first line of output. return "gcc" in stdout.split("\n")[0].lower() - except (subprocess.CalledProcessError, FileNotFoundError): + except ( + subprocess.CalledProcessError, + FileNotFoundError, + subprocess.TimeoutExpired, + ): return False -def system_gcc_path() -> str: - """Return the path of the system GCC as a string.""" - return subprocess.check_output( - ["which", "gcc"], universal_newlines=True, stderr=subprocess.DEVNULL - ).strip() - - +@lru_cache def gcc_bins() -> Iterable[str]: """Return a list of available GCCs.""" if docker_is_available(): yield "docker:gcc:11.2.0" - if system_gcc_is_available(): - yield system_gcc_path() + system_gcc = shutil.which("gcc") + if system_gcc and system_has_functional_gcc(system_gcc): + yield system_gcc @pytest.fixture(scope="module", params=gcc_bins()) diff --git a/tests/pytest_plugins/gcc.py b/tests/pytest_plugins/gcc.py index 730364bd7..e419edfc8 100644 --- a/tests/pytest_plugins/gcc.py +++ b/tests/pytest_plugins/gcc.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. """Pytest fixtures for the GCC CompilerGym environments.""" +import shutil import subprocess from functools import lru_cache from typing import Iterable @@ -13,38 +14,44 @@ from tests.pytest_plugins.common import docker_is_available -@lru_cache(maxsize=2) -def system_gcc_is_available() -> bool: +def system_has_functional_gcc(gcc_path: str) -> bool: """Return whether there is a system GCC available.""" try: stdout = subprocess.check_output( - ["gcc", "--version"], universal_newlines=True, stderr=subprocess.DEVNULL + [gcc_path, "--version"], + universal_newlines=True, + stderr=subprocess.DEVNULL, + timeout=30, ) # On some systems "gcc" may alias to a different compiler, so check for # the presence of the name "gcc" in the first line of output. return "gcc" in stdout.split("\n")[0].lower() - except (subprocess.CalledProcessError, FileNotFoundError): + except ( + subprocess.CalledProcessError, + FileNotFoundError, + subprocess.TimeoutExpired, + ): return False -def system_gcc_path() -> str: - """Return the path of the system GCC as a string.""" - return subprocess.check_output( - ["which", "gcc"], universal_newlines=True, stderr=subprocess.DEVNULL - ).strip() - - -def gcc_environment_is_supported() -> bool: - """Return whether the requirements for the GCC environment are met.""" - return docker_is_available() or system_gcc_is_available() +@lru_cache +def system_gcc_is_available(): + return system_has_functional_gcc(shutil.which("gcc")) +@lru_cache def gcc_bins() -> Iterable[str]: """Return a list of available GCCs.""" if docker_is_available(): yield "docker:gcc:11.2.0" - if system_gcc_is_available(): - yield system_gcc_path() + system_gcc = shutil.which("gcc") + if system_gcc and system_has_functional_gcc(system_gcc): + yield system_gcc + + +def gcc_environment_is_supported() -> bool: + """Return whether the requirements for the GCC environment are met.""" + return len(list(gcc_bins())) > 0 @pytest.fixture(scope="module", params=gcc_bins()) From 9c6e9fe16b870a94de8cfe340a098897a0e0cc1f Mon Sep 17 00:00:00 2001 From: Chris Cummins Date: Mon, 7 Mar 2022 18:02:38 +0000 Subject: [PATCH 28/41] [examples] Fix missing return code. --- examples/gcc_autotuning/info.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/gcc_autotuning/info.py b/examples/gcc_autotuning/info.py index 9b56b94a4..72ebc38c5 100644 --- a/examples/gcc_autotuning/info.py +++ b/examples/gcc_autotuning/info.py @@ -52,6 +52,7 @@ def info( if not dfs: print("No results") + return df = pd.concat(dfs) df = df.groupby(["timestamp", "search"])[["scaled_size"]].agg(geometric_mean) From 09188f718db4f5430118668bd6daa0850fb35209 Mon Sep 17 00:00:00 2001 From: Chris Cummins Date: Mon, 7 Mar 2022 19:32:19 +0000 Subject: [PATCH 29/41] [service] Defend against logging error on close(). --- compiler_gym/service/connection_pool.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/compiler_gym/service/connection_pool.py b/compiler_gym/service/connection_pool.py index 37d9ce403..96ac5642b 100644 --- a/compiler_gym/service/connection_pool.py +++ b/compiler_gym/service/connection_pool.py @@ -174,11 +174,18 @@ def close(self) -> None: if self.closed: return - logger.debug( - "Closing the service connection pool with %d cached and %d live connections", - self.size, - len(self.allocated), - ) + try: + logger.debug( + "Closing the service connection pool with %d cached and %d live connections", + self.size, + len(self.allocated), + ) + except ValueError: + # As this method is invoked by the atexit callback, the logger + # may already have closed its streams, in which case a + # ValueError is raised. + pass + for connections in self.pool.values(): for connection in connections: connection.shutdown() From aedcb0006231805767fdb234db675dd94d99a922 Mon Sep 17 00:00:00 2001 From: Chris Cummins Date: Tue, 19 Apr 2022 15:50:51 -0700 Subject: [PATCH 30/41] Rebase fixes for service pool PR. B --- compiler_gym/envs/compiler_env.py | 238 ------------------ compiler_gym/service/BUILD | 3 + .../service/client_service_compiler_env.py | 73 ++++-- 3 files changed, 52 insertions(+), 262 deletions(-) diff --git a/compiler_gym/envs/compiler_env.py b/compiler_gym/envs/compiler_env.py index a243d6afa..e6a48b284 100644 --- a/compiler_gym/envs/compiler_env.py +++ b/compiler_gym/envs/compiler_env.py @@ -98,244 +98,6 @@ def observation_space_spec(self) -> ObservationSpaceSpec: def observation_space_spec( self, observation_space_spec: Optional[ObservationSpaceSpec] ): -<<<<<<< HEAD -======= - """Construct and initialize a CompilerGym environment. - - In normal use you should use :code:`gym.make(...)` rather than calling - the constructor directly. - - :param service: The hostname and port of a service that implements the - CompilerGym service interface, or the path of a binary file which - provides the CompilerGym service interface when executed. See - :doc:`/compiler_gym/service` for details. - - :param rewards: The reward spaces that this environment supports. - Rewards are typically calculated based on observations generated by - the service. See :class:`Reward ` for - details. - - :param benchmark: The benchmark to use for this environment. Either a - URI string, or a :class:`Benchmark - ` instance. If not provided, the - first benchmark as returned by - :code:`next(env.datasets.benchmarks())` will be used as the default. - - :param observation_space: Compute and return observations at each - :func:`step()` from this space. Accepts a string name or an - :class:`ObservationSpaceSpec - `. If not provided, - :func:`step()` returns :code:`None` for the observation value. Can - be set later using :meth:`env.observation_space - `. For available - spaces, see :class:`env.observation.spaces - `. - - :param reward_space: Compute and return reward at each :func:`step()` - from this space. Accepts a string name or a :class:`Reward - `. If not provided, :func:`step()` - returns :code:`None` for the reward value. Can be set later using - :meth:`env.reward_space - `. For available spaces, - see :class:`env.reward.spaces `. - - :param action_space: The name of the action space to use. If not - specified, the default action space for this compiler is used. - - :param derived_observation_spaces: An optional list of arguments to be - passed to :meth:`env.observation.add_derived_space() - `. - - :param connection_settings: The settings used to establish a connection - with the remote service. - - :param service_connection: An existing compiler gym service connection - to use. - - :param service_pool: A service pool to use for acquiring a service - connection. If not specified, the :meth:`global service pool - ` is used. - - :raises FileNotFoundError: If service is a path to a file that is not - found. - - :raises TimeoutError: If the compiler service fails to initialize within - the parameters provided in :code:`connection_settings`. - """ - # NOTE(cummins): Logger argument deprecated and scheduled to be removed - # in release 0.2.3. - if logger: - warnings.warn( - "The `logger` argument is deprecated on CompilerEnv.__init__() " - "and will be removed in a future release. All CompilerEnv " - "instances share a logger named compiler_gym.envs.compiler_env", - DeprecationWarning, - ) - - self.metadata = {"render.modes": ["human", "ansi"]} - - # A compiler service supports multiple simultaneous environments. This - # session ID is used to identify this environment. - self._session_id: Optional[int] = None - - self._service_endpoint: Union[str, Path] = service - self._connection_settings = connection_settings or ConnectionOpts() - - if service_connection is None: - self._service_pool: Optional[ServiceConnectionPool] = ( - ServiceConnectionPool.get() if service_pool is None else service_pool - ) - self.service = self._service_pool.acquire( - endpoint=self._service_endpoint, - opts=self._connection_settings, - ) - else: - self._service_pool: Optional[ServiceConnectionPool] = service_pool - self.service = service_connection - - self.datasets = Datasets(datasets or []) - - self.action_space_name = action_space - - # If no reward space is specified, generate some from numeric observation spaces - rewards = rewards or [ - DefaultRewardFromObservation(obs.name) - for obs in self.service.observation_spaces - if obs.default_observation.WhichOneof("value") - and isinstance( - getattr( - obs.default_observation, obs.default_observation.WhichOneof("value") - ), - numbers.Number, - ) - ] - - # The benchmark that is currently being used, and the benchmark that - # will be used on the next call to reset(). These are equal except in - # the gap between the user setting the env.benchmark property while in - # an episode and the next call to env.reset(). - self._benchmark_in_use: Optional[Benchmark] = None - self._benchmark_in_use_proto: BenchmarkProto = BenchmarkProto() - self._next_benchmark: Optional[Benchmark] = None - # Normally when the benchmark is changed the updated value is not - # reflected until the next call to reset(). We make an exception for the - # constructor-time benchmark as otherwise the behavior of the benchmark - # property is counter-intuitive: - # - # >>> env = gym.make("example-v0", benchmark="foo") - # >>> env.benchmark - # None - # >>> env.reset() - # >>> env.benchmark - # "foo" - # - # By forcing the _benchmark_in_use URI at constructor time, the first - # env.benchmark above returns the benchmark as expected. - try: - self.benchmark = benchmark or next(self.datasets.benchmarks()) - self._benchmark_in_use = self._next_benchmark - except StopIteration: - # StopIteration raised on next(self.datasets.benchmarks()) if there - # are no benchmarks available. This is to allow CompilerEnv to be - # used without any datasets by setting a benchmark before/during the - # first reset() call. - pass - - # Process the available action, observation, and reward spaces. - self.action_spaces = [ - proto_to_action_space(space) for space in self.service.action_spaces - ] - - self.observation = self._observation_view_type( - raw_step=self.raw_step, - spaces=self.service.observation_spaces, - ) - self.reward = self._reward_view_type(rewards, self.observation) - - # Register any derived observation spaces now so that the observation - # space can be set below. - for derived_observation_space in derived_observation_spaces or []: - self.observation.add_derived_space_internal(**derived_observation_space) - - # Lazily evaluated version strings. - self._versions: Optional[GetVersionReply] = None - - self.action_space: Optional[Space] = None - self.observation_space: Optional[Space] = None - - # Mutable state initialized in reset(). - self.reward_range: Tuple[float, float] = (-np.inf, np.inf) - self.episode_reward: Optional[float] = None - self.episode_start_time: float = time() - self.actions: List[ActionType] = [] - - # Initialize the default observation/reward spaces. - self.observation_space_spec: Optional[ObservationSpaceSpec] = None - self.reward_space_spec: Optional[Reward] = None - self.observation_space = observation_space - self.reward_space = reward_space - - @property - @deprecated( - version="0.2.1", - reason=( - "The `CompilerEnv.logger` attribute is deprecated. All CompilerEnv " - "instances share a logger named compiler_gym.envs.compiler_env" - ), - ) - def logger(self): - return _logger - - @property - def versions(self) -> GetVersionReply: - """Get the version numbers from the compiler service.""" - if self._versions is None: - self._versions = self.service( - self.service.stub.GetVersion, GetVersionRequest() - ) - return self._versions - - @property - def version(self) -> str: - """The version string of the compiler service.""" - return self.versions.service_version - - @property - def compiler_version(self) -> str: - """The version string of the underlying compiler that this service supports.""" - return self.versions.compiler_version - - def commandline(self) -> str: - """Interface for :class:`CompilerEnv ` - subclasses to provide an equivalent commandline invocation to the - current environment state. - - See also :meth:`commandline_to_actions() - `. - - Calling this method on a :class:`CompilerEnv - ` instance raises - :code:`NotImplementedError`. - - :return: A string commandline invocation. - """ - raise NotImplementedError("abstract method") - - def commandline_to_actions(self, commandline: str) -> List[ActionType]: - """Interface for :class:`CompilerEnv ` - subclasses to convert from a commandline invocation to a sequence of - actions. - - See also :meth:`commandline() - `. - - Calling this method on a :class:`CompilerEnv - ` instance raises - :code:`NotImplementedError`. - - :return: A list of actions. - """ ->>>>>>> 4a874cee (Fix typo in docstring.) raise NotImplementedError("abstract method") @property diff --git a/compiler_gym/service/BUILD b/compiler_gym/service/BUILD index 958624d44..717f528b2 100644 --- a/compiler_gym/service/BUILD +++ b/compiler_gym/service/BUILD @@ -14,6 +14,9 @@ py_library( deps = [ ":compilation_session", ":connection", + # TODO(github.com/facebookresearch/CompilerGym/pull/633): + # add this after circular dependencies are resolved + # ":client_service_compiler_env", ":connection_pool", # TODO(github.com/facebookresearch/CompilerGym/pull/633): # add this after circular dependencies are resolved diff --git a/compiler_gym/service/client_service_compiler_env.py b/compiler_gym/service/client_service_compiler_env.py index 9bddf1eb1..2c96e7a82 100644 --- a/compiler_gym/service/client_service_compiler_env.py +++ b/compiler_gym/service/client_service_compiler_env.py @@ -14,7 +14,6 @@ from time import time from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union -from compiler_gym.service.connection_pool import ServiceConnectionPool import numpy as np from deprecated.sphinx import deprecated from gym.spaces import Space @@ -32,6 +31,10 @@ ValidationError, ) from compiler_gym.service import CompilerGymServiceConnection, ConnectionOpts +from compiler_gym.service.connection_pool import ( + ServiceConnectionPool, + ServiceConnectionPoolBase, +) from compiler_gym.service.proto import ActionSpace, AddBenchmarkRequest from compiler_gym.service.proto import Benchmark as BenchmarkProto from compiler_gym.service.proto import ( @@ -136,6 +139,7 @@ def __init__( reward_space: Optional[Union[str, Reward]] = None, action_space: Optional[str] = None, derived_observation_spaces: Optional[List[Dict[str, Any]]] = None, + service_message_converters: ServiceMessageConverters = None, connection_settings: Optional[ConnectionOpts] = None, service_connection: Optional[CompilerGymServiceConnection] = None, service_pool: Optional[ServiceConnectionPool] = None, @@ -187,6 +191,9 @@ def __init__( passed to :meth:`env.observation.add_derived_space() `. + :param service_message_converters: Custom converters for action spaces + and actions. + :param connection_settings: The settings used to establish a connection with the remote service. @@ -207,9 +214,10 @@ def __init__( # in release 0.2.3. if logger: warnings.warn( - "The `logger` argument is deprecated on CompilerEnv.__init__() " - "and will be removed in a future release. All CompilerEnv " - "instances share a logger named compiler_gym.envs.compiler_env", + "The `logger` argument is deprecated on " + "ClientServiceCompilerEnv.__init__() and will be removed in a " + "future release. All ClientServiceCompilerEnv instances share " + "a logger named compiler_gym.envs.compiler_env", DeprecationWarning, ) @@ -234,7 +242,7 @@ def __init__( self._service_pool: Optional[ServiceConnectionPoolBase] = service_pool self.service = service_connection - self.datasets = Datasets(datasets or []) + self._datasets = Datasets(datasets or []) self.action_space_name = action_space @@ -277,14 +285,21 @@ def __init__( self._benchmark_in_use = self._next_benchmark except StopIteration: # StopIteration raised on next(self.datasets.benchmarks()) if there - # are no benchmarks available. This is to allow CompilerEnv to be - # used without any datasets by setting a benchmark before/during the - # first reset() call. + # are no benchmarks available. This is to allow + # ClientServiceCompilerEnv to be used without any datasets by + # setting a benchmark before/during the first reset() call. pass + self.service_message_converters = ( + ServiceMessageConverters() + if service_message_converters is None + else service_message_converters + ) + # Process the available action, observation, and reward spaces. self.action_spaces = [ - proto_to_action_space(space) for space in self.service.action_spaces + self.service_message_converters.action_space_converter(space) + for space in self.service.action_spaces ] self.observation = self._observation_view_type( @@ -308,7 +323,7 @@ def __init__( self._reward_range: Tuple[float, float] = (-np.inf, np.inf) self.episode_reward: Optional[float] = None self.episode_start_time: float = time() - self.actions: List[ActionType] = [] + self._actions: List[ActionType] = [] # Initialize the default observation/reward spaces. self.observation_space_spec: Optional[ObservationSpaceSpec] = None @@ -548,10 +563,11 @@ def _init_kwargs(self) -> Dict[str, Any]: "benchmark": self.benchmark, "connection_settings": self._connection_settings, "service": self._service_endpoint, + "service_pool": self._service_pool, } def fork(self) -> "ClientServiceCompilerEnv": - if not self.in_episode: + if not self.in_episode: actions = self.actions.copy() self.reset() if actions: @@ -607,7 +623,7 @@ def fork(self) -> "ClientServiceCompilerEnv": # Copy over the mutable episode state. new_env.episode_reward = self.episode_reward new_env.episode_start_time = self.episode_start_time - new_env.actions = self.actions.copy() + new_env._actions = self.actions.copy() # pylint: disable=protected-access return new_env @@ -824,11 +840,13 @@ def _call_with_error( self.observation.session_id = reply.session_id self.reward.get_cost = self.observation.__getitem__ self.episode_start_time = time() - self.actions = [] + self._actions: List[ActionType] = [] # If the action space has changed, update it. if reply.HasField("new_action_space"): - self.action_space = proto_to_action_space(reply.new_action_space) + self.action_space = self.service_message_converters.action_space_converter( + reply.new_action_space + ) self.reward.reset(benchmark=self.benchmark, observation_view=self.observation) if self.reward_space: @@ -864,15 +882,17 @@ def raw_step( and rewards are lists. :raises SessionNotFound: If :meth:`reset() - ` has not been called. + ` has not been + called. .. warning:: Don't call this method directly, use :meth:`step() - ` or :meth:`multistep() - ` instead. The - :meth:`raw_step() ` method is an - implementation detail. + ` or + :meth:`multistep() + ` instead. The + :meth:`raw_step() ` + method is an implementation detail. """ if not self.in_episode: raise SessionNotFound("Must call reset() before step()") @@ -892,12 +912,14 @@ def raw_step( } # Record the actions. - self.actions += actions + self._actions += actions # Send the request to the backend service. request = StepRequest( session_id=self._session_id, - action=[Event(int64_value=a) for a in actions], + action=[ + self.service_message_converters.action_converter(a) for a in actions + ], observation_space=[ observation_space.index for observation_space in observations_to_compute ], @@ -941,7 +963,9 @@ def raw_step( # If the action space has changed, update it. if reply.HasField("new_action_space"): - self.action_space = proto_to_action_space(reply.new_action_space) + self.action_space = self.service_message_converters.action_space_converter( + reply.new_action_space + ) # Translate observations to python representations. if len(reply.observation) != len(observations_to_compute): @@ -1266,8 +1290,9 @@ def validate(self, state: Optional[CompilerEnvState] = None) -> ValidationResult def send_param(self, key: str, value: str) -> str: """Send a single parameter to the compiler service. - See :meth:`send_params() ` - for more information. + See :meth:`send_params() + ` for more + information. :param key: The parameter key. From 0674661a4863c5c8865b33ec51c8d3de137b1a46 Mon Sep 17 00:00:00 2001 From: Chris Cummins Date: Wed, 20 Apr 2022 21:05:57 -0700 Subject: [PATCH 31/41] Run pre-commit formatters on sources. --- benchmarks/bench_test.py | 8 +++----- tests/llvm/service_connection_test.py | 1 - 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/benchmarks/bench_test.py b/benchmarks/bench_test.py index 0cbd58e5e..373e30491 100644 --- a/benchmarks/bench_test.py +++ b/benchmarks/bench_test.py @@ -20,13 +20,11 @@ import gym import pytest -import examples.example_compiler_gym_service as dummy +import examples.example_compiler_gym_service # noqa Environment import. +import examples.example_compiler_gym_service as dummy # noqa Environment import. from compiler_gym.envs import CompilerEnv, LlvmEnv, llvm -from compiler_gym.service import CompilerGymServiceConnection -from compiler_gym.service.client_service_compiler_env import ClientServiceCompilerEnv from compiler_gym.service import CompilerGymServiceConnection, ConnectionOpts -import examples.example_compiler_gym_service # noqa Environment import. -from compiler_gym.envs import CompilerEnv +from compiler_gym.service.client_service_compiler_env import ClientServiceCompilerEnv from tests.pytest_plugins.llvm import OBSERVATION_SPACE_NAMES, REWARD_SPACE_NAMES from tests.test_main import main diff --git a/tests/llvm/service_connection_test.py b/tests/llvm/service_connection_test.py index 581671253..6a37e40b3 100644 --- a/tests/llvm/service_connection_test.py +++ b/tests/llvm/service_connection_test.py @@ -12,7 +12,6 @@ from compiler_gym.envs.llvm.llvm_env import LlvmEnv from compiler_gym.errors import ServiceError from compiler_gym.service.client_service_compiler_env import ClientServiceCompilerEnv -from compiler_gym.service.connection import CompilerGymServiceConnection from compiler_gym.service.connection import CompilerGymServiceConnection, ConnectionOpts from compiler_gym.third_party.autophase import AUTOPHASE_FEATURE_DIM from tests.test_main import main From 712276f7597bb8ce0b73979b6c9587f501aacefb Mon Sep 17 00:00:00 2001 From: Chris Cummins Date: Tue, 19 Apr 2022 15:50:51 -0700 Subject: [PATCH 32/41] [tests] Add missing maxsize to lru_cache. --- examples/gcc_autotuning/tune_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/gcc_autotuning/tune_test.py b/examples/gcc_autotuning/tune_test.py index 1fd9900ce..30d592c4b 100644 --- a/examples/gcc_autotuning/tune_test.py +++ b/examples/gcc_autotuning/tune_test.py @@ -45,7 +45,7 @@ def system_has_functional_gcc(gcc_path: str) -> bool: return False -@lru_cache +@lru_cache(maxsize=1) def gcc_bins() -> Iterable[str]: """Return a list of available GCCs.""" if docker_is_available(): From c1f89eb54b2f9401faa44faa61f31ca06e80b054 Mon Sep 17 00:00:00 2001 From: Chris Cummins Date: Tue, 19 Apr 2022 15:50:51 -0700 Subject: [PATCH 33/41] Fix cmake dependency name. --- tests/llvm/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/llvm/CMakeLists.txt b/tests/llvm/CMakeLists.txt index a64199fee..e47939754 100644 --- a/tests/llvm/CMakeLists.txt +++ b/tests/llvm/CMakeLists.txt @@ -108,7 +108,7 @@ cg_py_test( DATA compiler_gym::third_party::cbench::crc32 DEPS - compiler_gym + compiler_gym::compiler_gym tests::pytest_plugins::llvm tests::test_main ) From 077b3bfba1cbf419d9bd147fd597eca964dfb79f Mon Sep 17 00:00:00 2001 From: Chris Cummins Date: Tue, 19 Apr 2022 15:50:51 -0700 Subject: [PATCH 34/41] Add missing cmake dependencies. --- compiler_gym/service/BUILD | 2 +- compiler_gym/service/CMakeLists.txt | 14 ++++++++++++-- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/compiler_gym/service/BUILD b/compiler_gym/service/BUILD index 717f528b2..541200807 100644 --- a/compiler_gym/service/BUILD +++ b/compiler_gym/service/BUILD @@ -15,7 +15,7 @@ py_library( ":compilation_session", ":connection", # TODO(github.com/facebookresearch/CompilerGym/pull/633): - # add this after circular dependencies are resolved + # add this after circular dependencies are resolved: # ":client_service_compiler_env", ":connection_pool", # TODO(github.com/facebookresearch/CompilerGym/pull/633): diff --git a/compiler_gym/service/CMakeLists.txt b/compiler_gym/service/CMakeLists.txt index 3e129fe76..f35b962c4 100644 --- a/compiler_gym/service/CMakeLists.txt +++ b/compiler_gym/service/CMakeLists.txt @@ -14,8 +14,9 @@ cg_py_library( ::compilation_session ::connection # TODO(github.com/facebookresearch/CompilerGym/pull/633): - # add this after circular dependencies are resolved - #::client_service_compiler_env + # add this after circular dependencies are resolved: + # ::client_service_compiler_env + ::connection_pool ::service_cache compiler_gym::errors::errors compiler_gym::service::proto::proto @@ -79,6 +80,15 @@ cg_py_library( PUBLIC ) +cg_py_library( + connection_pool + SRCS + "connection_pool.py" + DEPS + ::connection + PUBLIC +) + cg_py_library( NAME service_cache From 31a94302eb6cadad2947f23423fdcae660e19b6e Mon Sep 17 00:00:00 2001 From: Chris Cummins Date: Tue, 19 Apr 2022 15:50:51 -0700 Subject: [PATCH 35/41] Auto-formatted sources. --- compiler_gym/envs/gcc/gcc_env.py | 2 +- tests/service/CMakeLists.txt | 1 + tests/service/connection_pool_test.py | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/compiler_gym/envs/gcc/gcc_env.py b/compiler_gym/envs/gcc/gcc_env.py index 3ca0b6246..245446659 100644 --- a/compiler_gym/envs/gcc/gcc_env.py +++ b/compiler_gym/envs/gcc/gcc_env.py @@ -16,8 +16,8 @@ from compiler_gym.envs.gcc.gcc_rewards import AsmSizeReward, ObjSizeReward from compiler_gym.service import ConnectionOpts from compiler_gym.service.client_service_compiler_env import ClientServiceCompilerEnv -from compiler_gym.spaces import Reward from compiler_gym.service.connection_pool import ServiceConnectionPoolBase +from compiler_gym.spaces import Reward from compiler_gym.util.decorators import memoized_property from compiler_gym.util.gym_type_hints import ObservationType, OptionalArgumentValue from compiler_gym.views import ObservationSpaceSpec diff --git a/tests/service/CMakeLists.txt b/tests/service/CMakeLists.txt index 6d1b77a5c..22d41faa9 100644 --- a/tests/service/CMakeLists.txt +++ b/tests/service/CMakeLists.txt @@ -15,6 +15,7 @@ if(COMPILER_GYM_ENABLE_LLVM_ENV) compiler_gym::compiler_gym compiler_gym::envs::envs compiler_gym::errors::errors + compiler_gym::service::service compiler_gym::service::service_cache tests::test_main ) diff --git a/tests/service/connection_pool_test.py b/tests/service/connection_pool_test.py index 62ee5d2f6..cb6773d39 100644 --- a/tests/service/connection_pool_test.py +++ b/tests/service/connection_pool_test.py @@ -8,8 +8,8 @@ import compiler_gym from compiler_gym.envs.llvm import LLVM_SERVICE_BINARY -from compiler_gym.service import ConnectionOpts, ServiceConnectionPool from compiler_gym.errors import ServiceError +from compiler_gym.service import ConnectionOpts, ServiceConnectionPool from tests.test_main import main pytest_plugins = ["tests.pytest_plugins.llvm"] From 8735fce0a4405a5ef42ecba83f927f7a2b38f862 Mon Sep 17 00:00:00 2001 From: Chris Cummins Date: Tue, 19 Apr 2022 15:50:51 -0700 Subject: [PATCH 36/41] Update tests to add ConnectionOpts argument. --- tests/compiler_env_test.py | 4 ++-- tests/mlir/mlir_env_test.py | 6 ++++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/compiler_env_test.py b/tests/compiler_env_test.py index 0049dbaa9..7d73c157c 100644 --- a/tests/compiler_env_test.py +++ b/tests/compiler_env_test.py @@ -9,7 +9,7 @@ from compiler_gym.envs import llvm from compiler_gym.envs.llvm import LlvmEnv -from compiler_gym.service.connection import CompilerGymServiceConnection +from compiler_gym.service.connection import CompilerGymServiceConnection, ConnectionOpts from tests.test_main import main pytest_plugins = ["tests.pytest_plugins.llvm"] @@ -174,7 +174,7 @@ def test_step_session_id_not_found(env: LlvmEnv): @pytest.fixture(scope="function") def remote_env() -> LlvmEnv: """A test fixture that yields a connection to a remote service.""" - service = CompilerGymServiceConnection(llvm.LLVM_SERVICE_BINARY) + service = CompilerGymServiceConnection(llvm.LLVM_SERVICE_BINARY, ConnectionOpts()) try: with LlvmEnv(service=service.connection.url) as env: yield env diff --git a/tests/mlir/mlir_env_test.py b/tests/mlir/mlir_env_test.py index 35294d7ad..c03958cfb 100644 --- a/tests/mlir/mlir_env_test.py +++ b/tests/mlir/mlir_env_test.py @@ -12,7 +12,7 @@ import compiler_gym from compiler_gym.envs import CompilerEnv, mlir from compiler_gym.envs.mlir import MlirEnv -from compiler_gym.service.connection import CompilerGymServiceConnection +from compiler_gym.service.connection import CompilerGymServiceConnection, ConnectionOpts from compiler_gym.spaces import ( Box, Dict, @@ -36,7 +36,9 @@ def env(request) -> CompilerEnv: with gym.make("mlir-v0") as env: yield env else: - service = CompilerGymServiceConnection(mlir.MLIR_SERVICE_BINARY) + service = CompilerGymServiceConnection( + mlir.MLIR_SERVICE_BINARY, ConnectionOpts() + ) try: with MlirEnv(service=service.connection.url) as env: yield env From 7228951fc86a794ab806671e1d0fc34cc5f99418 Mon Sep 17 00:00:00 2001 From: Chris Cummins Date: Tue, 19 Apr 2022 15:50:51 -0700 Subject: [PATCH 37/41] Bump timeout on GCC smoke test to 10 min. --- examples/gcc_autotuning/tune_test.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/gcc_autotuning/tune_test.py b/examples/gcc_autotuning/tune_test.py index 30d592c4b..cadab187b 100644 --- a/examples/gcc_autotuning/tune_test.py +++ b/examples/gcc_autotuning/tune_test.py @@ -60,6 +60,7 @@ def gcc_bin(request) -> str: return request.param +@pytest.mark.timeout(600) @pytest.mark.parametrize("search", ["random", "hillclimb", "genetic"]) def test_tune_smoke_test(search: str, gcc_bin: str, capsys, tmpdir: Path): tmpdir = Path(tmpdir) From fa2b50ef43b21be5c2248ba9ac8132c063f06ee9 Mon Sep 17 00:00:00 2001 From: Chris Cummins Date: Tue, 19 Apr 2022 15:50:51 -0700 Subject: [PATCH 38/41] [ci] Remove redundant comment. --- .github/workflows/ci.yaml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index ac836fe9e..8bd8db03c 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -59,7 +59,6 @@ jobs: else wget https://github.com/llvm/llvm-project/releases/download/llvmorg-10.0.0/clang+llvm-10.0.0-x86_64-linux-gnu-ubuntu-18.04.tar.xz -O ~/llvm.tar.xz fi - # TODO(cummins): Remove 'v' debugging flag: mkdir ~/llvm && tar xvf ~/llvm.tar.xz --strip-components 1 -C ~/llvm rm ~/llvm.tar.xz echo "Unpacked, testing for expected file:" From b6bb7fbb13192def37a1f3d291f4c420a5cd4afd Mon Sep 17 00:00:00 2001 From: Chris Cummins Date: Tue, 19 Apr 2022 15:50:51 -0700 Subject: [PATCH 39/41] Tidy up build dependencies. --- compiler_gym/service/BUILD | 9 +++------ compiler_gym/service/CMakeLists.txt | 4 ++-- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/compiler_gym/service/BUILD b/compiler_gym/service/BUILD index 541200807..314e33e05 100644 --- a/compiler_gym/service/BUILD +++ b/compiler_gym/service/BUILD @@ -12,15 +12,12 @@ py_library( ], visibility = ["//visibility:public"], deps = [ - ":compilation_session", - ":connection", # TODO(github.com/facebookresearch/CompilerGym/pull/633): # add this after circular dependencies are resolved: # ":client_service_compiler_env", + ":compilation_session", + ":connection", ":connection_pool", - # TODO(github.com/facebookresearch/CompilerGym/pull/633): - # add this after circular dependencies are resolved - # ":client_service_compiler_env", ":service_cache", "//compiler_gym/errors", "//compiler_gym/service/proto", @@ -86,7 +83,7 @@ py_library( visibility = ["//visibility:public"], deps = [ ":connection", - ] + ], ) py_library( diff --git a/compiler_gym/service/CMakeLists.txt b/compiler_gym/service/CMakeLists.txt index f35b962c4..9dd35668d 100644 --- a/compiler_gym/service/CMakeLists.txt +++ b/compiler_gym/service/CMakeLists.txt @@ -11,11 +11,11 @@ cg_py_library( SRCS "__init__.py" DEPS - ::compilation_session - ::connection # TODO(github.com/facebookresearch/CompilerGym/pull/633): # add this after circular dependencies are resolved: # ::client_service_compiler_env + ::compilation_session + ::connection ::connection_pool ::service_cache compiler_gym::errors::errors From 884fa10abbf3882f998a7d6dbeee5b1ac91e709f Mon Sep 17 00:00:00 2001 From: Chris Cummins Date: Tue, 19 Apr 2022 15:50:51 -0700 Subject: [PATCH 40/41] Remove unused import. --- compiler_gym/service/connection.py | 1 - 1 file changed, 1 deletion(-) diff --git a/compiler_gym/service/connection.py b/compiler_gym/service/connection.py index 2a74f1d98..38b760b57 100644 --- a/compiler_gym/service/connection.py +++ b/compiler_gym/service/connection.py @@ -14,7 +14,6 @@ import grpc from deprecated.sphinx import deprecated -from pydantic import BaseModel from frozendict import frozendict from pydantic import BaseModel, root_validator From 8847e93e58676a863232748a84afa9843a871579 Mon Sep 17 00:00:00 2001 From: Chris Cummins Date: Tue, 19 Apr 2022 15:50:51 -0700 Subject: [PATCH 41/41] Merge connection_pool target to fix CMake build. This removes the connection_pool target by merging it into the main package. This is needed to fix the CMake build. --- compiler_gym/service/BUILD | 11 +---------- compiler_gym/service/CMakeLists.txt | 11 +---------- tests/service/BUILD | 2 +- 3 files changed, 3 insertions(+), 21 deletions(-) diff --git a/compiler_gym/service/BUILD b/compiler_gym/service/BUILD index 314e33e05..1ce2c4cd7 100644 --- a/compiler_gym/service/BUILD +++ b/compiler_gym/service/BUILD @@ -9,6 +9,7 @@ py_library( name = "service", srcs = [ "__init__.py", + "connection_pool.py", ], visibility = ["//visibility:public"], deps = [ @@ -17,7 +18,6 @@ py_library( # ":client_service_compiler_env", ":compilation_session", ":connection", - ":connection_pool", ":service_cache", "//compiler_gym/errors", "//compiler_gym/service/proto", @@ -77,15 +77,6 @@ py_library( ], ) -py_library( - name = "connection_pool", - srcs = ["connection_pool.py"], - visibility = ["//visibility:public"], - deps = [ - ":connection", - ], -) - py_library( name = "service_cache", srcs = ["service_cache.py"], diff --git a/compiler_gym/service/CMakeLists.txt b/compiler_gym/service/CMakeLists.txt index 9dd35668d..eb6cf5d26 100644 --- a/compiler_gym/service/CMakeLists.txt +++ b/compiler_gym/service/CMakeLists.txt @@ -10,13 +10,13 @@ cg_py_library( service SRCS "__init__.py" + "connection_pool.py" DEPS # TODO(github.com/facebookresearch/CompilerGym/pull/633): # add this after circular dependencies are resolved: # ::client_service_compiler_env ::compilation_session ::connection - ::connection_pool ::service_cache compiler_gym::errors::errors compiler_gym::service::proto::proto @@ -80,15 +80,6 @@ cg_py_library( PUBLIC ) -cg_py_library( - connection_pool - SRCS - "connection_pool.py" - DEPS - ::connection - PUBLIC -) - cg_py_library( NAME service_cache diff --git a/tests/service/BUILD b/tests/service/BUILD index d8fb8117a..86e6e52ae 100644 --- a/tests/service/BUILD +++ b/tests/service/BUILD @@ -33,7 +33,7 @@ py_test( timeout = "short", srcs = ["service_cache_test.py"], deps = [ - "//compiler_gym/service:service_cache", + "//compiler_gym/service", "//tests:test_main", ], )