Skip to content

Commit

Permalink
✨ 新增select语句
Browse files Browse the repository at this point in the history
  • Loading branch information
snowykami committed Oct 12, 2024
1 parent 8e94cd9 commit 4975b19
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 4 deletions.
24 changes: 22 additions & 2 deletions magicoca/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,25 @@
from multiprocessing import set_start_method
from typing import Any, Callable, Generator

from magicoca.chan import Chan, T
from magicoca.chan import Chan, T, NoRecvValue

set_start_method("spawn", force=True)
__all__ = [
"Chan",
"select"
]

set_start_method("spawn", force=True)

def select(*args: Chan[T]) -> Generator[T, None, None]:
"""
Return a yield, when a value is received from one of the channels.
Args:
args: channels
"""
while True:
for ch in args:
if ch.is_closed:
continue

if not isinstance(value := ch.recv(0), NoRecvValue):
yield value
14 changes: 12 additions & 2 deletions magicoca/chan.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@

T = TypeVar("T")

class NoRecvValue(Exception):
"""
Exception raised when there is no value to receive.
"""
pass


class Chan(Generic[T]):
"""
Expand All @@ -31,6 +37,8 @@ def __init__(self):
"""
self.send_conn, self.recv_conn = Pipe()

self.is_closed = False

def send(self, value: T):
"""
Send a value to the channel.
Expand All @@ -39,7 +47,7 @@ def send(self, value: T):
"""
self.send_conn.send(value)

def recv(self, timeout: float | None = None) -> T | None:
def recv(self, timeout: float | None = None) -> T | None | NoRecvValue:
"""Receive a value from the channel.
If the timeout is None, it will block until a value is received.
If the timeout is a positive number, it will wait for the specified time, and if no value is received, it will return None.
Expand All @@ -56,15 +64,17 @@ def recv(self, timeout: float | None = None) -> T | None:
"""
if timeout is not None:
if not self.recv_conn.poll(timeout):
return None
return NoRecvValue("No value to receive.")
return self.recv_conn.recv()


def close(self):
"""
Close the channel. destructor
"""
self.send_conn.close()
self.recv_conn.close()
self.is_closed = True

def __iter__(self) -> "Chan[T]":
"""
Expand Down
5 changes: 5 additions & 0 deletions tests/test_chan.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import time
from select import select

from magicoca.chan import Chan
from multiprocessing import Process, set_start_method
Expand All @@ -21,6 +22,7 @@ def p2f(chan: Chan[int]):
if recv_ans != list(range(10)) + [-1]:
raise ValueError("Chan Shift Test Failed")


class TestChan:

def test_test(self):
Expand Down Expand Up @@ -63,3 +65,6 @@ def test_connect(self):
p2.start()
p1.join()
p2.join()



42 changes: 42 additions & 0 deletions tests/test_select.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from multiprocessing import Process

from magicoca import Chan, select


def sp1(chan: Chan[int]):
for i in range(10):
chan << i


def sp2(chan: Chan[int]):
for i in range(10):
chan << i


def rp(chans: list[Chan[int]]):
rl = []
for t in select(*chans):
rl.append(t)
if len(rl) == 20:
break
print(rl)
assert len(rl) == 20


class TestSelect:
def test_select(self):
chan1 = Chan[int]()
chan2 = Chan[int]()

print("Test Chan Select")

p1 = Process(target=sp1, args=(chan1,))
p2 = Process(target=sp2, args=(chan2,))
p3 = Process(target=rp, args=([chan1, chan2],))
p3.start()
p1.start()
p2.start()

p1.join()
p2.join()
p3.join()

0 comments on commit 4975b19

Please sign in to comment.