Skip to content

Commit

Permalink
change() is now thread safe using thread-local Config() class
Browse files Browse the repository at this point in the history
  • Loading branch information
mreiche committed Sep 5, 2023
1 parent 217b703 commit 6583657
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 23 deletions.
34 changes: 20 additions & 14 deletions paf/control.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,30 @@
import dataclasses
import threading
from contextlib import contextmanager
from dataclasses import dataclass
from time import sleep, time
from typing import Callable
from typing import Callable, Optional

from paf.common import Property
from paf.types import Consumer


@dataclass()
class Config:
class Config(threading.local):
retry_count: int = Property.env(Property.PAF_SEQUENCE_RETRY_COUNT)
wait_after_fail: float = Property.env(Property.PAF_SEQUENCE_WAIT_AFTER_FAIL)


__global_config = Config()
__config = Config()


def __get_global_config():
return __global_config
def get_config():
return __config


def __set_config(config: Config):
global __config
__config = config


class Sequence:
Expand Down Expand Up @@ -59,26 +65,26 @@ def change(
retry_count: int = None,
wait_after_fail: float = None,
):
global __global_config
config_backup = __global_config
config = dataclasses.replace(__global_config)
__global_config = config
config_backup = get_config()
scope_config = dataclasses.replace(config_backup)
__set_config(scope_config)

if retry_count is not None:
config.retry_count = retry_count
scope_config.retry_count = retry_count

if wait_after_fail is not None:
config.wait_after_fail = wait_after_fail
scope_config.wait_after_fail = wait_after_fail

try:
yield
finally:
__global_config = config_backup
__set_config(config_backup)


def retry(action: Callable, on_fail: Consumer[Exception] = None):
sequence = Sequence(retry_count=__global_config.retry_count, wait_after_fail=__global_config.wait_after_fail)
exception = None
config = get_config()
sequence = Sequence(retry_count=config.retry_count, wait_after_fail=config.wait_after_fail)
exception: Optional[Exception] = None

def _run():
nonlocal exception
Expand Down
53 changes: 44 additions & 9 deletions test/test_control.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,55 @@
import asyncio
import dataclasses
from time import sleep

from paf.control import change, __get_global_config, retry
import pytest

from paf.common import Property
from paf.control import change, get_config, retry

def test_config():

backup = dataclasses.replace(__get_global_config())
assert backup.retry_count != 99
assert backup.wait_after_fail != 99
def test_change():

backup_config = dataclasses.replace(get_config())
assert backup_config.retry_count != 99
assert backup_config.wait_after_fail != 99

with change(retry_count=99, wait_after_fail=99):
global_config = __get_global_config()
global_config = get_config()
assert global_config.wait_after_fail == 99
assert global_config.retry_count == 99
retry(lambda: None)

global_config = __get_global_config()
assert global_config.retry_count == backup.retry_count
assert global_config.wait_after_fail == backup.wait_after_fail
global_config = get_config()
assert global_config.retry_count == backup_config.retry_count
assert global_config.wait_after_fail == backup_config.wait_after_fail


def change_first():
global_config = get_config()
assert global_config.retry_count == Property.env(Property.PAF_SEQUENCE_RETRY_COUNT)

with change(retry_count=0):
sleep(0.3)
config = get_config()
assert config.retry_count == 0


def change_second():
global_config = get_config()
assert global_config.retry_count == Property.env(Property.PAF_SEQUENCE_RETRY_COUNT)

with change(retry_count=99):
sleep(0.1)
config = get_config()
assert config.retry_count == 99


@pytest.mark.asyncio
async def test_thread_safety():
tasks = [
asyncio.to_thread(change_first),
asyncio.to_thread(change_second)
]

await asyncio.gather(*tasks)

0 comments on commit 6583657

Please sign in to comment.