diff --git a/pyproject.toml b/pyproject.toml index ed45425..942e7fa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,7 +52,7 @@ demo = [ "Bug Tracker" = "https://github.com/catalystneuro/tqdm_publisher/issues" [project.scripts] -tqdm_publisher = "tqdm_publisher.demo.demo_command_line_interface:main" +tqdm_publisher = "tqdm_publisher._demo._demo_command_line_interface:_command_line_interface" [tool.black] line-length = 120 diff --git a/src/tqdm_publisher/demo/__init__.py b/src/tqdm_publisher/_demo/__init__.py similarity index 100% rename from src/tqdm_publisher/demo/__init__.py rename to src/tqdm_publisher/_demo/__init__.py diff --git a/src/tqdm_publisher/demo/client.html b/src/tqdm_publisher/_demo/_client.html similarity index 94% rename from src/tqdm_publisher/demo/client.html rename to src/tqdm_publisher/_demo/_client.html index 3b716e2..0e01ef3 100644 --- a/src/tqdm_publisher/demo/client.html +++ b/src/tqdm_publisher/_demo/_client.html @@ -131,10 +131,10 @@

tqdm_progress

const { element, progress } = createProgressBar(); barElements.appendChild(element); - const id = Math.random().toString(36).substring(7); - bars[id] = progress; + const progress_bar_id = Math.random().toString(36).substring(7); + bars[progress_bar_id] = progress; - client.socket.send(JSON.stringify({ command: 'start', id })); + client.socket.send(JSON.stringify({ command: 'start', progress_bar_id })); }) diff --git a/src/tqdm_publisher/_demo/_demo_command_line_interface.py b/src/tqdm_publisher/_demo/_demo_command_line_interface.py new file mode 100644 index 0000000..2f1764c --- /dev/null +++ b/src/tqdm_publisher/_demo/_demo_command_line_interface.py @@ -0,0 +1,43 @@ +import os +import subprocess +import sys +from pathlib import Path + +from ._server import run_demo + +DEMO_BASE_FOLDER_PATH = Path(__file__).parent + +CLIENT_FILE_PATH = DEMO_BASE_FOLDER_PATH / "_client.html" +SERVER_FILE_PATH = DEMO_BASE_FOLDER_PATH / "_server.py" + + +def _command_line_interface(): + """A simple command line interface for running the demo for TQDM Publisher.""" + if len(sys.argv) <= 1: + print("No input provided. Please specify a command (e.g. 'demo').") + return + + command = sys.argv[1] + if "-" in command: + print( + f"No command provided, but flag {command} was received. " + "Please specify a command before the flag (e.g. 'demo')." + ) + return + + flags_list = sys.argv[2:] + if len(flags_list) > 0: + print(f"No flags are accepted at this time, but flags {flags_list} were received.") + return + + if command == "demo": + # For convenience - automatically pop-up a browser window on the locally hosted HTML page + if sys.platform == "win32": + os.system(f'start "" "{CLIENT_FILE_PATH}"') + else: + subprocess.run(["open", CLIENT_FILE_PATH]) + + run_demo() + + else: + print(f"{command} is an invalid command.") diff --git a/src/tqdm_publisher/_demo/_server.py b/src/tqdm_publisher/_demo/_server.py new file mode 100644 index 0000000..dbd18ab --- /dev/null +++ b/src/tqdm_publisher/_demo/_server.py @@ -0,0 +1,89 @@ +import asyncio +import json +import threading +import time +from typing import Dict, Any, List +from uuid import uuid4 + +import websockets + +from tqdm_publisher import TQDMPublisher + + +def start_progress_bar(*, client_id: str, progress_bar_id: str, client_callback: callable) -> None: + """ + Emulate running the specified number of tasks by sleeping the specified amount of time on each iteration. + + Defaults are chosen for a deterministic and regular update period of one second for a total time of one minute. + """ + all_task_durations_in_seconds = [1.0 for _ in range(60)] # One minute at one second per update + progress_bar = TQDMPublisher(iterable=all_task_durations_in_seconds) + + def run_function_on_progress_update(format_dict: dict) -> None: + """ + This is the injected callback that will be run on each update of the TQDM object. + + Its first and only positional argument must be the `format_dict` of the TQDM instance. Additional customization + on outside parameters must be achieved by defining those fields at an outer scope and defining this + server-specific callback inside the local scope. + + In this demo, we will execute the `client_callback` whose protocol is known only to the WebSocketHandler. + It has + """ + client_callback(client_id=client_id, progress_bar_id=progress_bar_id, format_dict=format_dict) + + progress_bar.subscribe(callback=run_function_on_progress_update) + + for task_duration in progress_bar: + time.sleep(task_duration) + + +class WebSocketHandler: + """Describe this class.""" + + def __init__(self) -> None: + """Initialize the mapping of client IDs to .""" + self.clients: Dict[str, Any] = dict() + + def function_to_run_on_progress_update(self, *, client_id: str, progress_bar_id: str, format_dict: dict) -> None: + """This is...""" + asyncio.run(self.send(client_id=client_id, data=dict(progress_bar_id=progress_bar_id, format_dict=format_dict))) + + async def send(self, client_id: str, data: dict) -> None: + """Send an arbitrary JSON object `data` to the client identifier by `client_id`.""" + await self.clients[client_id].send(json.dumps(obj=data)) + + async def handler(self, websocket) -> None: + """Describe what the handler does.""" + client_id = str(uuid4()) + self.clients[client_id] = websocket # Register client connection + + try: + async for message in websocket: + message_from_client = json.loads(message) + + if message_from_client["command"] == "start": + thread = threading.Thread( + target=start_progress_bar, + kwargs=dict( + client_id=client_id, + progress_bar_id=message_from_client["progress_bar_id"], + client_callback=self.function_to_run_on_progress_update, + ), + ) + thread.start() + finally: # This is called when the connection is closed + if client_id in self.clients: + del self.clients[client_id] + + +async def spawn_server() -> None: + """Spawn the server asynchronously.""" + handler = WebSocketHandler().handler + async with websockets.serve(handler, "", 8000): + await asyncio.Future() + + +def run_demo() -> None: + """Trigger the execution of the asynchronous spawn.""" + asyncio.run(spawn_server()) diff --git a/src/tqdm_publisher/demo/client.py b/src/tqdm_publisher/demo/client.py deleted file mode 100644 index 73bf883..0000000 --- a/src/tqdm_publisher/demo/client.py +++ /dev/null @@ -1,12 +0,0 @@ -import subprocess -from pathlib import Path - -client_path = Path(__file__).parent / "client.html" - - -def main(): - subprocess.run(["open", client_path]) - - -if __name__ == "__main__": - main() diff --git a/src/tqdm_publisher/demo/demo_command_line_interface.py b/src/tqdm_publisher/demo/demo_command_line_interface.py deleted file mode 100644 index 766d806..0000000 --- a/src/tqdm_publisher/demo/demo_command_line_interface.py +++ /dev/null @@ -1,41 +0,0 @@ -import subprocess -import sys -from pathlib import Path - -demo_base_path = Path(__file__).parent - -client_path = demo_base_path / "client.html" -server_path = demo_base_path / "server.py" - - -def main(): - if len(sys.argv) <= 1: - print("No command provided. Please specify a command (e.g. 'demo').") - return - - command = sys.argv[1] - - flags_list = sys.argv[2:] - - client_flag = "--client" in flags_list - server_flag = "--server" in flags_list - both_flags = "--server" in flags_list and "--client" in flags_list - - flags = dict( - client=not server_flag or both_flags, - server=not client_flag or both_flags, - ) - - if command == "demo": - if flags["client"]: - subprocess.run(["open", client_path]) - - if flags["server"]: - subprocess.run(["python", server_path]) - - else: - print(f"{command} is an invalid command.") - - -if __name__ == "__main__": - main() diff --git a/src/tqdm_publisher/demo/server.py b/src/tqdm_publisher/demo/server.py deleted file mode 100644 index abaa212..0000000 --- a/src/tqdm_publisher/demo/server.py +++ /dev/null @@ -1,68 +0,0 @@ -#!/usr/bin/env python - -import asyncio -import json -import random -import threading -import time -from typing import List -from uuid import uuid4 - -import websockets - -from tqdm_publisher import TQDMPublisher - - -def generate_task_durations(n=100) -> List[float]: - return [random.uniform(0, 1.0) for _ in range(n)] - - -def start_progress_bar(id, callback): - durations = generate_task_durations() - progress_bar = TQDMPublisher(durations) - progress_bar.subscribe(lambda info: callback(id, info)) - for duration in progress_bar: - time.sleep(duration) - - -class WebSocketHandler: - def __init__(self): - self.clients = {} - pass - - async def send(self, id, data): - await self.clients[id].send(json.dumps(data)) - - async def handler(self, websocket): - identifier = str(uuid4()) - self.clients[identifier] = websocket # Register client connection - - def on_progress(id, info): - - asyncio.run(self.send(identifier, dict(id=id, payload=info))) - - try: - async for message in websocket: - - info = json.loads(message) - - if info["command"] == "start": - thread = threading.Thread(target=start_progress_bar, args=[info["id"], on_progress]) - thread.start() - - finally: - del self.clients[identifier] # This is called when the connection is closed - - -async def spawn_server(): - handler = WebSocketHandler().handler - async with websockets.serve(handler, "", 8000): - await asyncio.Future() - - -def main(): - asyncio.run(spawn_server()) - - -if __name__ == "__main__": - main()