Skip to content

Commit

Permalink
Add cleanup_socket param on create_unix_server()
Browse files Browse the repository at this point in the history
This is derived from python/cpython#111483 but available on
all Python versions with uvloop, only that it's only enabled
by default for Python 3.13 and above to be consistent with
CPython behavior.
  • Loading branch information
fantix committed Aug 28, 2024
1 parent 0019ff9 commit d6114d2
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 18 deletions.
40 changes: 24 additions & 16 deletions tests/test_unix.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,18 @@ async def start_server():

self.assertFalse(srv.is_serving())

# asyncio doesn't cleanup the sock file
self.assertTrue(os.path.exists(sock_name))
if sys.version_info < (3, 13):
# asyncio doesn't cleanup the sock file under Python 3.13
self.assertTrue(os.path.exists(sock_name))
else:
self.assertFalse(os.path.exists(sock_name))

async def start_server_sock(start_server, is_unix_api=True):
# is_unix_api indicates whether `start_server` is calling
# `loop.create_unix_server()` or `loop.create_server()`,
# because asyncio `loop.create_server()` doesn't cleanup
# the socket file even if it's a UNIX socket.

async def start_server_sock(start_server):
nonlocal CNT
CNT = 0

Expand Down Expand Up @@ -140,8 +148,11 @@ async def start_server_sock(start_server):

self.assertFalse(srv.is_serving())

# asyncio doesn't cleanup the sock file
self.assertTrue(os.path.exists(sock_name))
if sys.version_info < (3, 13) or not is_unix_api:
# asyncio doesn't cleanup the sock file under Python 3.13
self.assertTrue(os.path.exists(sock_name))
else:
self.assertFalse(os.path.exists(sock_name))

with self.subTest(func='start_unix_server(host, port)'):
self.loop.run_until_complete(start_server())
Expand All @@ -160,7 +171,7 @@ async def start_server_sock(start_server):
lambda sock: asyncio.start_server(
handle_client,
None, None,
sock=sock)))
sock=sock), is_unix_api=False))
self.assertEqual(CNT, TOTAL_CNT)

def test_create_unix_server_2(self):
Expand Down Expand Up @@ -455,16 +466,13 @@ def test_create_unix_server_path_stream_bittype(self):
socket.AF_UNIX, socket.SOCK_STREAM | socket.SOCK_NONBLOCK)
with tempfile.NamedTemporaryFile() as file:
fn = file.name
try:
with sock:
sock.bind(fn)
coro = self.loop.create_unix_server(lambda: None, path=None,
sock=sock)
srv = self.loop.run_until_complete(coro)
srv.close()
self.loop.run_until_complete(srv.wait_closed())
finally:
os.unlink(fn)
with sock:
sock.bind(fn)
coro = self.loop.create_unix_server(lambda: None, path=None,
sock=sock, cleanup_socket=True)
srv = self.loop.run_until_complete(coro)
srv.close()
self.loop.run_until_complete(srv.wait_closed())

@unittest.skipUnless(sys.platform.startswith('linux'), 'requires epoll')
def test_epollhup(self):
Expand Down
21 changes: 21 additions & 0 deletions uvloop/handles/pipe.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,27 @@ cdef class UnixServer(UVStreamServer):
context)
return <UVStream>tr

cdef _close(self):
sock = self._fileobj
if sock is not None and sock in self._loop._unix_server_sockets:
path = sock.getsockname()
else:
path = None

UVStreamServer._close(self)

if path is not None:
prev_ino = self._loop._unix_server_sockets[sock]
del self._loop._unix_server_sockets[sock]
try:
if os_stat(path).st_ino == prev_ino:
os_unlink(path)
except FileNotFoundError:
pass
except OSError as err:
aio_logger.error('Unable to clean up listening UNIX socket '
'%r: %r', path, err)


@cython.no_gc_clear
cdef class UnixTransport(UVStream):
Expand Down
1 change: 1 addition & 0 deletions uvloop/includes/stdlib.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ cdef os_pipe = os.pipe
cdef os_read = os.read
cdef os_remove = os.remove
cdef os_stat = os.stat
cdef os_unlink = os.unlink
cdef os_fspath = os.fspath

cdef stat_S_ISSOCK = stat.S_ISSOCK
Expand Down
1 change: 1 addition & 0 deletions uvloop/loop.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ cdef class Loop:
set _processes
dict _fd_to_reader_fileobj
dict _fd_to_writer_fileobj
dict _unix_server_sockets

set _signals
dict _signal_handlers
Expand Down
24 changes: 22 additions & 2 deletions uvloop/loop.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ include "errors.pyx"
cdef:
int PY39 = PY_VERSION_HEX >= 0x03090000
int PY311 = PY_VERSION_HEX >= 0x030b0000
int PY313 = PY_VERSION_HEX >= 0x030d0000
uint64_t MAX_SLEEP = 3600 * 24 * 365 * 100


Expand Down Expand Up @@ -155,6 +156,8 @@ cdef class Loop:
self._fd_to_reader_fileobj = {}
self._fd_to_writer_fileobj = {}

self._unix_server_sockets = {}

self._timers = set()
self._polls = {}

Expand Down Expand Up @@ -1704,7 +1707,10 @@ cdef class Loop:
'host/port and sock can not be specified at the same time')
return await self.create_unix_server(
protocol_factory, sock=sock, backlog=backlog, ssl=ssl,
start_serving=start_serving)
start_serving=start_serving,
# asyncio won't clean up socket file using create_server() API
cleanup_socket=False,
)

server = Server(self)

Expand Down Expand Up @@ -2089,7 +2095,7 @@ cdef class Loop:
*, backlog=100, sock=None, ssl=None,
ssl_handshake_timeout=None,
ssl_shutdown_timeout=None,
start_serving=True):
start_serving=True, cleanup_socket=PY313):
"""A coroutine which creates a UNIX Domain Socket server.
The return value is a Server object, which can be used to stop
Expand All @@ -2114,6 +2120,11 @@ cdef class Loop:
ssl_shutdown_timeout is the time in seconds that an SSL server
will wait for completion of the SSL shutdown before aborting the
connection. Default is 30s.
If *cleanup_socket* is true then the Unix socket will automatically
be removed from the filesystem when the server is closed, unless the
socket has been replaced after the server has been created.
This defaults to True on Python 3.13 and above, or False otherwise.
"""
cdef:
UnixServer pipe
Expand Down Expand Up @@ -2191,6 +2202,15 @@ cdef class Loop:
# we want Python socket object to notice that.
sock.setblocking(False)

if cleanup_socket:
path = sock.getsockname()
# Check for abstract socket. `str` and `bytes` paths are supported.
if path[0] not in (0, '\x00'):
try:
self._unix_server_sockets[sock] = os_stat(path).st_ino
except FileNotFoundError:
pass

pipe = UnixServer.new(
self, protocol_factory, server, backlog,
ssl, ssl_handshake_timeout, ssl_shutdown_timeout)
Expand Down

0 comments on commit d6114d2

Please sign in to comment.