From d24f4a6ff15af10faaf5647ca3409b90d61148b0 Mon Sep 17 00:00:00 2001 From: Cody Baker <51133164+CodyCBakerPhD@users.noreply.github.com> Date: Mon, 11 Mar 2024 14:24:19 -0400 Subject: [PATCH] Functional websocket declaration (#42) * Update _server.py * more keyword arguments; break up docstring; remove unused capture * more keyword arguments --------- Co-authored-by: Garrett Michael Flynn --- src/tqdm_publisher/_demo/_server.py | 79 ++++++++++------------------- 1 file changed, 27 insertions(+), 52 deletions(-) diff --git a/src/tqdm_publisher/_demo/_server.py b/src/tqdm_publisher/_demo/_server.py index 1a697ab..ac821dc 100644 --- a/src/tqdm_publisher/_demo/_server.py +++ b/src/tqdm_publisher/_demo/_server.py @@ -2,22 +2,20 @@ import json import threading import time -from typing import Any, Dict, List -from uuid import uuid4 import websockets -from tqdm_publisher import TQDMPublisher +import tqdm_publisher -def start_progress_bar(*, client_id: str, progress_bar_id: str, client_callback: callable) -> None: +def start_progress_bar(*, progress_bar_id: str, client_callback: callable) -> None: """ Emulate running the specified number of tasks by sleeping the specified amount of time on each iteration. Defaults are chosen for a deterministic and regular update period of one second for a total time of one minute. """ all_task_durations_in_seconds = [1.0 for _ in range(60)] # One minute at one second per update - progress_bar = TQDMPublisher(iterable=all_task_durations_in_seconds) + progress_bar = tqdm_publisher.TQDMPublisher(iterable=all_task_durations_in_seconds) def run_function_on_progress_update(format_dict: dict) -> None: """ @@ -29,7 +27,7 @@ def run_function_on_progress_update(format_dict: dict) -> None: In this demo, we will execute the `client_callback` whose protocol is known only to the WebSocketHandler. """ - client_callback(client_id=client_id, progress_bar_id=progress_bar_id, format_dict=format_dict) + client_callback(progress_bar_id=progress_bar_id, format_dict=format_dict) progress_bar.subscribe(callback=run_function_on_progress_update) @@ -37,60 +35,37 @@ def run_function_on_progress_update(format_dict: dict) -> None: time.sleep(task_duration) -class WebSocketHandler: - """ - This is a class that handles the WebSocket connections and the communication protocol - between the server and the client. - - While we could have implemented this as a function, a class is chosen here to maintain reference - to the clients within a defined scope. - """ - - def __init__(self) -> None: - """Initialize the mapping of client IDs to .""" - self.clients: Dict[str, Any] = dict() - - def forward_progress_to_client(self, *, client_id: str, progress_bar_id: str, format_dict: dict) -> None: - """This is the function that will run on every update of the TQDM object. It will forward the progress to the client.""" - asyncio.run(self.send(client_id=client_id, data=dict(progress_bar_id=progress_bar_id, format_dict=format_dict))) +async def handler(websocket: websockets.WebSocketServerProtocol) -> None: + """Handle messages from the client and manage the client connections.""" - async def send(self, client_id: str, data: dict) -> None: - """Send an arbitrary JSON object `data` to the client identifier by `client_id`.""" - await self.clients[client_id].send(json.dumps(obj=data)) - - async def handler(self, websocket: websockets.WebSocketServerProtocol) -> None: - """Register new WebSocket clients and handle their messages.""" - client_id = str(uuid4()) - - # Register client connection - self.clients[client_id] = websocket + def forward_progress_to_client(*, progress_bar_id: str, format_dict: dict) -> None: + """ + This is the function that will run on every update of the TQDM object. - # Wait for messages from the client - try: - async for message in websocket: - message_from_client = json.loads(message) + It will forward the progress to the client. + """ + asyncio.run( + websocket.send(message=json.dumps(obj=dict(progress_bar_id=progress_bar_id, format_dict=format_dict))) + ) - if message_from_client["command"] == "start": - thread = threading.Thread( - target=start_progress_bar, - kwargs=dict( - client_id=client_id, - progress_bar_id=message_from_client["progress_bar_id"], - client_callback=self.forward_progress_to_client, - ), - ) - thread.start() + # Wait for messages from the client + async for message in websocket: + message_from_client = json.loads(message) - # Deregister the client when the connection is closed - finally: - if client_id in self.clients: - del self.clients[client_id] + if message_from_client["command"] == "start": + thread = threading.Thread( + target=start_progress_bar, + kwargs=dict( + progress_bar_id=message_from_client["progress_bar_id"], + client_callback=forward_progress_to_client, + ), + ) + thread.start() async def spawn_server() -> None: """Spawn the server asynchronously.""" - handler = WebSocketHandler().handler - async with websockets.serve(handler, "", 8000): + async with websockets.serve(ws_handler=handler, host="", port=8000): await asyncio.Future()