From 69d657ed60053623a8193d4cb0194ec2c26492fd Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Wed, 29 Jan 2025 15:08:18 -0500 Subject: [PATCH] PYTHON-4864 - Create async version of SpecRunnerThread --- test/asynchronous/unified_format.py | 23 ++++-- test/asynchronous/utils_spec_runner.py | 109 +++++++++++++++++-------- test/unified_format.py | 9 +- test/utils_spec_runner.py | 109 +++++++++++++++++-------- tools/synchro.py | 2 + 5 files changed, 177 insertions(+), 75 deletions(-) diff --git a/test/asynchronous/unified_format.py b/test/asynchronous/unified_format.py index 52d964eb3e..6963945b46 100644 --- a/test/asynchronous/unified_format.py +++ b/test/asynchronous/unified_format.py @@ -35,6 +35,7 @@ client_knobs, unittest, ) +from test.asynchronous.utils_spec_runner import SpecRunnerTask from test.unified_format_shared import ( KMS_TLS_OPTS, PLACEHOLDER_MAP, @@ -58,7 +59,6 @@ snake_to_camel, wait_until, ) -from test.utils_spec_runner import SpecRunnerThread from test.version import Version from typing import Any, Dict, List, Mapping, Optional @@ -382,8 +382,8 @@ async def drop(self: AsyncGridFSBucket, *args: Any, **kwargs: Any) -> None: return elif entity_type == "thread": name = spec["id"] - thread = SpecRunnerThread(name) - thread.start() + thread = SpecRunnerTask(name) + await thread.start() self[name] = thread return @@ -1177,16 +1177,23 @@ def primary_changed() -> bool: wait_until(primary_changed, "change primary", timeout=timeout) - def _testOperation_runOnThread(self, spec): + async def _testOperation_runOnThread(self, spec): """Run the 'runOnThread' operation.""" thread = self.entity_map[spec["thread"]] - thread.schedule(lambda: self.run_entity_operation(spec["operation"])) + if _IS_SYNC: + await thread.schedule(lambda: self.run_entity_operation(spec["operation"])) + else: + + async def op(): + await self.run_entity_operation(spec["operation"]) + + await thread.schedule(op) - def _testOperation_waitForThread(self, spec): + async def _testOperation_waitForThread(self, spec): """Run the 'waitForThread' operation.""" thread = self.entity_map[spec["thread"]] - thread.stop() - thread.join(10) + await thread.stop() + await thread.join(10) if thread.exc: raise thread.exc self.assertFalse(thread.is_alive(), "Thread {} is still running".format(spec["thread"])) diff --git a/test/asynchronous/utils_spec_runner.py b/test/asynchronous/utils_spec_runner.py index b79e5258b5..e59ecd9b94 100644 --- a/test/asynchronous/utils_spec_runner.py +++ b/test/asynchronous/utils_spec_runner.py @@ -54,39 +54,82 @@ _IS_SYNC = False - -class SpecRunnerThread(threading.Thread): - def __init__(self, name): - super().__init__() - self.name = name - self.exc = None - self.daemon = True - self.cond = threading.Condition() - self.ops = [] - self.stopped = False - - def schedule(self, work): - self.ops.append(work) - with self.cond: - self.cond.notify() - - def stop(self): - self.stopped = True - with self.cond: - self.cond.notify() - - def run(self): - while not self.stopped or self.ops: - if not self.ops: - with self.cond: - self.cond.wait(10) - if self.ops: - try: - work = self.ops.pop(0) - work() - except Exception as exc: - self.exc = exc - self.stop() +if _IS_SYNC: + + class SpecRunnerThread(threading.Thread): + def __init__(self, name): + super().__init__() + self.name = name + self.exc = None + self.daemon = True + self.cond = threading.Condition() + self.ops = [] + self.stopped = False + + def schedule(self, work): + self.ops.append(work) + with self.cond: + self.cond.notify() + + def stop(self): + self.stopped = True + with self.cond: + self.cond.notify() + + def run(self): + while not self.stopped or self.ops: + if not self.ops: + with self.cond: + self.cond.wait(10) + if self.ops: + try: + work = self.ops.pop(0) + work() + except Exception as exc: + self.exc = exc + self.stop() +else: + + class SpecRunnerTask: + def __init__(self, name): + self.name = name + self.exc = None + self.cond = asyncio.Condition() + self.ops = [] + self.stopped = False + self.task = None + + async def schedule(self, work): + self.ops.append(work) + async with self.cond: + self.cond.notify() + + async def stop(self): + self.stopped = True + async with self.cond: + self.cond.notify() + + async def start(self): + self.task = asyncio.create_task(self.run(), name=self.name) + + async def join(self, timeout: int = 0): + await asyncio.wait([self.task], timeout=timeout) + + def is_alive(self): + return not self.stopped + + async def run(self): + while not self.stopped or self.ops: + if not self.ops: + async with self.cond: + await asyncio.wait_for(self.cond.wait(), timeout=10) + if self.ops: + try: + work = self.ops.pop(0) + await work() + except Exception as exc: + self.exc = exc + await self.stop() class AsyncSpecTestCreator: diff --git a/test/unified_format.py b/test/unified_format.py index 372eb8abba..28369a5e87 100644 --- a/test/unified_format.py +++ b/test/unified_format.py @@ -1167,7 +1167,14 @@ def primary_changed() -> bool: def _testOperation_runOnThread(self, spec): """Run the 'runOnThread' operation.""" thread = self.entity_map[spec["thread"]] - thread.schedule(lambda: self.run_entity_operation(spec["operation"])) + if _IS_SYNC: + thread.schedule(lambda: self.run_entity_operation(spec["operation"])) + else: + + def op(): + self.run_entity_operation(spec["operation"]) + + thread.schedule(op) def _testOperation_waitForThread(self, spec): """Run the 'waitForThread' operation.""" diff --git a/test/utils_spec_runner.py b/test/utils_spec_runner.py index 4508502cd0..4b24c5c2e8 100644 --- a/test/utils_spec_runner.py +++ b/test/utils_spec_runner.py @@ -54,39 +54,82 @@ _IS_SYNC = True - -class SpecRunnerThread(threading.Thread): - def __init__(self, name): - super().__init__() - self.name = name - self.exc = None - self.daemon = True - self.cond = threading.Condition() - self.ops = [] - self.stopped = False - - def schedule(self, work): - self.ops.append(work) - with self.cond: - self.cond.notify() - - def stop(self): - self.stopped = True - with self.cond: - self.cond.notify() - - def run(self): - while not self.stopped or self.ops: - if not self.ops: - with self.cond: - self.cond.wait(10) - if self.ops: - try: - work = self.ops.pop(0) - work() - except Exception as exc: - self.exc = exc - self.stop() +if _IS_SYNC: + + class SpecRunnerThread(threading.Thread): + def __init__(self, name): + super().__init__() + self.name = name + self.exc = None + self.daemon = True + self.cond = threading.Condition() + self.ops = [] + self.stopped = False + + def schedule(self, work): + self.ops.append(work) + with self.cond: + self.cond.notify() + + def stop(self): + self.stopped = True + with self.cond: + self.cond.notify() + + def run(self): + while not self.stopped or self.ops: + if not self.ops: + with self.cond: + self.cond.wait(10) + if self.ops: + try: + work = self.ops.pop(0) + work() + except Exception as exc: + self.exc = exc + self.stop() +else: + + class SpecRunnerThread: + def __init__(self, name): + self.name = name + self.exc = None + self.cond = asyncio.Condition() + self.ops = [] + self.stopped = False + self.task = None + + def schedule(self, work): + self.ops.append(work) + with self.cond: + self.cond.notify() + + def stop(self): + self.stopped = True + with self.cond: + self.cond.notify() + + def start(self): + self.task = asyncio.create_task(self.run(), name=self.name) + + def join(self, timeout: int = 0): + asyncio.wait([self.task], timeout=timeout) + + def is_alive(self): + return not self.stopped + + def run(self): + while not self.stopped or self.ops: + if not self.ops: + with self.cond: + asyncio.wait_for(self.cond.wait(), timeout=10) + if self.ops: + try: + work = self.ops.pop(0) + work() + except Exception as exc: + self.exc = exc + self.stop() class SpecTestCreator: diff --git a/tools/synchro.py b/tools/synchro.py index 897e5e8018..6444a06922 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -119,6 +119,8 @@ "_async_create_lock": "_create_lock", "_async_create_condition": "_create_condition", "_async_cond_wait": "_cond_wait", + "AsyncDummyMonitor": "DummyMonitor", + "SpecRunnerTask": "SpecRunnerThread", } docstring_replacements: dict[tuple[str, str], str] = {