diff --git a/.github/workflows/assess-file-changes.yml b/.github/workflows/assess-file-changes.yml index 7070e92..067ccd7 100644 --- a/.github/workflows/assess-file-changes.yml +++ b/.github/workflows/assess-file-changes.yml @@ -50,4 +50,4 @@ jobs: else echo "Changelog not updated" fi - done \ No newline at end of file + done diff --git a/.github/workflows/auto-publish.yml b/.github/workflows/auto-publish.yml index a8dcc52..8732990 100644 --- a/.github/workflows/auto-publish.yml +++ b/.github/workflows/auto-publish.yml @@ -28,4 +28,4 @@ jobs: with: verbose: true user: __token__ - password: ${{ secrets.PYPI_API_TOKEN }} \ No newline at end of file + password: ${{ secrets.PYPI_API_TOKEN }} diff --git a/.github/workflows/codespell.yml b/.github/workflows/codespell.yml index e022d40..4b93c1e 100644 --- a/.github/workflows/codespell.yml +++ b/.github/workflows/codespell.yml @@ -13,4 +13,4 @@ jobs: - name: Checkout uses: actions/checkout@v3 - name: Codespell - uses: codespell-project/actions-codespell@v1 \ No newline at end of file + uses: codespell-project/actions-codespell@v1 diff --git a/.github/workflows/dailies.yml b/.github/workflows/dailies.yml index c615312..4224204 100644 --- a/.github/workflows/dailies.yml +++ b/.github/workflows/dailies.yml @@ -24,4 +24,4 @@ jobs: subject: TQDM Publisher Daily Test Failure to: garrett.flynn@catalystneuro.com,cody.baker@catalystneuro.com # add more with commas, no separation from: TQDM Publisher - body: "The daily workflow for TQDM Publisher failed: please check status at https://github.com/CatalystNeuro/tqdm_publisher/actions/workflows/dailies.yml" \ No newline at end of file + body: "The daily workflow for TQDM Publisher failed: please check status at https://github.com/CatalystNeuro/tqdm_publisher/actions/workflows/dailies.yml" diff --git a/.github/workflows/deploy-tests.yml b/.github/workflows/deploy-tests.yml index b1ba6db..bbb1fa7 100644 --- a/.github/workflows/deploy-tests.yml +++ b/.github/workflows/deploy-tests.yml @@ -44,4 +44,4 @@ jobs: uses: re-actors/alls-green@release/v1 with: allowed-skips: run-tests # Sometimes only docs are adjusted - jobs: ${{ toJSON(needs) }} \ No newline at end of file + jobs: ${{ toJSON(needs) }} diff --git a/.github/workflows/testing.yml b/.github/workflows/testing.yml index a231b8f..d28b4ea 100644 --- a/.github/workflows/testing.yml +++ b/.github/workflows/testing.yml @@ -29,7 +29,7 @@ jobs: - name: Run full pytest with coverage run: pytest -rsx -n auto --dist loadscope --cov=./ --cov-report xml:./codecov.xml - + - name: Upload full coverage to Codecov if: ${{ matrix.python-version == '3.9' && matrix.os == 'ubuntu-latest' }} uses: codecov/codecov-action@v3 @@ -38,4 +38,4 @@ jobs: file: ./codecov.xml flags: unittests name: codecov-umbrella - yml: ./codecov.yml \ No newline at end of file + yml: ./codecov.yml diff --git a/.gitignore b/.gitignore index 6065a51..34fbcc3 100644 --- a/.gitignore +++ b/.gitignore @@ -4,4 +4,4 @@ dist .coverage .coverage.* -codecov.xml \ No newline at end of file +codecov.xml diff --git a/README.md b/README.md index 1916a3f..08bb229 100644 --- a/README.md +++ b/README.md @@ -23,7 +23,7 @@ async def sleep_func(sleep_duration = 1): await asyncio.sleep(delay=sleep_duration) async def run_multiple_sleeps(sleep_durations): - + tasks = [] for sleep_duration in sleep_durations: diff --git a/demo/client.html b/demo/client.html index 4a075b8..67ed769 100644 --- a/demo/client.html +++ b/demo/client.html @@ -1,15 +1,15 @@ - + - + - + - + - + - + Concurrent Client Demo - + - +
@@ -69,9 +69,9 @@

tqdm_progress

- + - + - - \ No newline at end of file + + diff --git a/demo/client.py b/demo/client.py index 4c03083..73bf883 100644 --- a/demo/client.py +++ b/demo/client.py @@ -1,11 +1,12 @@ import subprocess - from pathlib import Path client_path = Path(__file__).parent / "client.html" + def main(): subprocess.run(["open", client_path]) + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/demo/demo_cli.py b/demo/demo_cli.py index e95a1a7..766d806 100644 --- a/demo/demo_cli.py +++ b/demo/demo_cli.py @@ -5,14 +5,14 @@ demo_base_path = Path(__file__).parent client_path = demo_base_path / "client.html" -server_path = demo_base_path/ "server.py" +server_path = demo_base_path / "server.py" -def main(): - if (len(sys.argv) <= 1): +def main(): + if len(sys.argv) <= 1: print("No command provided. Please specify a command (e.g. 'demo').") return - + command = sys.argv[1] flags_list = sys.argv[2:] @@ -22,15 +22,14 @@ def main(): both_flags = "--server" in flags_list and "--client" in flags_list flags = dict( - client = not server_flag or both_flags, - server = not client_flag or both_flags, - + client=not server_flag or both_flags, + server=not client_flag or both_flags, ) - if (command == "demo"): + if command == "demo": if flags["client"]: subprocess.run(["open", client_path]) - + if flags["server"]: subprocess.run(["python", server_path]) @@ -39,4 +38,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/demo/server.py b/demo/server.py index e489574..02d68dd 100644 --- a/demo/server.py +++ b/demo/server.py @@ -1,16 +1,16 @@ #!/usr/bin/env python -import random import asyncio +import json +import random +import threading from typing import List - -from tqdm_publisher import TQDMPublisher +from uuid import uuid4 import websockets -import threading -from uuid import uuid4 -import json +from tqdm_publisher import TQDMPublisher + async def sleep_func(sleep_duration: float = 1) -> float: await asyncio.sleep(delay=sleep_duration) @@ -28,31 +28,28 @@ def create_tasks(): return tasks -class ProgressHandler(): - +class ProgressHandler: def __init__(self): self.started = False self.callbacks = [] self.callback_ids = [] - def subscribe(self, callback): + def subscribe(self, callback): self.callbacks.append(callback) - if (hasattr(self, 'progress_bar')): + if hasattr(self, "progress_bar"): self._subscribe(callback) - def unsubscribe(self, callback_id): self.progress_bar.unsubscribe(callback_id) - def clear(self): - self.callbacks = [] + def clear(self): + self.callbacks = [] self._clear() def _clear(self): - for callback_id in self.callback_ids: - self.unsubscribe(callback_id) + self.unsubscribe(callback_id) self.callback_ids = [] @@ -65,15 +62,12 @@ def stop(self): self.clear() self.thread.join() - def _subscribe(self, callback): callback_id = self.progress_bar.subscribe(callback) self.callback_ids.append(callback_id) - async def run(self): - - if (hasattr(self, 'progress_bar')): + if hasattr(self, "progress_bar"): print("Progress bar already running") return @@ -82,35 +76,32 @@ async def run(self): for callback in self.callbacks: self._subscribe(callback) - + for f in self.progress_bar: await f self._clear() del self.progress_bar - - def thread_loop(self): while self.started: asyncio.run(self.run()) - - def start(self): - if (self.started): + def start(self): + if self.started: return - + self.started = True - self.thread = threading.Thread(target=self.thread_loop) # Start infinite loop of progress bar thread + self.thread = threading.Thread(target=self.thread_loop) # Start infinite loop of progress bar thread self.thread.start() progress_handler = ProgressHandler() + class WebSocketHandler: def __init__(self): - self.clients = {} # Initialize with any state you need @@ -126,9 +117,9 @@ def handle_task_result(self, task): async def handler(self, websocket): id = str(uuid4()) - self.clients[id] = websocket # Register client connection + self.clients[id] = websocket # Register client connection - progress_handler.start() # Start if not started + progress_handler.start() # Start if not started def on_progress(info): task = asyncio.create_task(websocket.send(json.dumps(info))) @@ -143,18 +134,19 @@ def on_progress(info): finally: # This is called when the connection is closed del self.clients[id] - if (len(self.clients) == 0): + if len(self.clients) == 0: progress_handler.stop() - async def spawn_server(): handler = WebSocketHandler().handler async with websockets.serve(handler, "", 8000): await asyncio.Future() # run forever + def main(): asyncio.run(spawn_server()) + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/pyproject.toml b/pyproject.toml index 4eff15e..b410cf1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,8 +37,8 @@ dependencies = [ [project.optional-dependencies] test = [ - "pytest", - "pytest-asyncio", + "pytest", + "pytest-asyncio", "pytest-cov" ] diff --git a/src/tqdm_publisher/__init__.py b/src/tqdm_publisher/__init__.py index 8cc0942..526c8ed 100644 --- a/src/tqdm_publisher/__init__.py +++ b/src/tqdm_publisher/__init__.py @@ -1 +1 @@ -from .publisher import TQDMPublisher \ No newline at end of file +from .publisher import TQDMPublisher diff --git a/src/tqdm_publisher/publisher.py b/src/tqdm_publisher/publisher.py index 94de1be..acd723d 100644 --- a/src/tqdm_publisher/publisher.py +++ b/src/tqdm_publisher/publisher.py @@ -1,19 +1,16 @@ +from typing import Union +from uuid import uuid4 from tqdm import tqdm as base_tqdm -from uuid import uuid4 -from typing import Union -# This class is a subclass of tqdm that allows for an arbitrary number of callbacks to be registered class TQDMPublisher(base_tqdm): - def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.callbacks = {} - - # Override the update method to call callbacks - def update(self, n: int=1) -> Union[bool, None]: + # Override the update method to run callbacks + def update(self, n: int = 1) -> Union[bool, None]: displayed = super().update(n) for callback in self.callbacks.values(): @@ -21,7 +18,6 @@ def update(self, n: int=1) -> Union[bool, None]: return displayed - def subscribe(self, callback: callable): """ Subscribe to updates from the progress bar. @@ -35,13 +31,13 @@ def subscribe(self, callback: callable): ---------- callback : callable A callable object (like a function) that will be called back by this object. - The callback should be able to be invoked with a single argument, the progress + The callback should be able to be invoked with a single argument, the progress bar's format_dict. Returns ------- str - A unique identifier for the callback. This ID is a UUID string and can be used + A unique identifier for the callback. This ID is a UUID string and can be used to reference the registered callback in future operations. Examples @@ -57,20 +53,20 @@ def subscribe(self, callback: callable): callback_id = str(uuid4()) self.callbacks[callback_id] = callback return callback_id - + def unsubscribe(self, callback_id: str): """ Unsubscribe a previously registered callback from the progress bar updates. This method removes the callback associated with the given unique ID from the internal - dictionary. It is used to deregister callbacks that were previously added via the - `subscribe` method. Once a callback is removed, it will no longer be called during + dictionary. It is used to deregister callbacks that were previously added via the + `subscribe` method. Once a callback is removed, it will no longer be called during the progress bar's update events. Parameters ---------- callback_id : str - The unique identifier of the callback to be unsubscribed. This is the same UUID string + The unique identifier of the callback to be unsubscribed. This is the same UUID string that was returned by the `subscribe` method when the callback was registered. Returns @@ -97,4 +93,4 @@ def unsubscribe(self, callback_id: str): return False del self.callbacks[callback_id] - return True \ No newline at end of file + return True diff --git a/src/tqdm_publisher/testing.py b/src/tqdm_publisher/testing.py index d237b06..d344c9e 100644 --- a/src/tqdm_publisher/testing.py +++ b/src/tqdm_publisher/testing.py @@ -1,6 +1,7 @@ import asyncio import random + async def sleep_func(sleep_duration: float = 1) -> float: await asyncio.sleep(delay=sleep_duration) diff --git a/tests/test_basic.py b/tests/test_basic.py index 3e9ea93..160ed09 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -1,7 +1,10 @@ -from tqdm_publisher import TQDMPublisher +import asyncio + import pytest + +from tqdm_publisher import TQDMPublisher from tqdm_publisher.testing import create_tasks -import asyncio + def test_initialization(): publisher = TQDMPublisher() @@ -15,20 +18,22 @@ async def test_subscription_and_callback_execution(): def test_callback(identifier, data): nonlocal n_callback_executions - + if identifier not in n_callback_executions: n_callback_executions[identifier] = 0 n_callback_executions[identifier] += 1 - assert 'n' in data and 'total' in data + assert "n" in data and "total" in data tasks = create_tasks() publisher = TQDMPublisher(asyncio.as_completed(tasks), total=len(tasks)) n_subscriptions = 10 for i in range(n_subscriptions): - callback_id = publisher.subscribe(lambda data, i=i: test_callback(i, data)) # Creates a new scoped i value for subscription + callback_id = publisher.subscribe( + lambda data, i=i: test_callback(i, data) + ) # Creates a new scoped i value for subscription assert callback_id in publisher.callbacks # Simulate an update to trigger the callback @@ -37,9 +42,10 @@ def test_callback(identifier, data): assert len(n_callback_executions) == n_subscriptions - for (identifier, n_executions) in n_callback_executions.items(): + for identifier, n_executions in n_callback_executions.items(): assert n_executions > 1 + def test_unsubscription(): def dummy_callback(data): pass