diff --git a/CHANGELOG.md b/CHANGELOG.md index a31d450..35b6691 100755 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +## v2.0.0-beta.50 - 2025-01-24 + +- Make Python TCP server fully async + ## v2.0.0-beta.49 - 2025-01-23 - Make resource locator optional diff --git a/src/agent/python/main.py b/src/agent/python/main.py index 096b890..8b552bf 100644 --- a/src/agent/python/main.py +++ b/src/agent/python/main.py @@ -28,7 +28,7 @@ async def main(): await agent_service.load_packages() - agent_service.accept_clients() + await agent_service.accept_clients() main_task_reference: asyncio.Task[None] # prevents task to be garbage collected diff --git a/src/agent/python/services.py b/src/agent/python/services.py index 2d0cb7c..c64c93f 100644 --- a/src/agent/python/services.py +++ b/src/agent/python/services.py @@ -4,15 +4,17 @@ import uuid from datetime import timedelta from logging import Logger -from typing import Any, Awaitable, Coroutine, Optional +from typing import Any, Coroutine, Optional from apollo3zehn_package_management import ExtensionHive, PackageService from nexus_remoting._remoting import RemoteCommunicator class TcpClientPair: - comm: Optional[socket.socket] = None - data: Optional[socket.socket] = None + comm_reader: Optional[asyncio.StreamReader] = None + comm_writer: Optional[asyncio.StreamWriter] = None + data_reader: Optional[asyncio.StreamReader] = None + data_writer: Optional[asyncio.StreamWriter] = None remote_communicator: Optional[RemoteCommunicator] = None watchdog_timer = time.time() task: Optional[asyncio.Task] = None @@ -46,18 +48,7 @@ async def load_packages(self): package_reference_map = await self._package_service.get_all() await self._extension_hive.load_packages(package_reference_map) - def accept_clients(self): - - self._logger.info( - "Listening for JSON-RPC communication on %s:%d", - self._json_rpc_listen_address, - self._json_rpc_listen_port - ) - - tcp_listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - tcp_listener.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - tcp_listener.bind((self._json_rpc_listen_address, self._json_rpc_listen_port)) - tcp_listener.listen() + async def accept_clients(self): async def detect_and_remove_inactive_clients(): @@ -71,7 +62,7 @@ async def detect_and_remove_inactive_clients(): watchdog_timer_elasped = timedelta(seconds=now - pair.watchdog_timer) is_dead =\ - (pair.comm is None or pair.data is None) and watchdog_timer_elasped >= self.CLIENT_TIMEOUT or \ + (pair.comm_reader is None or pair.data_reader is None) and watchdog_timer_elasped >= self.CLIENT_TIMEOUT or \ pair.remote_communicator is not None and pair.remote_communicator.last_communication >= self.CLIENT_TIMEOUT if is_dead: @@ -87,34 +78,37 @@ async def detect_and_remove_inactive_clients(): self._create_task(detect_and_remove_inactive_clients()) - async def accept_new_clients(): - - loop = asyncio.get_event_loop() - - while True: - client, _ = await loop.sock_accept(tcp_listener) - self._create_task(self._handle_client(client)) + self._logger.info( + "Listening for JSON-RPC communication on %s:%d", + self._json_rpc_listen_address, + self._json_rpc_listen_port + ) - self._create_task(accept_new_clients()) + server = await asyncio.start_server( + self._handle_client, + host=self._json_rpc_listen_address, + port=self._json_rpc_listen_port + ) - async def _handle_client(self, client: socket.socket): + async with server: + await server.serve_forever() - stream_read_timeout = 1 - client.settimeout(stream_read_timeout) + async def _handle_client(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter): # Get connection id - buffer1 = client.recv(36) + buffer1 = await asyncio.wait_for(reader.readexactly(36), timeout=5) id_string = buffer1.decode("utf-8") # Get connection type - buffer2 = client.recv(4) + buffer2 = await asyncio.wait_for(reader.readexactly(4), timeout=5) type_string = buffer2.decode("utf-8") try: id = uuid.UUID(id_string) except: - client.close() + writer.close() + await writer.wait_closed() return self._logger.debug("Accept TCP client with connection ID %s and communication type %s", id_string, type_string) @@ -127,7 +121,8 @@ async def _handle_client(self, client: socket.socket): if id not in self._tcp_client_pairs: self._tcp_client_pairs[id] = TcpClientPair() - self._tcp_client_pairs[id].comm = client + self._tcp_client_pairs[id].comm_reader = reader + self._tcp_client_pairs[id].comm_writer = writer # We got a "data" tcp connection elif type_string == "data": @@ -135,22 +130,30 @@ async def _handle_client(self, client: socket.socket): if id not in self._tcp_client_pairs: self._tcp_client_pairs[id] = TcpClientPair() - self._tcp_client_pairs[id].data = client + self._tcp_client_pairs[id].data_reader = reader + self._tcp_client_pairs[id].data_writer = writer # Something went wrong, close the socket and return else: - client.close() + writer.close() + await writer.wait_closed() return pair = self._tcp_client_pairs[id] - if pair.comm and pair.data and not pair.remote_communicator: + if pair.comm_reader and \ + pair.comm_writer and \ + pair.data_reader and \ + pair.data_writer and \ + not pair.remote_communicator: self._logger.debug("Accept remoting client with connection ID %s", id) pair.remote_communicator = RemoteCommunicator( - pair.comm, - pair.data, + pair.comm_reader, + pair.comm_writer, + pair.data_reader, + pair.data_writer, get_data_source=lambda type: self._extension_hive.get_instance(type) ) @@ -163,6 +166,3 @@ def _create_task(self, coro: Coroutine[Any, Any, Any]) -> asyncio.Task: task.add_done_callback(self._background_tasks.discard) return task - - - diff --git a/src/remoting/python/nexus_remoting/_remoting.py b/src/remoting/python/nexus_remoting/_remoting.py index 554ba99..c6c896b 100644 --- a/src/remoting/python/nexus_remoting/_remoting.py +++ b/src/remoting/python/nexus_remoting/_remoting.py @@ -1,9 +1,8 @@ +import asyncio import json -import socket import struct import time from datetime import datetime, timedelta -from threading import Lock from typing import Any, Awaitable, Callable, Dict, Optional, Tuple, cast from urllib.parse import urlparse @@ -19,14 +18,10 @@ property_name_decoder=to_snake_case ) -_lock: Lock = Lock() - class _Logger(ILogger): - _comm_socket: socket.socket - - def __init__(self, tcp_comm_socket: socket.socket): - self._comm_socket = tcp_comm_socket + def __init__(self, tcp_comm_socket: asyncio.StreamWriter): + self._comm_writer = tcp_comm_socket def log(self, log_level: LogLevel, message: str): @@ -36,7 +31,7 @@ def log(self, log_level: LogLevel, message: str): "params": [log_level.name, message] } - _send_to_server(notification, self._comm_socket) + asyncio.create_task(_send_to_server(notification, self._comm_writer)) class RemoteCommunicator: """A remote communicator.""" @@ -47,8 +42,10 @@ class RemoteCommunicator: def __init__( self, - comm_socket: socket.socket, - data_socket: socket.socket, + comm_reader: asyncio.StreamReader, + comm_writer: asyncio.StreamWriter, + data_reader: asyncio.StreamReader, + data_writer: asyncio.StreamWriter, get_data_source: Callable[[str], IDataSource] ): """ @@ -60,8 +57,10 @@ def __init__( get_data_source: A func to get a new data source instance by its type name. """ - self._comm_socket = comm_socket - self._data_socket = data_socket + self._comm_reader = comm_reader + self._comm_writer = comm_writer + self._data_reader = data_reader + self._data_writer = data_writer self._get_data_source = get_data_source @property @@ -80,11 +79,8 @@ async def run(self) -> Awaitable: # https://www.jsonrpc.org/specification # get request message - size = self._read_size(self._comm_socket) - json_request = self._comm_socket.recv(size, socket.MSG_WAITALL) - - if len(json_request) == 0: - raise Exception("The connection has been closed.") + size = await self._read_size(self._comm_reader) + json_request = await asyncio.wait_for(self._comm_reader.readexactly(size), timeout=60) request: Dict[str, Any] = json.loads(json_request) @@ -124,12 +120,15 @@ async def run(self) -> Awaitable: response["id"] = request["id"] # send response - _send_to_server(response, self._comm_socket) + await _send_to_server(response, self._comm_writer) # send data if data is not None and status is not None: - self._data_socket.sendall(data) - self._data_socket.sendall(status) + + self._data_writer.write(data) + self._data_writer.write(status) + + await self._data_writer.drain() async def _process_invocation(self, request: dict[str, Any]) \ -> Tuple[ @@ -170,7 +169,7 @@ async def _process_invocation(self, request: dict[str, Any]) \ request_configuration = raw_context["requestConfiguration"] \ if "requestConfiguration" in raw_context else None - self._logger = _Logger(self._comm_socket) + self._logger = _Logger(self._comm_writer) context = DataSourceContext( resource_locator, @@ -279,13 +278,10 @@ async def _handle_read_data(self, resource_path: str, begin: datetime, end: date "params": [resource_path, begin, end] } - _send_to_server(read_data_request, self._comm_socket) + await _send_to_server(read_data_request, self._comm_writer) - size = self._read_size(self._data_socket) - data = self._data_socket.recv(size, socket.MSG_WAITALL) - - if len(data) == 0: - raise Exception("The connection has been closed.") + size = await self._read_size(self._data_reader) + data = await asyncio.wait_for(self._data_reader.readexactly(size), timeout=600) # 'cast' is required because of https://github.com/python/cpython/issues/126012 # see also https://github.com/nexus-main/nexus/issues/184 @@ -294,20 +290,20 @@ async def _handle_read_data(self, resource_path: str, begin: datetime, end: date def _handle_report_progress(self, progress_value: float): pass # not implemented - def _read_size(self, current_socket: socket.socket) -> int: - size_buffer = current_socket.recv(4, socket.MSG_WAITALL) - - if len(size_buffer) == 0: - raise Exception("The connection has been closed.") + async def _read_size(self, reader: asyncio.StreamReader) -> int: + size_buffer = await asyncio.wait_for(reader.readexactly(4), timeout=60) size = struct.unpack(">I", size_buffer)[0] + return size -def _send_to_server(message: Any, current_socket: socket.socket): +async def _send_to_server(message: Any, writer: asyncio.StreamWriter): + encoded = JsonEncoder.encode(message, _json_encoder_options) json_response = json.dumps(encoded) encoded_response = json_response.encode() - with _lock: - current_socket.sendall(struct.pack(">I", len(encoded_response))) - current_socket.sendall(encoded_response) + writer.write(struct.pack(">I", len(encoded_response))) + writer.write(encoded_response) + + await writer.drain() diff --git a/version.json b/version.json index 929a5c5..61d5a39 100644 --- a/version.json +++ b/version.json @@ -1,4 +1,4 @@ { "version": "2.0.0", - "suffix": "beta.49" + "suffix": "beta.50" } \ No newline at end of file