Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add an in-process CompilerGymService for python backends #727

Draft
wants to merge 9 commits into
base: development
Choose a base branch
from
12 changes: 0 additions & 12 deletions compiler_gym/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ py_library(
srcs = ["__init__.py"],
visibility = ["//visibility:public"],
deps = [
":random_replay",
":random_search",
":validate",
"//compiler_gym/bin",
Expand All @@ -38,17 +37,6 @@ py_library(
],
)

py_library(
name = "random_replay",
srcs = ["random_replay.py"],
visibility = ["//visibility:public"],
deps = [
":random_search",
"//compiler_gym/envs",
"//compiler_gym/util",
],
)

py_library(
name = "random_search",
srcs = ["random_search.py"],
Expand Down
2 changes: 1 addition & 1 deletion compiler_gym/bin/manual_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -716,7 +716,7 @@ def do_set_default_reward(self, arg):

def do_commandline(self, arg):
"""Show the command line equivalent of the actions taken so far"""
print("$", self.env.commandline(), flush=True)
print("$", self.env.action_space.to_string(self.env.actions), flush=True)

def do_stack(self, arg):
"""Show the environments on the stack. The current environment is the first shown."""
Expand Down
20 changes: 0 additions & 20 deletions compiler_gym/datasets/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,6 @@

logger = logging.getLogger(__name__)

# NOTE(cummins): This is only required to prevent a name conflict with the now
# deprecated Dataset.logger attribute. This can be removed once the logger
# attribute is removed, scheduled for release 0.2.3.
_logger = logger

_DATASET_VERSION_PATTERN = r"[a-zA-z0-9-_]+-v(?P<version>[0-9]+)"
_DATASET_VERSION_RE = re.compile(_DATASET_VERSION_PATTERN)

Expand Down Expand Up @@ -131,21 +126,6 @@ def __init__(
def __repr__(self):
return self.name

@property
@mark_deprecated(
version="0.2.1",
reason=(
"The `Dataset.logger` attribute is deprecated. All Dataset "
"instances share a logger named compiler_gym.datasets"
),
)
def logger(self) -> logging.Logger:
"""The logger for this dataset.

:type: logging.Logger
"""
return _logger

@property
def name(self) -> str:
"""The name of the dataset.
Expand Down
38 changes: 0 additions & 38 deletions compiler_gym/datasets/uri.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,49 +3,11 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""This module contains utility code for working with URIs."""
import re
from typing import Dict, List, Union
from urllib.parse import ParseResult, parse_qs, urlencode, urlparse, urlunparse

from deprecated.sphinx import deprecated
from pydantic import BaseModel

# === BEGIN DEPRECATED DECLARATIONS ===
#
# The following regular expression definitions have been deprecated and will be
# removed in a future release! Please update your code to use the new
# BenchmarkUri class defined in this file.

# Regular expression that matches the full two-part URI prefix of a dataset:
# {{scheme}}://{{dataset}}
#
# An optional trailing slash is permitted.
#
# Example matches: "benchmark://foo-v0", "generator://bar-v0/".
DATASET_NAME_PATTERN = r"(?P<dataset>(?P<dataset_protocol>[a-zA-z0-9-_]+)://(?P<dataset_name>[a-zA-z0-9-_]+-v(?P<dataset_version>[0-9]+)))/?"
DATASET_NAME_RE = re.compile(DATASET_NAME_PATTERN)

# Regular expression that matches the full three-part format of a benchmark URI:
# {{sceme}}://{{dataset}}/{{id}}
#
# Example matches: "benchmark://foo-v0/foo" or "generator://bar-v1/foo/bar.txt".
BENCHMARK_URI_PATTERN = r"(?P<dataset>(?P<dataset_protocol>[a-zA-z0-9-_]+)://(?P<dataset_name>[a-zA-z0-9-_]+-v(?P<dataset_version>[0-9]+)))/(?P<benchmark_name>.+)$"
BENCHMARK_URI_RE = re.compile(BENCHMARK_URI_PATTERN)

# === END DEPRECATED DECLARATIONS ===


@deprecated(
version="0.2.2",
reason=("Use compiler_gym.datasets.BenchmarkUri.canonicalize()"),
)
def resolve_uri_protocol(uri: str) -> str:
"""Require that the URI has a scheme by applying a default "benchmark"
scheme if none is set."""
if "://" not in uri:
return f"benchmark://{uri}"
return uri


class BenchmarkUri(BaseModel):
"""A URI used to identify a benchmark, and optionally a set of parameters
Expand Down
3 changes: 1 addition & 2 deletions compiler_gym/envs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import compiler_gym.envs.loop_tool # noqa
from compiler_gym import config
from compiler_gym.envs.compiler_env import CompilerEnv
from compiler_gym.envs.gcc import GccEnv
Expand All @@ -11,14 +12,12 @@
if config.enable_mlir_env:
from compiler_gym.envs.mlir.mlir_env import MlirEnv # noqa: F401

from compiler_gym.envs.loop_tool.loop_tool_env import LoopToolEnv
from compiler_gym.util.registration import COMPILER_GYM_ENVS

__all__ = [
"COMPILER_GYM_ENVS",
"CompilerEnv",
"GccEnv",
"LoopToolEnv",
]

if config.enable_llvm_env:
Expand Down
22 changes: 9 additions & 13 deletions compiler_gym/envs/compiler_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from compiler_gym.compiler_env_state import CompilerEnvState
from compiler_gym.datasets import Benchmark, BenchmarkUri, Dataset
from compiler_gym.spaces import Reward
from compiler_gym.spaces import ActionSpace, Reward
from compiler_gym.util.gym_type_hints import (
ActionType,
ObservationType,
Expand Down Expand Up @@ -182,14 +182,6 @@ def episode_reward(self, episode_reward: Optional[float]):
def actions(self) -> List[ActionType]:
raise NotImplementedError("abstract method")

@property
@abstractmethod
@deprecated(
version="0.2.1",
)
def logger(self):
raise NotImplementedError("abstract method")

@property
@abstractmethod
def version(self) -> str:
Expand All @@ -210,7 +202,7 @@ def state(self) -> CompilerEnvState:

@property
@abstractmethod
def action_space(self) -> Space:
def action_space(self) -> ActionSpace:
"""The current action space.

:getter: Get the current action space.
Expand All @@ -227,7 +219,7 @@ def action_space(self, action_space: Optional[str]):

@property
@abstractmethod
def action_spaces(self) -> List[str]:
def action_spaces(self) -> List[ActionSpace]:
"""A list of supported action space names."""
raise NotImplementedError("abstract method")

Expand Down Expand Up @@ -481,7 +473,9 @@ def render(
"""
raise NotImplementedError("abstract method")

@abstractmethod
@deprecated(
version="0.2.5", reason="Use env.action_space.to_string(env.actions) instead"
)
def commandline(self) -> str:
"""Interface for :class:`CompilerEnv <compiler_gym.envs.CompilerEnv>`
subclasses to provide an equivalent commandline invocation to the
Expand All @@ -494,7 +488,9 @@ def commandline(self) -> str:
"""
raise NotImplementedError("abstract method")

@abstractmethod
@deprecated(
version="0.2.5", reason='Use env.action_space.from_string("...") instead'
)
def commandline_to_actions(self, commandline: str) -> List[ActionType]:
"""Interface for :class:`CompilerEnv <compiler_gym.envs.CompilerEnv>`
subclasses to convert from a commandline invocation to a sequence of
Expand Down
2 changes: 1 addition & 1 deletion compiler_gym/envs/gcc/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ py_library(
deps = [
"//compiler_gym/envs/gcc/datasets",
"//compiler_gym/errors",
"//compiler_gym/service:client_service_compiler_env",
"//compiler_gym/service:in_process_client_compiler_env",
"//compiler_gym/service/runtime", # Implicit dependency of service.
"//compiler_gym/util",
],
Expand Down
2 changes: 1 addition & 1 deletion compiler_gym/envs/gcc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ cg_py_library(
DATA
compiler_gym::envs::gcc::service::service
DEPS
compiler_gym::service::client_service_compiler_env
compiler_gym::service::in_process_client_compiler_env
compiler_gym::envs::gcc::datasets::datasets
compiler_gym::errors::errors
compiler_gym::service::runtime::runtime
Expand Down
14 changes: 12 additions & 2 deletions compiler_gym/envs/gcc/gcc_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,17 @@
from pathlib import Path
from typing import Any, Dict, List, Optional, Union

from deprecated.sphinx import deprecated

from compiler_gym.datasets import Benchmark
from compiler_gym.envs.gcc.datasets import get_gcc_datasets
from compiler_gym.envs.gcc.gcc import Gcc, GccSpec
from compiler_gym.envs.gcc.gcc_rewards import AsmSizeReward, ObjSizeReward
from compiler_gym.envs.gcc.service.gcc_service import make_gcc_compilation_session
from compiler_gym.service import ConnectionOpts
from compiler_gym.service.client_service_compiler_env import ClientServiceCompilerEnv
from compiler_gym.service.in_process_client_compiler_env import (
InProcessClientCompilerEnv,
)
from compiler_gym.spaces import Reward
from compiler_gym.util.decorators import memoized_property
from compiler_gym.util.gym_type_hints import ObservationType, OptionalArgumentValue
Expand All @@ -24,7 +29,7 @@
DEFAULT_GCC: str = "docker:gcc:11.2.0"


class GccEnv(ClientServiceCompilerEnv):
class GccEnv(InProcessClientCompilerEnv):
"""A specialized ClientServiceCompilerEnv for GCC.

This class exposes the optimization space of GCC's command line flags
Expand Down Expand Up @@ -79,6 +84,7 @@ def __init__(
super().__init__(
*args,
**kwargs,
make_session=make_gcc_compilation_session(),
benchmark=benchmark,
datasets=get_gcc_datasets(
gcc_bin=gcc_bin, site_data_base=datasets_site_path
Expand Down Expand Up @@ -109,6 +115,10 @@ def reset(
self.send_param("timeout", str(self._timeout))
return observation

@deprecated(
version="0.2.1",
reason="Use `env.observation.command_line()` instead",
)
def commandline(self) -> str:
"""Return a string representing the command line options.

Expand Down
11 changes: 11 additions & 0 deletions compiler_gym/envs/llvm/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ py_library(
":benchmark_from_command_line",
":compute_observation",
":llvm_benchmark",
":llvm_command_line",
":llvm_env",
"//compiler_gym/util",
],
Expand Down Expand Up @@ -54,12 +55,22 @@ py_library(
],
)

py_library(
name = "llvm_command_line",
srcs = ["llvm_command_line.py"],
deps = [
"//compiler_gym/spaces",
"//compiler_gym/util",
],
)

py_library(
name = "llvm_env",
srcs = ["llvm_env.py"],
deps = [
":benchmark_from_command_line",
":llvm_benchmark",
":llvm_command_line",
":llvm_rewards",
"//compiler_gym/datasets",
"//compiler_gym/envs/llvm/datasets",
Expand Down
12 changes: 12 additions & 0 deletions compiler_gym/envs/llvm/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ cg_py_library(
::benchmark_from_command_line
::compute_observation
::llvm_benchmark
::llvm_command_line
::llvm_env
compiler_gym::util::util
PUBLIC
Expand Down Expand Up @@ -56,13 +57,24 @@ cg_py_library(
PUBLIC
)

cg_py_library(
NAME
llvm_command_line
SRCS
llvm_command_line.py
DEPS
compiler_gym::spaces::spaces
compiler_gym::util::util
)

cg_py_library(
NAME
llvm_env
SRCS
"llvm_env.py"
DEPS
::llvm_benchmark
::llvm_command_line
::llvm_rewards
compiler_gym::datasets::datasets
compiler_gym::errors::errors
Expand Down
2 changes: 2 additions & 0 deletions compiler_gym/envs/llvm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
get_system_library_flags,
make_benchmark,
)
from compiler_gym.envs.llvm.llvm_command_line import LlvmCommandLine
from compiler_gym.envs.llvm.llvm_env import LlvmEnv

# TODO(github.com/facebookresearch/CompilerGym/issues/506): Tidy up.
Expand All @@ -24,6 +25,7 @@

__all__ = [
"BenchmarkFromCommandLine",
"LlvmCommandLine",
"ClangInvocation",
"compute_observation",
"get_system_library_flags",
Expand Down
40 changes: 40 additions & 0 deletions compiler_gym/envs/llvm/llvm_command_line.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import List

from compiler_gym.spaces import ActionSpace
from compiler_gym.util.gym_type_hints import ActionType


class LlvmCommandLine(ActionSpace):
"""An action space for LLVM that supports serializing / deserializing to
opt command line.
"""

def to_string(self, actions: List[ActionType]) -> str:
"""Returns an LLVM :code:`opt` command line invocation for the given actions.

:param actions: A list of actions to serialize.

:returns: A command line string.
"""
return f"opt {self.wrapped.to_string(actions)} input.bc -o output.bc"

def from_string(self, string: str) -> List[ActionType]:
"""Returns a list of actions from the given command line.

:param commandline: A command line invocation.

:return: A list of actions.

:raises ValueError: In case the command line string is malformed.
"""
if string.startswith("opt "):
string = string[len("opt ") :]

if string.endswith(" input.bc -o output.bc"):
string = string[: -len(" input.bc -o output.bc")]

return self.wrapped.from_string(string)
Loading