Skip to content

Commit

Permalink
Make Python TCP server fully async
Browse files Browse the repository at this point in the history
  • Loading branch information
Vincent Wilms committed Jan 24, 2025
1 parent 89a8218 commit dc42f06
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 78 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
## v2.0.0-beta.50 - 2025-01-24

- Make Python TCP server fully async

## v2.0.0-beta.49 - 2025-01-23

- Make resource locator optional
Expand Down
2 changes: 1 addition & 1 deletion src/agent/python/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

async def main():
await agent_service.load_packages()
agent_service.accept_clients()
await agent_service.accept_clients()

main_task_reference: asyncio.Task[None] # prevents task to be garbage collected

Expand Down
78 changes: 39 additions & 39 deletions src/agent/python/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,17 @@
import uuid
from datetime import timedelta
from logging import Logger
from typing import Any, Awaitable, Coroutine, Optional
from typing import Any, Coroutine, Optional

from apollo3zehn_package_management import ExtensionHive, PackageService
from nexus_remoting._remoting import RemoteCommunicator


class TcpClientPair:
comm: Optional[socket.socket] = None
data: Optional[socket.socket] = None
comm_reader: Optional[asyncio.StreamReader] = None
comm_writer: Optional[asyncio.StreamWriter] = None
data_reader: Optional[asyncio.StreamReader] = None
data_writer: Optional[asyncio.StreamWriter] = None
remote_communicator: Optional[RemoteCommunicator] = None
watchdog_timer = time.time()
task: Optional[asyncio.Task] = None
Expand Down Expand Up @@ -46,18 +48,7 @@ async def load_packages(self):
package_reference_map = await self._package_service.get_all()
await self._extension_hive.load_packages(package_reference_map)

def accept_clients(self):

self._logger.info(
"Listening for JSON-RPC communication on %s:%d",
self._json_rpc_listen_address,
self._json_rpc_listen_port
)

tcp_listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
tcp_listener.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
tcp_listener.bind((self._json_rpc_listen_address, self._json_rpc_listen_port))
tcp_listener.listen()
async def accept_clients(self):

async def detect_and_remove_inactive_clients():

Expand All @@ -71,7 +62,7 @@ async def detect_and_remove_inactive_clients():
watchdog_timer_elasped = timedelta(seconds=now - pair.watchdog_timer)

is_dead =\
(pair.comm is None or pair.data is None) and watchdog_timer_elasped >= self.CLIENT_TIMEOUT or \
(pair.comm_reader is None or pair.data_reader is None) and watchdog_timer_elasped >= self.CLIENT_TIMEOUT or \
pair.remote_communicator is not None and pair.remote_communicator.last_communication >= self.CLIENT_TIMEOUT

if is_dead:
Expand All @@ -87,34 +78,37 @@ async def detect_and_remove_inactive_clients():

self._create_task(detect_and_remove_inactive_clients())

async def accept_new_clients():

loop = asyncio.get_event_loop()

while True:
client, _ = await loop.sock_accept(tcp_listener)
self._create_task(self._handle_client(client))
self._logger.info(
"Listening for JSON-RPC communication on %s:%d",
self._json_rpc_listen_address,
self._json_rpc_listen_port
)

self._create_task(accept_new_clients())
server = await asyncio.start_server(
self._handle_client,
host=self._json_rpc_listen_address,
port=self._json_rpc_listen_port
)

async def _handle_client(self, client: socket.socket):
async with server:
await server.serve_forever()

stream_read_timeout = 1
client.settimeout(stream_read_timeout)
async def _handle_client(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter):

# Get connection id
buffer1 = client.recv(36)
buffer1 = await asyncio.wait_for(reader.readexactly(36), timeout=5)
id_string = buffer1.decode("utf-8")

# Get connection type
buffer2 = client.recv(4)
buffer2 = await asyncio.wait_for(reader.readexactly(4), timeout=5)
type_string = buffer2.decode("utf-8")

try:
id = uuid.UUID(id_string)

except:
client.close()
writer.close()
await writer.wait_closed()
return

self._logger.debug("Accept TCP client with connection ID %s and communication type %s", id_string, type_string)
Expand All @@ -127,30 +121,39 @@ async def _handle_client(self, client: socket.socket):
if id not in self._tcp_client_pairs:
self._tcp_client_pairs[id] = TcpClientPair()

self._tcp_client_pairs[id].comm = client
self._tcp_client_pairs[id].comm_reader = reader
self._tcp_client_pairs[id].comm_writer = writer

# We got a "data" tcp connection
elif type_string == "data":

if id not in self._tcp_client_pairs:
self._tcp_client_pairs[id] = TcpClientPair()

self._tcp_client_pairs[id].data = client
self._tcp_client_pairs[id].data_reader = reader
self._tcp_client_pairs[id].data_writer = writer

# Something went wrong, close the socket and return
else:
client.close()
writer.close()
await writer.wait_closed()
return

pair = self._tcp_client_pairs[id]

if pair.comm and pair.data and not pair.remote_communicator:
if pair.comm_reader and \
pair.comm_writer and \
pair.data_reader and \
pair.data_writer and \
not pair.remote_communicator:

self._logger.debug("Accept remoting client with connection ID %s", id)

pair.remote_communicator = RemoteCommunicator(
pair.comm,
pair.data,
pair.comm_reader,
pair.comm_writer,
pair.data_reader,
pair.data_writer,
get_data_source=lambda type: self._extension_hive.get_instance(type)
)

Expand All @@ -163,6 +166,3 @@ def _create_task(self, coro: Coroutine[Any, Any, Any]) -> asyncio.Task:
task.add_done_callback(self._background_tasks.discard)

return task



70 changes: 33 additions & 37 deletions src/remoting/python/nexus_remoting/_remoting.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import asyncio
import json
import socket
import struct
import time
from datetime import datetime, timedelta
from threading import Lock
from typing import Any, Awaitable, Callable, Dict, Optional, Tuple, cast
from urllib.parse import urlparse

Expand All @@ -19,14 +18,10 @@
property_name_decoder=to_snake_case
)

_lock: Lock = Lock()

class _Logger(ILogger):

_comm_socket: socket.socket

def __init__(self, tcp_comm_socket: socket.socket):
self._comm_socket = tcp_comm_socket
def __init__(self, tcp_comm_socket: asyncio.StreamWriter):
self._comm_writer = tcp_comm_socket

def log(self, log_level: LogLevel, message: str):

Expand All @@ -36,7 +31,7 @@ def log(self, log_level: LogLevel, message: str):
"params": [log_level.name, message]
}

_send_to_server(notification, self._comm_socket)
asyncio.create_task(_send_to_server(notification, self._comm_writer))

class RemoteCommunicator:
"""A remote communicator."""
Expand All @@ -47,8 +42,10 @@ class RemoteCommunicator:

def __init__(
self,
comm_socket: socket.socket,
data_socket: socket.socket,
comm_reader: asyncio.StreamReader,
comm_writer: asyncio.StreamWriter,
data_reader: asyncio.StreamReader,
data_writer: asyncio.StreamWriter,
get_data_source: Callable[[str], IDataSource]
):
"""
Expand All @@ -60,8 +57,10 @@ def __init__(
get_data_source: A func to get a new data source instance by its type name.
"""

self._comm_socket = comm_socket
self._data_socket = data_socket
self._comm_reader = comm_reader
self._comm_writer = comm_writer
self._data_reader = data_reader
self._data_writer = data_writer
self._get_data_source = get_data_source

@property
Expand All @@ -80,11 +79,8 @@ async def run(self) -> Awaitable:
# https://www.jsonrpc.org/specification

# get request message
size = self._read_size(self._comm_socket)
json_request = self._comm_socket.recv(size, socket.MSG_WAITALL)

if len(json_request) == 0:
raise Exception("The connection has been closed.")
size = await self._read_size(self._comm_reader)
json_request = await asyncio.wait_for(self._comm_reader.readexactly(size), timeout=60)

request: Dict[str, Any] = json.loads(json_request)

Expand Down Expand Up @@ -124,12 +120,15 @@ async def run(self) -> Awaitable:
response["id"] = request["id"]

# send response
_send_to_server(response, self._comm_socket)
await _send_to_server(response, self._comm_writer)

# send data
if data is not None and status is not None:
self._data_socket.sendall(data)
self._data_socket.sendall(status)

self._data_writer.write(data)
self._data_writer.write(status)

await self._data_writer.drain()

async def _process_invocation(self, request: dict[str, Any]) \
-> Tuple[
Expand Down Expand Up @@ -170,7 +169,7 @@ async def _process_invocation(self, request: dict[str, Any]) \
request_configuration = raw_context["requestConfiguration"] \
if "requestConfiguration" in raw_context else None

self._logger = _Logger(self._comm_socket)
self._logger = _Logger(self._comm_writer)

context = DataSourceContext(
resource_locator,
Expand Down Expand Up @@ -279,13 +278,10 @@ async def _handle_read_data(self, resource_path: str, begin: datetime, end: date
"params": [resource_path, begin, end]
}

_send_to_server(read_data_request, self._comm_socket)
await _send_to_server(read_data_request, self._comm_writer)

size = self._read_size(self._data_socket)
data = self._data_socket.recv(size, socket.MSG_WAITALL)

if len(data) == 0:
raise Exception("The connection has been closed.")
size = await self._read_size(self._data_reader)
data = await asyncio.wait_for(self._data_reader.readexactly(size), timeout=600)

# 'cast' is required because of https://github.com/python/cpython/issues/126012
# see also https://github.com/nexus-main/nexus/issues/184
Expand All @@ -294,20 +290,20 @@ async def _handle_read_data(self, resource_path: str, begin: datetime, end: date
def _handle_report_progress(self, progress_value: float):
pass # not implemented

def _read_size(self, current_socket: socket.socket) -> int:
size_buffer = current_socket.recv(4, socket.MSG_WAITALL)

if len(size_buffer) == 0:
raise Exception("The connection has been closed.")
async def _read_size(self, reader: asyncio.StreamReader) -> int:

size_buffer = await asyncio.wait_for(reader.readexactly(4), timeout=60)
size = struct.unpack(">I", size_buffer)[0]

return size

def _send_to_server(message: Any, current_socket: socket.socket):
async def _send_to_server(message: Any, writer: asyncio.StreamWriter):

encoded = JsonEncoder.encode(message, _json_encoder_options)
json_response = json.dumps(encoded)
encoded_response = json_response.encode()

with _lock:
current_socket.sendall(struct.pack(">I", len(encoded_response)))
current_socket.sendall(encoded_response)
writer.write(struct.pack(">I", len(encoded_response)))
writer.write(encoded_response)

await writer.drain()
2 changes: 1 addition & 1 deletion version.json
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
{
"version": "2.0.0",
"suffix": "beta.49"
"suffix": "beta.50"
}

0 comments on commit dc42f06

Please sign in to comment.