Skip to content

Commit

Permalink
Move database connection to singleton
Browse files Browse the repository at this point in the history
  • Loading branch information
marrobi committed Dec 22, 2023
1 parent f87f2bc commit 69c8568
Show file tree
Hide file tree
Showing 27 changed files with 278 additions and 278 deletions.
4 changes: 2 additions & 2 deletions api_app/api/dependencies/airlock.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from fastapi import Depends, HTTPException, Path, status
from pydantic import UUID4

from api.dependencies.database import get_repository
from api.dependencies.database import Database
from db.repositories.airlock_requests import AirlockRequestRepository
from models.domain.airlock_request import AirlockRequest
from db.errors import EntityDoesNotExist, UnableToAccessDatabase
Expand All @@ -17,5 +17,5 @@ async def get_airlock_request_by_id(airlock_request_id: UUID4, airlock_request_r
raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=strings.STATE_STORE_ENDPOINT_NOT_RESPONDING)


async def get_airlock_request_by_id_from_path(airlock_request_id: UUID4 = Path(...), airlock_request_repo=Depends(get_repository(AirlockRequestRepository))) -> AirlockRequest:
async def get_airlock_request_by_id_from_path(airlock_request_id: UUID4 = Path(...), airlock_request_repo=Depends(Database().get_repository(AirlockRequestRepository))) -> AirlockRequest:
return await get_airlock_request_by_id(airlock_request_id, airlock_request_repo)
165 changes: 82 additions & 83 deletions api_app/api/dependencies/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,91 +11,90 @@
from services.logging import logger


async def connect_to_db() -> CosmosClient:
logger.debug(f"Connecting to {STATE_STORE_ENDPOINT}")

async with get_credential_async() as credential:
if MANAGED_IDENTITY_CLIENT_ID:
logger.debug("Connecting with managed identity")
cosmos_client = CosmosClient(
url=STATE_STORE_ENDPOINT,
credential=credential
)
else:
logger.debug("Connecting with key")
primary_master_key = await get_store_key(credential)
class Singleton(type):
_instances = {}
def __call__(cls, *args, **kwargs):
if cls not in cls._instances:
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
return cls._instances[cls]


class Database(metaclass=Singleton):
cosmos_client = None

def __init__(self):
pass

@classmethod
async def _connect_to_db(self) -> CosmosClient:
logger.debug(f"Connecting to {STATE_STORE_ENDPOINT}")

if STATE_STORE_SSL_VERIFY:
logger.debug("Connecting with SSL verification")
async with get_credential_async() as credential:
if MANAGED_IDENTITY_CLIENT_ID:
logger.debug("Connecting with managed identity")
cosmos_client = CosmosClient(
url=STATE_STORE_ENDPOINT, credential=primary_master_key
url=STATE_STORE_ENDPOINT,
credential=credential
)
else:
logger.debug("Connecting without SSL verification")
# ignore TLS (setup is a pain) when using local Cosmos emulator.
cosmos_client = CosmosClient(
STATE_STORE_ENDPOINT, primary_master_key, connection_verify=False
logger.debug("Connecting with key")
primary_master_key = await self._get_store_key(credential)

if STATE_STORE_SSL_VERIFY:
logger.debug("Connecting with SSL verification")
cosmos_client = CosmosClient(
url=STATE_STORE_ENDPOINT, credential=primary_master_key
)
else:
logger.debug("Connecting without SSL verification")
# ignore TLS (setup is a pain) when using local Cosmos emulator.
cosmos_client = CosmosClient(
STATE_STORE_ENDPOINT, primary_master_key, connection_verify=False
)
logger.debug("Connection established")
return cosmos_client

@classmethod
async def _get_store_key(self, credential) -> str:
logger.debug("Getting store key")
if STATE_STORE_KEY:
primary_master_key = STATE_STORE_KEY
else:
async with CosmosDBManagementClient(
credential,
subscription_id=SUBSCRIPTION_ID,
base_url=RESOURCE_MANAGER_ENDPOINT,
credential_scopes=CREDENTIAL_SCOPES
) as cosmosdb_mng_client:
database_keys = await cosmosdb_mng_client.database_accounts.list_keys(
resource_group_name=RESOURCE_GROUP_NAME,
account_name=COSMOSDB_ACCOUNT_NAME,
)
logger.debug("Connection established")
return cosmos_client


async def get_store_key(credential) -> str:
logger.debug("Getting store key")
if STATE_STORE_KEY:
primary_master_key = STATE_STORE_KEY
else:
async with CosmosDBManagementClient(
credential,
subscription_id=SUBSCRIPTION_ID,
base_url=RESOURCE_MANAGER_ENDPOINT,
credential_scopes=CREDENTIAL_SCOPES
) as cosmosdb_mng_client:
database_keys = await cosmosdb_mng_client.database_accounts.list_keys(
resource_group_name=RESOURCE_GROUP_NAME,
account_name=COSMOSDB_ACCOUNT_NAME,
)
primary_master_key = database_keys.primary_master_key

return primary_master_key


async def get_db_client(app: FastAPI) -> CosmosClient:
logger.debug("Getting cosmos client")
cosmos_client = None
if hasattr(app.state, 'cosmos_client') and app.state.cosmos_client:
logger.debug("Cosmos client found in state")
cosmos_client = app.state.cosmos_client
# TODO: if session is closed recreate - need to investigate why this is happening
# https://github.com/Azure/azure-sdk-for-python/issues/32309
if hasattr(cosmos_client.client_connection, "session") and not cosmos_client.client_connection.session:
logger.debug("Cosmos client session is None")
cosmos_client = await connect_to_db()
else:
logger.debug("No cosmos client found, creating one")
cosmos_client = await connect_to_db()

app.state.cosmos_client = cosmos_client
return app.state.cosmos_client


async def get_db_client_from_request(request: Request) -> CosmosClient:
return await get_db_client(request.app)


def get_repository(
repo_type: Type[BaseRepository],
) -> Callable[[CosmosClient], BaseRepository]:
async def _get_repo(
client: CosmosClient = Depends(get_db_client_from_request),
) -> BaseRepository:
try:
return await repo_type.create(client)
except UnableToAccessDatabase:
logger.exception(strings.STATE_STORE_ENDPOINT_NOT_RESPONDING)
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail=strings.STATE_STORE_ENDPOINT_NOT_RESPONDING,
)

return _get_repo
primary_master_key = database_keys.primary_master_key

return primary_master_key

@classmethod
async def get_db_client(self) -> CosmosClient:
logger.debug("Getting cosmos client")
if not Database.cosmos_client:
Database.cosmos_client = await self._connect_to_db()
return self.cosmos_client


@classmethod
def get_repository(self,
repo_type: Type[BaseRepository],
) -> Callable[[CosmosClient], BaseRepository]:

async def _get_repo() -> BaseRepository:
try:
return await repo_type.create(self.cosmos_client)
except UnableToAccessDatabase:
logger.exception(strings.STATE_STORE_ENDPOINT_NOT_RESPONDING)
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail=strings.STATE_STORE_ENDPOINT_NOT_RESPONDING,
)

return _get_repo
6 changes: 3 additions & 3 deletions api_app/api/dependencies/shared_services.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from fastapi import Depends, HTTPException, Path, status
from pydantic import UUID4

from api.dependencies.database import get_repository
from api.dependencies.database import Database
from db.errors import EntityDoesNotExist
from resources import strings
from models.domain.shared_service import SharedService
Expand All @@ -17,11 +17,11 @@ async def get_shared_service_by_id(shared_service_id: UUID4, shared_services_rep
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=strings.SHARED_SERVICE_DOES_NOT_EXIST)


async def get_shared_service_by_id_from_path(shared_service_id: UUID4 = Path(...), shared_service_repo=Depends(get_repository(SharedServiceRepository))) -> SharedService:
async def get_shared_service_by_id_from_path(shared_service_id: UUID4 = Path(...), shared_service_repo=Depends(Database().get_repository(SharedServiceRepository))) -> SharedService:
return await get_shared_service_by_id(shared_service_id, shared_service_repo)


async def get_operation_by_id_from_path(operation_id: UUID4 = Path(...), operations_repo=Depends(get_repository(OperationRepository))) -> Operation:
async def get_operation_by_id_from_path(operation_id: UUID4 = Path(...), operations_repo=Depends(Database().get_repository(OperationRepository))) -> Operation:
try:
return await operations_repo.get_operation_by_id(operation_id=operation_id)
except EntityDoesNotExist:
Expand Down
4 changes: 2 additions & 2 deletions api_app/api/dependencies/workspace_service_templates.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from fastapi import Depends, HTTPException, Path, status

from api.dependencies.database import get_repository
from api.dependencies.database import Database
from db.errors import EntityDoesNotExist
from db.repositories.resource_templates import ResourceTemplateRepository
from models.domain.resource import ResourceType
from models.domain.resource_template import ResourceTemplate
from resources import strings


async def get_workspace_service_template_by_name_from_path(service_template_name: str = Path(...), template_repo=Depends(get_repository(ResourceTemplateRepository))) -> ResourceTemplate:
async def get_workspace_service_template_by_name_from_path(service_template_name: str = Path(...), template_repo=Depends(Database().get_repository(ResourceTemplateRepository))) -> ResourceTemplate:
try:
return await template_repo.get_current_template(service_template_name, ResourceType.WorkspaceService)
except EntityDoesNotExist:
Expand Down
14 changes: 7 additions & 7 deletions api_app/api/dependencies/workspaces.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from fastapi import Depends, HTTPException, Path, status
from pydantic import UUID4

from api.dependencies.database import get_repository
from api.dependencies.database import Database
from db.errors import EntityDoesNotExist, ResourceIsNotDeployed
from db.repositories.operations import OperationRepository
from db.repositories.user_resources import UserResourceRepository
Expand All @@ -22,11 +22,11 @@ async def get_workspace_by_id(workspace_id: UUID4, workspaces_repo) -> Workspace
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=strings.WORKSPACE_DOES_NOT_EXIST)


async def get_workspace_by_id_from_path(workspace_id: UUID4 = Path(...), workspaces_repo=Depends(get_repository(WorkspaceRepository))) -> Workspace:
async def get_workspace_by_id_from_path(workspace_id: UUID4 = Path(...), workspaces_repo=Depends(Database().get_repository(WorkspaceRepository))) -> Workspace:
return await get_workspace_by_id(workspace_id, workspaces_repo)


async def get_deployed_workspace_by_id_from_path(workspace_id: UUID4 = Path(...), workspaces_repo=Depends(get_repository(WorkspaceRepository)), operations_repo=Depends(get_repository(OperationRepository))) -> Workspace:
async def get_deployed_workspace_by_id_from_path(workspace_id: UUID4 = Path(...), workspaces_repo=Depends(Database().get_repository(WorkspaceRepository)), operations_repo=Depends(Database().get_repository(OperationRepository))) -> Workspace:
try:
return await workspaces_repo.get_deployed_workspace_by_id(workspace_id, operations_repo)
except EntityDoesNotExist:
Expand All @@ -35,14 +35,14 @@ async def get_deployed_workspace_by_id_from_path(workspace_id: UUID4 = Path(...)
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=strings.WORKSPACE_IS_NOT_DEPLOYED)


async def get_workspace_service_by_id_from_path(workspace_id: UUID4 = Path(...), service_id: UUID4 = Path(...), workspace_services_repo=Depends(get_repository(WorkspaceServiceRepository))) -> WorkspaceService:
async def get_workspace_service_by_id_from_path(workspace_id: UUID4 = Path(...), service_id: UUID4 = Path(...), workspace_services_repo=Depends(Database().get_repository(WorkspaceServiceRepository))) -> WorkspaceService:
try:
return await workspace_services_repo.get_workspace_service_by_id(workspace_id, service_id)
except EntityDoesNotExist:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=strings.WORKSPACE_SERVICE_DOES_NOT_EXIST)


async def get_deployed_workspace_service_by_id_from_path(workspace_id: UUID4 = Path(...), service_id: UUID4 = Path(...), workspace_services_repo=Depends(get_repository(WorkspaceServiceRepository)), operations_repo=Depends(get_repository(OperationRepository))) -> WorkspaceService:
async def get_deployed_workspace_service_by_id_from_path(workspace_id: UUID4 = Path(...), service_id: UUID4 = Path(...), workspace_services_repo=Depends(Database().get_repository(WorkspaceServiceRepository)), operations_repo=Depends(Database().get_repository(OperationRepository))) -> WorkspaceService:
try:
return await workspace_services_repo.get_deployed_workspace_service_by_id(workspace_id, service_id, operations_repo)
except EntityDoesNotExist:
Expand All @@ -51,14 +51,14 @@ async def get_deployed_workspace_service_by_id_from_path(workspace_id: UUID4 = P
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=strings.WORKSPACE_SERVICE_IS_NOT_DEPLOYED)


async def get_user_resource_by_id_from_path(workspace_id: UUID4 = Path(...), service_id: UUID4 = Path(...), resource_id: UUID4 = Path(...), user_resource_repo=Depends(get_repository(UserResourceRepository))) -> UserResource:
async def get_user_resource_by_id_from_path(workspace_id: UUID4 = Path(...), service_id: UUID4 = Path(...), resource_id: UUID4 = Path(...), user_resource_repo=Depends(Database().get_repository(UserResourceRepository))) -> UserResource:
try:
return await user_resource_repo.get_user_resource_by_id(workspace_id, service_id, resource_id)
except EntityDoesNotExist:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=strings.USER_RESOURCE_DOES_NOT_EXIST)


async def get_operation_by_id_from_path(operation_id: UUID4 = Path(...), operations_repo=Depends(get_repository(OperationRepository))) -> Operation:
async def get_operation_by_id_from_path(operation_id: UUID4 = Path(...), operations_repo=Depends(Database().get_repository(OperationRepository))) -> Operation:
try:
return await operations_repo.get_operation_by_id(operation_id=operation_id)
except EntityDoesNotExist:
Expand Down
Loading

0 comments on commit 69c8568

Please sign in to comment.