diff --git a/paf/control.py b/paf/control.py index 60d4d7b..423c805 100644 --- a/paf/control.py +++ b/paf/control.py @@ -1,24 +1,30 @@ import dataclasses +import threading from contextlib import contextmanager from dataclasses import dataclass from time import sleep, time -from typing import Callable +from typing import Callable, Optional from paf.common import Property from paf.types import Consumer @dataclass() -class Config: +class Config(threading.local): retry_count: int = Property.env(Property.PAF_SEQUENCE_RETRY_COUNT) wait_after_fail: float = Property.env(Property.PAF_SEQUENCE_WAIT_AFTER_FAIL) -__global_config = Config() +__config = Config() -def __get_global_config(): - return __global_config +def get_config(): + return __config + + +def __set_config(config: Config): + global __config + __config = config class Sequence: @@ -59,26 +65,26 @@ def change( retry_count: int = None, wait_after_fail: float = None, ): - global __global_config - config_backup = __global_config - config = dataclasses.replace(__global_config) - __global_config = config + config_backup = get_config() + scope_config = dataclasses.replace(config_backup) + __set_config(scope_config) if retry_count is not None: - config.retry_count = retry_count + scope_config.retry_count = retry_count if wait_after_fail is not None: - config.wait_after_fail = wait_after_fail + scope_config.wait_after_fail = wait_after_fail try: yield finally: - __global_config = config_backup + __set_config(config_backup) def retry(action: Callable, on_fail: Consumer[Exception] = None): - sequence = Sequence(retry_count=__global_config.retry_count, wait_after_fail=__global_config.wait_after_fail) - exception = None + config = get_config() + sequence = Sequence(retry_count=config.retry_count, wait_after_fail=config.wait_after_fail) + exception: Optional[Exception] = None def _run(): nonlocal exception diff --git a/test/test_control.py b/test/test_control.py index e0741f3..7d6bb5a 100644 --- a/test/test_control.py +++ b/test/test_control.py @@ -1,20 +1,55 @@ +import asyncio import dataclasses +from time import sleep -from paf.control import change, __get_global_config, retry +import pytest +from paf.common import Property +from paf.control import change, get_config, retry -def test_config(): - backup = dataclasses.replace(__get_global_config()) - assert backup.retry_count != 99 - assert backup.wait_after_fail != 99 +def test_change(): + + backup_config = dataclasses.replace(get_config()) + assert backup_config.retry_count != 99 + assert backup_config.wait_after_fail != 99 with change(retry_count=99, wait_after_fail=99): - global_config = __get_global_config() + global_config = get_config() assert global_config.wait_after_fail == 99 assert global_config.retry_count == 99 retry(lambda: None) - global_config = __get_global_config() - assert global_config.retry_count == backup.retry_count - assert global_config.wait_after_fail == backup.wait_after_fail + global_config = get_config() + assert global_config.retry_count == backup_config.retry_count + assert global_config.wait_after_fail == backup_config.wait_after_fail + + +def change_first(): + global_config = get_config() + assert global_config.retry_count == Property.env(Property.PAF_SEQUENCE_RETRY_COUNT) + + with change(retry_count=0): + sleep(0.3) + config = get_config() + assert config.retry_count == 0 + + +def change_second(): + global_config = get_config() + assert global_config.retry_count == Property.env(Property.PAF_SEQUENCE_RETRY_COUNT) + + with change(retry_count=99): + sleep(0.1) + config = get_config() + assert config.retry_count == 99 + + +@pytest.mark.asyncio +async def test_thread_safety(): + tasks = [ + asyncio.to_thread(change_first), + asyncio.to_thread(change_second) + ] + + await asyncio.gather(*tasks)