Skip to content

Commit

Permalink
implement Event directly with wait_task_rescheduled rather than Parki…
Browse files Browse the repository at this point in the history
…ngLot

raised in python-trio#1944
  • Loading branch information
belm0 committed Apr 6, 2021
1 parent 52a210f commit dc78d64
Showing 1 changed file with 22 additions and 8 deletions.
30 changes: 22 additions & 8 deletions trio/_sync.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
import math

import attr
import outcome

import trio

from . import _core
from ._core import enable_ki_protection, ParkingLot
from ._deprecate import deprecated
from ._util import Final


@attr.s(repr=False, eq=False, hash=False)
@attr.s(frozen=True)
class _EventStatistics:
tasks_waiting = attr.ib()


@attr.s(repr=False, eq=False, hash=False, slots=True)
class Event(metaclass=Final):
"""A waitable boolean value useful for inter-task synchronization,
inspired by :class:`threading.Event`.
Expand All @@ -37,7 +41,7 @@ class Event(metaclass=Final):
"""

_lot = attr.ib(factory=ParkingLot, init=False)
_tasks = attr.ib(factory=set, init=False)
_flag = attr.ib(default=False, init=False)

def is_set(self):
Expand All @@ -47,8 +51,10 @@ def is_set(self):
@enable_ki_protection
def set(self):
"""Set the internal flag value to True, and wake any waiting tasks."""
self._flag = True
self._lot.unpark_all()
if not self._flag:
self._flag = True
for task in self._tasks:
_core.reschedule(task)

async def wait(self):
"""Block until the internal flag value becomes True.
Expand All @@ -59,7 +65,15 @@ async def wait(self):
if self._flag:
await trio.lowlevel.checkpoint()
else:
await self._lot.park()
task = _core.current_task()
self._tasks.add(task)
task.custom_sleep_data = self

def abort_fn(_):
task.custom_sleep_data._tasks.remove(task)
return _core.Abort.SUCCEEDED

await _core.wait_task_rescheduled(abort_fn)

def statistics(self):
"""Return an object containing debugging information.
Expand All @@ -70,7 +84,7 @@ def statistics(self):
:meth:`wait` method.
"""
return self._lot.statistics()
return _EventStatistics(tasks_waiting=len(self._tasks))


def async_cm(cls):
Expand Down

0 comments on commit dc78d64

Please sign in to comment.