From 0ed35a7f70799f4979561a3f42a61b7eb218fc86 Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Mon, 14 Oct 2024 08:51:44 +0100 Subject: [PATCH] add test for ki protection leaking accross local functions --- src/trio/_core/_tests/test_ki.py | 25 ++++++++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/src/trio/_core/_tests/test_ki.py b/src/trio/_core/_tests/test_ki.py index e4241fc762..271c563de2 100644 --- a/src/trio/_core/_tests/test_ki.py +++ b/src/trio/_core/_tests/test_ki.py @@ -4,7 +4,7 @@ import inspect import signal import threading -from typing import TYPE_CHECKING, AsyncIterator, Callable, Iterator +from typing import TYPE_CHECKING, AsyncIterator, Callable, Iterator, TypeVar import outcome import pytest @@ -515,3 +515,26 @@ async def inner() -> None: _core.run(inner) finally: threading._active[thread.ident] = original # type: ignore[attr-defined] + + +_T = TypeVar("_T") + + +def _identity(v: _T) -> _T: + return v + + +async def test_ki_does_not_leak_accross_different_calls_to_inner_functions() -> None: + assert not _core.currently_ki_protected() + + def factory(enabled: bool) -> Callable[[], bool]: + @_identity(_core.enable_ki_protection if enabled else _identity) + def decorated() -> bool: + return _core.currently_ki_protected() + + return decorated + + decorated_enabled = factory(True) + decorated_disabled = factory(False) + assert decorated_enabled() + assert not decorated_disabled()