diff --git a/pygls/io_.py b/pygls/io_.py new file mode 100644 index 00000000..85eba53a --- /dev/null +++ b/pygls/io_.py @@ -0,0 +1,193 @@ +############################################################################ +# Copyright(c) Open Law Library. All rights reserved. # +# See ThirdPartyNotices.txt in the project root for additional notices. # +# # +# Licensed under the Apache License, Version 2.0 (the "License") # +# you may not use this file except in compliance with the License. # +# You may obtain a copy of the License at # +# # +# http: // www.apache.org/licenses/LICENSE-2.0 # +# # +# Unless required by applicable law or agreed to in writing, software # +# distributed under the License is distributed on an "AS IS" BASIS, # +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # +# See the License for the specific language governing permissions and # +# limitations under the License. # +############################################################################ +from __future__ import annotations + +import asyncio +import json +import logging +import re +import typing + +from pygls.exceptions import JsonRpcException + +if typing.TYPE_CHECKING: + import logging + import threading + from collections.abc import Awaitable + from concurrent.futures import ThreadPoolExecutor + from typing import Any, BinaryIO, Callable, Protocol + + from pygls.protocol import JsonRPCProtocol + + class Reader(Protocol): + """An synchronous reader.""" + + def readline(self) -> bytes: ... + + def read(self, n: int) -> bytes: ... + + class AsyncReader(typing.Protocol): + """An asynchronous reader.""" + + def readline(self) -> Awaitable[bytes]: ... + + def readexactly(self, n: int) -> Awaitable[bytes]: ... + + +class StdinAsyncReader: + """Read from stdin asynchronously.""" + + def __init__(self, stdin: BinaryIO, executor: ThreadPoolExecutor | None = None): + self.stdin = stdin + self._loop: asyncio.AbstractEventLoop | None = None + self.executor = executor + + @property + def loop(self): + if self._loop is None: + self._loop = asyncio.get_running_loop() + + return self._loop + + def readline(self) -> Awaitable[bytes]: + return self.loop.run_in_executor(self.executor, self.stdin.readline) + + def readexactly(self, n: int) -> Awaitable[bytes]: + return self.loop.run_in_executor(self.executor, self.stdin.read, n) + + +async def run_async( + stop_event: threading.Event, + reader: AsyncReader, + protocol: JsonRPCProtocol, + logger: logging.Logger | None = None, + error_handler: Callable[[Exception, type[JsonRpcException]], Any] | None = None, +): + """Run a main message processing loop, asynchronously + + Parameters + ---------- + stop_event + A ``threading.Event`` used to break the main loop + + reader + The reader to read messages from + + protocol + The protocol instance that should handle the messages + + logger + The logger instance to use + """ + + CONTENT_LENGTH_PATTERN = re.compile(rb"^Content-Length: (\d+)\r\n$") + content_length = 0 + logger = logger or logging.getLogger(__name__) + + while not stop_event.is_set(): + # Read a header line + header = await reader.readline() + if not header: + break + + # Extract content length if possible + if not content_length: + match = CONTENT_LENGTH_PATTERN.fullmatch(header) + if match: + content_length = int(match.group(1)) + logger.debug("Content length: %s", content_length) + + # Check if all headers have been read (as indicated by an empty line \r\n) + if content_length and not header.strip(): + # Read body + body = await reader.readexactly(content_length) + if not body: + break + + try: + message = json.loads(body, object_hook=protocol.structure_message) + protocol.handle_message(message) + except Exception as exc: + logger.exception("Unable to handle message") + if error_handler: + error_handler(exc, JsonRpcException) + finally: + # Reset + content_length = 0 + + +def run( + stop_event: threading.Event, + reader: Reader, + protocol: JsonRPCProtocol, + logger: logging.Logger | None = None, + error_handler: Callable[[Exception, type[JsonRpcException]], Any] | None = None, +): + """Run a main message processing loop, synchronously + + Parameters + ---------- + stop_event + A ``threading.Event`` used to break the main loop + + reader + The reader to read messages from + + protocol + The protocol instance that should handle the messages + + logger + The logger instance to use + + error_handler + Function to call when an error is encountered. + """ + + CONTENT_LENGTH_PATTERN = re.compile(rb"^Content-Length: (\d+)\r\n$") + content_length = 0 + logger = logger or logging.getLogger(__name__) + + while not stop_event.is_set(): + # Read a header line + header = reader.readline() + if not header: + break + + # Extract content length if possible + if not content_length: + match = CONTENT_LENGTH_PATTERN.fullmatch(header) + if match: + content_length = int(match.group(1)) + logger.debug("Content length: %s", content_length) + + # Check if all headers have been read (as indicated by an empty line \r\n) + if content_length and not header.strip(): + # Read body + body = reader.read(content_length) + if not body: + break + + try: + message = json.loads(body, object_hook=protocol.structure_message) + protocol.handle_message(message) + except Exception as exc: + logger.exception("Unable to handle message") + if error_handler: + error_handler(exc, JsonRpcException) + finally: + # Reset + content_length = 0 diff --git a/pygls/server.py b/pygls/server.py index 69ee8816..31318b00 100644 --- a/pygls/server.py +++ b/pygls/server.py @@ -19,73 +19,22 @@ import asyncio import json import logging -import re import sys from concurrent.futures import ThreadPoolExecutor from threading import Event from typing import Any, BinaryIO, Callable, Optional, Type, TypeVar, Union import cattrs -from pygls.exceptions import ( - FeatureNotificationError, - JsonRpcInternalError, - PyglsError, - JsonRpcException, - FeatureRequestError, -) -from pygls.protocol import JsonRPCProtocol +from pygls.exceptions import JsonRpcException, PyglsError +from pygls.io_ import StdinAsyncReader, run_async +from pygls.protocol import JsonRPCProtocol logger = logging.getLogger(__name__) F = TypeVar("F", bound=Callable) -ServerErrors = Union[ - PyglsError, - JsonRpcException, - Type[JsonRpcInternalError], - Type[FeatureNotificationError], - Type[FeatureRequestError], -] - - -async def aio_readline(loop, executor, stop_event, rfile, proxy): - """Reads data from stdin in separate thread (asynchronously).""" - - CONTENT_LENGTH_PATTERN = re.compile(rb"^Content-Length: (\d+)\r\n$") - - # Initialize message buffer - message = [] - content_length = 0 - - while not stop_event.is_set() and not rfile.closed: - # Read a header line - header = await loop.run_in_executor(executor, rfile.readline) - if not header: - break - message.append(header) - - # Extract content length if possible - if not content_length: - match = CONTENT_LENGTH_PATTERN.fullmatch(header) - if match: - content_length = int(match.group(1)) - logger.debug("Content length: %s", content_length) - - # Check if all headers have been read (as indicated by an empty line \r\n) - if content_length and not header.strip(): - # Read body - body = await loop.run_in_executor(executor, rfile.read, content_length) - if not body: - break - message.append(body) - - # Pass message to language server protocol - proxy(b"".join(message)) - - # Reset the buffer - message = [] - content_length = 0 +ServerErrors = Union[type[PyglsError], type[JsonRpcException]] class StdOutTransportAdapter: @@ -228,19 +177,20 @@ def start_io( logger.info("Starting IO server") self._stop_event = Event() + reader = StdinAsyncReader(stdin or sys.stdin.buffer, self.thread_pool) transport = StdOutTransportAdapter( stdin or sys.stdin.buffer, stdout or sys.stdout.buffer ) self.protocol.connection_made(transport) # type: ignore[arg-type] try: - self.loop.run_until_complete( - aio_readline( - self.loop, - self.thread_pool, - self._stop_event, - stdin or sys.stdin.buffer, - self.protocol.data_received, + asyncio.run( + run_async( + stop_event=self._stop_event, + reader=reader, + protocol=self.protocol, + logger=logger, + error_handler=self.report_server_error, ) ) except BrokenPipeError: diff --git a/tests/servers/invalid_json.py b/tests/servers/invalid_json.py index f2c30da3..dccd2fff 100644 --- a/tests/servers/invalid_json.py +++ b/tests/servers/invalid_json.py @@ -1,28 +1,28 @@ """This server does nothing but print invalid JSON.""" -import asyncio -import threading import sys -from concurrent.futures import ThreadPoolExecutor +import threading + +from pygls.io_ import run +from pygls.protocol import JsonRPCProtocol, default_converter -from pygls.server import aio_readline +class InvalidJsonProtocol(JsonRPCProtocol): + """A protocol that only sends messages containing invalid JSON.""" -def handler(data): - content = 'Content-Length: 5\r\n\r\n{"ll}'.encode("utf8") - sys.stdout.buffer.write(content) - sys.stdout.flush() + def handle_message(self, message): + content = 'Content-Length: 5\r\n\r\n{"ll}'.encode("utf8") + sys.stdout.buffer.write(content) + sys.stdout.flush() -async def main(): - await aio_readline( - asyncio.get_running_loop(), - ThreadPoolExecutor(), +def main(): + run( threading.Event(), sys.stdin.buffer, - handler, + InvalidJsonProtocol(None, default_converter()), ) if __name__ == "__main__": - asyncio.run(main()) + main()