Skip to content

Commit

Permalink
[service] Add a base class for the connection pool.
Browse files Browse the repository at this point in the history
Add a base class for the ServiceConnectionPool that provides the same
interface but does no caching.
  • Loading branch information
ChrisCummins committed Mar 7, 2022
1 parent a5a551c commit 4cd83f7
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 6 deletions.
11 changes: 7 additions & 4 deletions compiler_gym/envs/compiler_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,10 @@
ServiceTransportError,
SessionNotFound,
)
from compiler_gym.service.connection_pool import ServiceConnectionPool
from compiler_gym.service.connection_pool import (
ServiceConnectionPool,
ServiceConnectionPoolBase,
)
from compiler_gym.service.proto import AddBenchmarkRequest
from compiler_gym.service.proto import Benchmark as BenchmarkProto
from compiler_gym.service.proto import (
Expand Down Expand Up @@ -165,7 +168,7 @@ def __init__(
derived_observation_spaces: Optional[List[Dict[str, Any]]] = None,
connection_settings: Optional[ConnectionOpts] = None,
service_connection: Optional[CompilerGymServiceConnection] = None,
service_pool: Optional[ServiceConnectionPool] = None,
service_pool: Optional[ServiceConnectionPoolBase] = None,
logger: Optional[logging.Logger] = None,
):
"""Construct and initialize a CompilerGym environment.
Expand Down Expand Up @@ -250,15 +253,15 @@ def __init__(
self._connection_settings = connection_settings or ConnectionOpts()

if service_connection is None:
self._service_pool: Optional[ServiceConnectionPool] = (
self._service_pool: Optional[ServiceConnectionPoolBase] = (
ServiceConnectionPool.get() if service_pool is None else service_pool
)
self.service = self._service_pool.acquire(
endpoint=self._service_endpoint,
opts=self._connection_settings,
)
else:
self._service_pool: Optional[ServiceConnectionPool] = service_pool
self._service_pool: Optional[ServiceConnectionPoolBase] = service_pool
self.service = service_connection

self.datasets = Datasets(datasets or [])
Expand Down
18 changes: 16 additions & 2 deletions compiler_gym/service/connection_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,21 @@
ServiceConnectionCacheKey = Tuple[Path, ConnectionOpts]


class ServiceConnectionPool:
class ServiceConnectionPoolBase:
"""A class that provides the base interface for service connection pools."""

def acquire(
self, endpoint: Path, opts: ConnectionOpts
) -> CompilerGymServiceConnection:
return CompilerGymServiceConnection(
endpoint=endpoint, opts=opts, owning_service_pool=self
)

def release(self, service: CompilerGymServiceConnection) -> None:
pass


class ServiceConnectionPool(ServiceConnectionPoolBase):
"""An object pool for compiler service connections.
This class implements a thread-safe pool for compiler service connections.
Expand Down Expand Up @@ -52,7 +66,7 @@ class ServiceConnectionPool:
:vartype allocated: Set[CompilerGymServiceConnection]
"""

def __init__(self) -> None:
def __init__(self):
""""""
self._lock = Lock()
self.pool: Dict[
Expand Down

0 comments on commit 4cd83f7

Please sign in to comment.