From ec30d7b33d9c5994bebb81d8537acf0889db92ac Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Mon, 14 Oct 2024 08:53:51 +0100 Subject: [PATCH] add fix for ki protection leaking accross local functions --- src/trio/_core/_ki.py | 73 ++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 69 insertions(+), 4 deletions(-) diff --git a/src/trio/_core/_ki.py b/src/trio/_core/_ki.py index 15a62aada..bd86b5f26 100644 --- a/src/trio/_core/_ki.py +++ b/src/trio/_core/_ki.py @@ -4,7 +4,7 @@ import sys import types import weakref -from typing import TYPE_CHECKING, Protocol, TypeVar +from typing import TYPE_CHECKING, Any, Generic, Protocol, TypeVar import attrs @@ -14,7 +14,7 @@ import types from collections.abc import Callable - from typing_extensions import TypeGuard + from typing_extensions import Self, TypeGuard # In ordinary single-threaded Python code, when you hit control-C, it raises # an exception and automatically does all the regular unwinding stuff. # @@ -77,10 +77,75 @@ # for any Python program that's written to catch and ignore # KeyboardInterrupt.) -_CODE_KI_PROTECTION_STATUS_WMAP: weakref.WeakKeyDictionary[ +_T = TypeVar("_T") + + +class _IdRef(weakref.ref[_T]): + slots = "_hash" + _hash: int + + def __new__(cls, ob: _T, callback: Callable[[Self], Any] | None = None, /) -> Self: + self: Self = weakref.ref.__new__(cls, ob, callback) + self._hash = object.__hash__(ob) + return self + + def __eq__(self, other: object) -> bool: + if self is other: + return True + + if not isinstance(other, _IdRef): + return NotImplemented + + my_obj = None + other_obj: Any = None + try: + my_obj = self() + other_obj = other() + return my_obj is not None and other_obj is not None and my_obj is other_obj + finally: + del my_obj, other_obj + + def __hash__(self) -> int: + return self._hash + + +_KT = TypeVar("_KT") +_VT = TypeVar("_VT") + + +# see also: https://github.com/python/cpython/issues/88306 +class WeakKeyIdentityDictionary(Generic[_KT, _VT]): + def __init__(self) -> None: + self._data: dict[_IdRef[_KT], _VT] = {} + + def remove( + k: _IdRef[_KT], + selfref: weakref.ref[ + WeakKeyIdentityDictionary[_KT, _VT] + ] = weakref.ref( # noqa: B008 # function-call-in-default-argument + self, + ), + ) -> None: + self = selfref() + if self is not None: + try: # noqa: SIM105 # supressible-exception + del self._data[k] + except KeyError: + pass + + self._remove = remove + + def __getitem__(self, k: _KT) -> _VT: + return self._data[_IdRef(k)] + + def __setitem__(self, k: _KT, v: _VT) -> None: + self._data[_IdRef(k, self._remove)] = v + + +_CODE_KI_PROTECTION_STATUS_WMAP: WeakKeyIdentityDictionary[ types.CodeType, bool, -] = weakref.WeakKeyDictionary() +] = WeakKeyIdentityDictionary() # This is to support the async_generator package necessary for aclosing on <3.10