Skip to content

Commit

Permalink
Introduce abstraction for broadway socket communication
Browse files Browse the repository at this point in the history
  • Loading branch information
SvenMarcus committed Feb 10, 2023
1 parent 4a8b587 commit 66c4b1a
Show file tree
Hide file tree
Showing 16 changed files with 273 additions and 195 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
*~
authorized_keys
__pycache__/
.python-version
4 changes: 4 additions & 0 deletions ocrd_monitor/ocrdbrowser/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from . import _workspace as workspace
from ._browser import (
Channel,
ChannelClosed,
OcrdBrowser,
OcrdBrowserFactory,
filter_owned,
Expand All @@ -14,6 +16,8 @@
from ._subprocess import SubProcessOcrdBrowserFactory

__all__ = [
"Channel",
"ChannelClosed",
"DockerOcrdBrowserFactory",
"NoPortsAvailableError",
"OcrdBrowser",
Expand Down
17 changes: 16 additions & 1 deletion ocrd_monitor/ocrdbrowser/_browser.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,19 @@
from __future__ import annotations

from os import path
from typing import Protocol
from typing import AsyncContextManager, Protocol


class ChannelClosed(RuntimeError):
pass


class Channel(Protocol):
async def receive_bytes(self) -> bytes:
...

async def send_bytes(self, data: bytes) -> None:
...


class OcrdBrowser(Protocol):
Expand All @@ -20,6 +32,9 @@ def start(self) -> None:
def stop(self) -> None:
...

def open_channel(self) -> AsyncContextManager[Channel]:
...


class OcrdBrowserFactory(Protocol):
def __call__(self, owner: str, workspace_path: str) -> OcrdBrowser:
Expand Down
8 changes: 6 additions & 2 deletions ocrd_monitor/ocrdbrowser/_docker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@

import os.path as path
import subprocess as sp
from typing import Any
from typing import Any, AsyncContextManager

from ._browser import OcrdBrowser
from ._browser import Channel, OcrdBrowser
from ._port import Port
from ._websocketchannel import WebSocketChannel

_docker_run = "docker run --rm -d --name {} -v {}:/data -p {}:8085 ocrd-browser:latest"
_docker_stop = "docker stop {}"
Expand Down Expand Up @@ -48,6 +49,9 @@ def stop(self) -> None:
self._port.release()
self.id = None

def open_channel(self) -> AsyncContextManager[Channel]:
return WebSocketChannel(self.address() + "/socket")

def _container_name(self) -> str:
workspace = path.basename(self.workspace())
return f"ocrd-browser-{self.owner()}-{workspace}"
Expand Down
8 changes: 6 additions & 2 deletions ocrd_monitor/ocrdbrowser/_subprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
import os
import subprocess as sp
from shutil import which
from typing import Optional
from typing import AsyncContextManager, Optional

from ._browser import OcrdBrowser
from ._browser import Channel, OcrdBrowser
from ._port import Port
from ._websocketchannel import WebSocketChannel

BROADWAY_BASE_PORT = 8080

Expand Down Expand Up @@ -65,6 +66,9 @@ def stop(self) -> None:
self._process.terminate()
self._localport.release()

def open_channel(self) -> AsyncContextManager[Channel]:
return WebSocketChannel(self.address() + "/socket")


class SubProcessOcrdBrowserFactory:
def __init__(self, available_ports: set[int]) -> None:
Expand Down
60 changes: 60 additions & 0 deletions ocrd_monitor/ocrdbrowser/_websocketchannel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from __future__ import annotations

from types import TracebackType
from typing import Type, cast

from websockets import client
from websockets.exceptions import ConnectionClosed
from websockets.legacy.client import WebSocketClientProtocol
from websockets.typing import Subprotocol

from ._browser import ChannelClosed


class WebSocketChannel:
def __init__(self, url: str) -> None:
url = url.replace("http://", "ws://").replace("https://", "wss://")
self._connection = client.connect(
url,
subprotocols=[Subprotocol("broadway")],
open_timeout=None,
ping_timeout=None,
close_timeout=None,
max_size=2**32,
)

self._open_connection: WebSocketClientProtocol | None = None

async def __aenter__(self) -> "WebSocketChannel":
self._open_connection = await self._connection
return self

async def __aexit__(
self,
exc_type: Type[BaseException] | None,
exc_value: BaseException | None,
traceback: TracebackType | None,
) -> None:
if not self._open_connection:
return

await self._open_connection.close()
self._open_connection = None

async def receive_bytes(self) -> bytes:
try:
if not self._open_connection:
return bytes()

return cast(bytes, await self._open_connection.recv())
except ConnectionClosed:
raise ChannelClosed()

async def send_bytes(self, data: bytes) -> None:
try:
if not self._open_connection:
return

await self._open_connection.send(data)
except ConnectionClosed:
raise ChannelClosed()
63 changes: 3 additions & 60 deletions ocrd_monitor/ocrdmonitor/server/proxy.py
Original file line number Diff line number Diff line change
@@ -1,64 +1,15 @@
from __future__ import annotations

import asyncio
from types import TracebackType
from typing import Protocol, Sequence, Type, cast

from fastapi import Response
from ocrdbrowser import Channel
from requests import request
from websockets import client
from websockets.legacy.client import WebSocketClientProtocol
from websockets.typing import Subprotocol

from .redirect import WorkspaceRedirect
from .redirect import BrowserRedirect


class WebSocketAdapter:
def __init__(
self, url: str, protocols: Sequence[Subprotocol] | None = None
) -> None:
url = url.replace("http://", "ws://").replace("https://", "wss://")
self._connection = client.connect(
url,
subprotocols=protocols,
open_timeout=None,
ping_timeout=None,
close_timeout=None,
max_size=2**32,
)

self._open_connection: WebSocketClientProtocol | None = None

async def __aenter__(self) -> "WebSocketAdapter":
self._open_connection = await self._connection
return self

async def __aexit__(
self,
exc_type: Type[BaseException] | None,
exc_value: BaseException | None,
traceback: TracebackType | None,
) -> None:
if not self._open_connection:
return

await self._open_connection.close()
self._open_connection = None

async def receive_bytes(self) -> bytes:
if not self._open_connection:
return bytes()

return cast(bytes, await self._open_connection.recv())

async def send_bytes(self, data: bytes) -> None:
if not self._open_connection:
return

await self._open_connection.send(data)


def forward(redirect: WorkspaceRedirect, url: str) -> Response:
def forward(redirect: BrowserRedirect, url: str) -> Response:
redirect_url = redirect.redirect_url(url)
response = request("GET", redirect_url, allow_redirects=False)
return Response(
Expand All @@ -68,14 +19,6 @@ def forward(redirect: WorkspaceRedirect, url: str) -> Response:
)


class Channel(Protocol):
async def receive_bytes(self) -> bytes:
...

async def send_bytes(self, data: bytes) -> None:
...


async def tunnel(
source: Channel,
target: Channel,
Expand Down
35 changes: 17 additions & 18 deletions ocrd_monitor/ocrdmonitor/server/redirect.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from __future__ import annotations

from pathlib import Path
from typing import Callable, Protocol
from typing import Callable

from ocrdbrowser import OcrdBrowser


def removeprefix(string: str, prefix: str) -> str:
Expand Down Expand Up @@ -33,15 +35,14 @@ def __removesuffix(suffix: str) -> str:
return _removesuffix(suffix)


class WorkspaceServer(Protocol):
def address(self) -> str:
...


class WorkspaceRedirect:
def __init__(self, workspace: Path, server: WorkspaceServer) -> None:
class BrowserRedirect:
def __init__(self, workspace: Path, browser: OcrdBrowser) -> None:
self._workspace = workspace
self._server = server
self._browser = browser

@property
def browser(self) -> OcrdBrowser:
return self._browser

@property
def workspace(self) -> Path:
Expand All @@ -50,7 +51,7 @@ def workspace(self) -> Path:
def redirect_url(self, url: str) -> str:
url = removeprefix(url, str(self._workspace))
url = removeprefix(url, "/")
address = removesuffix(self._server.address(), "/")
address = removesuffix(self._browser.address(), "/")
return removesuffix(address + "/" + url, "/")

def matches(self, path: str) -> bool:
Expand All @@ -59,24 +60,24 @@ def matches(self, path: str) -> bool:

class RedirectMap:
def __init__(self) -> None:
self._redirects: dict[str, set[WorkspaceRedirect]] = {}
self._redirects: dict[str, set[BrowserRedirect]] = {}

def add(
self, session_id: str, workspace: Path, server: WorkspaceServer
) -> WorkspaceRedirect:
self, session_id: str, workspace: Path, server: OcrdBrowser
) -> BrowserRedirect:
try:
redirect = self.get(session_id, workspace)
return redirect
except KeyError:
redirect = WorkspaceRedirect(workspace, server)
redirect = BrowserRedirect(workspace, server)
self._redirects.setdefault(session_id, set()).add(redirect)
return redirect

def remove(self, session_id: str, workspace: Path) -> None:
redirect = self.get(session_id, workspace)
self._redirects[session_id].remove(redirect)

def get(self, session_id: str, workspace: Path) -> WorkspaceRedirect:
def get(self, session_id: str, workspace: Path) -> BrowserRedirect:
redirect = next(
(
redirect
Expand All @@ -88,9 +89,7 @@ def get(self, session_id: str, workspace: Path) -> WorkspaceRedirect:

return self._instance_or_raise(redirect)

def _instance_or_raise(
self, redirect: WorkspaceRedirect | None
) -> WorkspaceRedirect:
def _instance_or_raise(self, redirect: BrowserRedirect | None) -> BrowserRedirect:
if redirect is None:
raise KeyError("No redirect found")

Expand Down
21 changes: 8 additions & 13 deletions ocrd_monitor/ocrdmonitor/server/workspaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,14 @@
from pathlib import Path

import ocrdbrowser
import ocrdmonitor.server.proxy as proxy
from fastapi import APIRouter, Cookie, Request, Response, WebSocket
from fastapi.templating import Jinja2Templates
from ocrdbrowser import OcrdBrowser, OcrdBrowserFactory, workspace
from ocrdbrowser import ChannelClosed, OcrdBrowser, OcrdBrowserFactory, workspace
from ocrdmonitor.server.redirect import RedirectMap
from requests.exceptions import ConnectionError
from websockets.typing import Subprotocol
from websockets.exceptions import ConnectionClosedError

import ocrdmonitor.server.proxy as proxy
from ocrdmonitor.server.redirect import RedirectMap
from websockets.typing import Subprotocol


def create_workspaces(
Expand Down Expand Up @@ -74,17 +73,13 @@ async def workspace_socket_proxy(
websocket: WebSocket, workspace: Path, session_id: str = Cookie(default=None)
) -> None:
redirect = redirects.get(session_id, workspace)
url = redirect.redirect_url(str(workspace / "socket"))
await websocket.accept(subprotocol="broadway")

async with proxy.WebSocketAdapter(
url, [Subprotocol("broadway")]
) as broadway_socket:
async with redirect.browser.open_channel() as channel:
try:
while True:
await proxy.tunnel(broadway_socket, websocket)
except ConnectionClosedError:
_stop_browsers_in_workspace(workspace, session_id)
await proxy.tunnel(channel, websocket)
except ChannelClosed:
redirect.browser.stop()

def _launch_browser(session_id: str, workspace: Path) -> OcrdBrowser:
browser = ocrdbrowser.launch(
Expand Down
1 change: 1 addition & 0 deletions ocrd_monitor/requirements.dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@ httpx>=0.23.1
mypy>=0.982
nox>=2022.11.21
pytest>=7.2.0
pytest-asyncio>=0.20.3
testcontainers>=3.7.0
2 changes: 1 addition & 1 deletion ocrd_monitor/tests/fakes/_backgroundprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


class AnyFunc(Protocol):
def __call__(self, *args, **kwargs) -> Any:
def __call__(self, *args: Any, **kwargs: Any) -> Any:
...


Expand Down
Loading

0 comments on commit 66c4b1a

Please sign in to comment.