diff --git a/poetry.lock b/poetry.lock index 60456a0e..e939750b 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand. [[package]] name = "alabaster" @@ -436,13 +436,13 @@ i18n = ["Babel (>=2.7)"] [[package]] name = "lsprotocol" -version = "2023.0.0a3" +version = "2023.0.0b1" description = "Python implementation of the Language Server Protocol." optional = false python-versions = ">=3.7" files = [ - {file = "lsprotocol-2023.0.0a3-py3-none-any.whl", hash = "sha256:2896c5a30c34846e3d5687e35715961f49bf7b92a36e4fb2b707ff65f19087f7"}, - {file = "lsprotocol-2023.0.0a3.tar.gz", hash = "sha256:d704e4e00419f74bece9795de4b34d02aa555fc0131fec49f59ac9eb46816e51"}, + {file = "lsprotocol-2023.0.0b1-py3-none-any.whl", hash = "sha256:ade2cd0fa0ede7965698cb59cd05d3adbd19178fd73e83f72ef57a032fbb9d62"}, + {file = "lsprotocol-2023.0.0b1.tar.gz", hash = "sha256:f7a2d4655cbd5639f373ddd1789807450c543341fa0a32b064ad30dbb9f510d4"}, ] [package.dependencies] @@ -761,13 +761,13 @@ testing = ["coverage (>=6.2)", "flaky (>=3.5.0)", "hypothesis (>=5.7.1)", "mypy [[package]] name = "pytz" -version = "2023.3" +version = "2023.3.post1" description = "World timezone definitions, modern and historical" optional = false python-versions = "*" files = [ - {file = "pytz-2023.3-py2.py3-none-any.whl", hash = "sha256:a151b3abb88eda1d4e34a9814df37de2a80e301e68ba0fd856fb9b46bfbbbffb"}, - {file = "pytz-2023.3.tar.gz", hash = "sha256:1d8ce29db189191fb55338ee6d0387d82ab59f3d00eac103412d64e0ebd0c588"}, + {file = "pytz-2023.3.post1-py2.py3-none-any.whl", hash = "sha256:ce42d816b81b68506614c11e8937d3aa9e41007ceb50bfdcb0749b921bf646c7"}, + {file = "pytz-2023.3.post1.tar.gz", hash = "sha256:7b4fddbeb94a1eba4b557da24f19fdf9db575192544270a9101d8509f9f43d7b"}, ] [[package]] @@ -1057,13 +1057,13 @@ sortedcontainers = "*" [[package]] name = "trio-websocket" -version = "0.10.3" +version = "0.10.4" description = "WebSocket library for Trio" optional = false python-versions = ">=3.7" files = [ - {file = "trio-websocket-0.10.3.tar.gz", hash = "sha256:1a748604ad906a7dcab9a43c6eb5681e37de4793ba0847ef0bc9486933ed027b"}, - {file = "trio_websocket-0.10.3-py3-none-any.whl", hash = "sha256:a9937d48e8132ebf833019efde2a52ca82d223a30a7ea3e8d60a7d28f75a4e3a"}, + {file = "trio-websocket-0.10.4.tar.gz", hash = "sha256:e66b3db3e2453017431dfbd352081006654e1241c2a6800dc2f43d7df54d55c5"}, + {file = "trio_websocket-0.10.4-py3-none-any.whl", hash = "sha256:c7a620c4013c34b7e4477d89fe76695da1e455e4510a8d7ae13f81c632bdce1d"}, ] [package.dependencies] @@ -1153,13 +1153,13 @@ files = [ [[package]] name = "urllib3" -version = "2.0.4" +version = "2.0.5" description = "HTTP library with thread-safe connection pooling, file post, and more." optional = false python-versions = ">=3.7" files = [ - {file = "urllib3-2.0.4-py3-none-any.whl", hash = "sha256:de7df1803967d2c2a98e4b11bb7d6bd9210474c46e8a0401514e3a42a75ebde4"}, - {file = "urllib3-2.0.4.tar.gz", hash = "sha256:8d22f86aae8ef5e410d4f539fde9ce6b2113a001bb4d189e0aed70642d602b11"}, + {file = "urllib3-2.0.5-py3-none-any.whl", hash = "sha256:ef16afa8ba34a1f989db38e1dbbe0c302e4289a47856990d0682e374563ce35e"}, + {file = "urllib3-2.0.5.tar.gz", hash = "sha256:13abf37382ea2ce6fb744d4dad67838eec857c9f4f57009891805e0b5e123594"}, ] [package.dependencies] @@ -1285,4 +1285,4 @@ ws = ["websockets"] [metadata] lock-version = "2.0" python-versions = ">=3.7.9,<4" -content-hash = "27ee5cd8b82f9c490eed22daa698893ebaed1bd98c080ad314545e168298b6e9" +content-hash = "36d67b26c8878f526a5d76cf5f7eb8173f55b35bec57b164972a8b8ea1d98b8c" diff --git a/pygls/capabilities.py b/pygls/capabilities.py index c36832b0..852b5cf2 100644 --- a/pygls/capabilities.py +++ b/pygls/capabilities.py @@ -15,7 +15,8 @@ # limitations under the License. # ############################################################################ from functools import reduce -from typing import Any, Dict, List, Set, Union +from typing import Any, Dict, List, Optional, Set, Union +import logging from lsprotocol.types import ( INLAY_HINT_RESOLVE, @@ -64,6 +65,7 @@ WORKSPACE_WILL_DELETE_FILES, WORKSPACE_WILL_RENAME_FILES, InlayHintOptions, + PositionEncodingKind, ) from lsprotocol.types import ( ClientCapabilities, @@ -86,6 +88,8 @@ WorkspaceFoldersServerCapabilities, ) +logger = logging.getLogger(__name__) + def get_capability( client_capabilities: ClientCapabilities, field: str, default: Any = None @@ -115,7 +119,7 @@ def __init__( feature_options: Dict[str, Any], commands: List[str], text_document_sync_kind: TextDocumentSyncKind, - notebook_document_sync: NotebookDocumentSyncOptions, + notebook_document_sync: Optional[NotebookDocumentSyncOptions] = None, ): self.client_capabilities = client_capabilities self.features = features @@ -429,6 +433,32 @@ def _with_inline_value_provider(self): self.server_cap.inline_value_provider = value return self + def _with_position_encodings(self): + self.server_cap.position_encoding = PositionEncodingKind.Utf16 + + general = self.client_capabilities.general + if general is None: + return self + + encodings = general.position_encodings + if encodings is None: + return self + + if PositionEncodingKind.Utf16 in encodings: + return self + + if PositionEncodingKind.Utf32 in encodings: + self.server_cap.position_encoding = PositionEncodingKind.Utf32 + return self + + if PositionEncodingKind.Utf8 in encodings: + self.server_cap.position_encoding = PositionEncodingKind.Utf8 + return self + + logger.warning(f"Unknown `PositionEncoding`s: {encodings}") + + return self + def _build(self): return self.server_cap @@ -467,5 +497,6 @@ def build(self): ._with_workspace_capabilities() ._with_diagnostic_provider() ._with_inline_value_provider() + ._with_position_encodings() ._build() ) diff --git a/pygls/client.py b/pygls/client.py index b2ea96bd..577f05e0 100644 --- a/pygls/client.py +++ b/pygls/client.py @@ -79,7 +79,10 @@ def __init__( protocol_cls: Type[JsonRPCProtocol] = JsonRPCProtocol, converter_factory: Callable[[], Converter] = default_converter, ): - self.protocol = protocol_cls(self, converter_factory()) + # Strictly speaking `JsonRPCProtocol` wants a `LanguageServer`, not a + # `JsonRPCClient`. However there similar enough for our purposes, which is + # that this client will mostly be used in testing contexts. + self.protocol = protocol_cls(self, converter_factory()) # type: ignore self._server: Optional[asyncio.subprocess.Process] = None self._stop_event = Event() diff --git a/pygls/protocol.py b/pygls/protocol.py index b15757b0..3d8c6e03 100644 --- a/pygls/protocol.py +++ b/pygls/protocol.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # # limitations under the License. # ############################################################################ +from __future__ import annotations import asyncio import enum import functools @@ -27,7 +28,21 @@ from concurrent.futures import Future from functools import lru_cache, partial from itertools import zip_longest -from typing import Any, Callable, List, Optional, Type, TypeVar, Union +from typing import ( + Any, + Callable, + Dict, + List, + Optional, + Type, + TypeVar, + Union, + TYPE_CHECKING, +) + +if TYPE_CHECKING: + from pygls.server import LanguageServer, WebSocketTransportAdapter + import attrs from cattrs.errors import ClassValidationError @@ -239,19 +254,21 @@ class JsonRPCProtocol(asyncio.Protocol): VERSION = "2.0" - def __init__(self, server, converter): + def __init__(self, server: LanguageServer, converter): self._server = server self._converter = converter self._shutdown = False # Book keeping for in-flight requests - self._request_futures = {} - self._result_types = {} + self._request_futures: Dict[str, Future[Any]] = {} + self._result_types: Dict[str, Any] = {} self.fm = FeatureManager(server) - self.transport = None - self._message_buf = [] + self.transport: Optional[ + Union[asyncio.WriteTransport, WebSocketTransportAdapter] + ] = None + self._message_buf: List[bytes] = [] self._send_only_body = False @@ -504,7 +521,9 @@ def _send_data(self, data): logger.info("Sending data: %s", body) if self._send_only_body: - self.transport.write(body) + # Mypy/Pyright seem to think `write()` wants `"bytes | bytearray | memoryview"` + # But runtime errors with anything but `str`. + self.transport.write(body) # type: ignore return header = ( @@ -544,7 +563,10 @@ def connection_lost(self, exc): logger.error("Connection to the client is lost! Shutting down the server.") sys.exit(1) - def connection_made(self, transport: asyncio.BaseTransport): + def connection_made( # type: ignore # see: https://github.com/python/typeshed/issues/3021 + self, + transport: asyncio.Transport, + ): """Method from base class, called when connection is established""" self.transport = transport @@ -805,12 +827,17 @@ def lsp_initialize(self, params: InitializeParams) -> InitializeResult: ) root_path = params.root_path - root_uri = params.root_uri or from_fs_path(root_path) + root_uri = params.root_uri + if root_path is not None and root_uri is None: + root_uri = from_fs_path(root_path) # Initialize the workspace workspace_folders = params.workspace_folders or [] self._workspace = Workspace( - root_uri, text_document_sync_kind, workspace_folders + root_uri, + text_document_sync_kind, + workspace_folders, + self.server_capabilities.position_encoding, ) self.trace = TraceValues.Off diff --git a/pygls/server.py b/pygls/server.py index b26c9a4b..7717b84e 100644 --- a/pygls/server.py +++ b/pygls/server.py @@ -35,7 +35,13 @@ import cattrs from pygls import IS_PYODIDE from pygls.lsp import ConfigCallbackType, ShowDocumentCallbackType -from pygls.exceptions import PyglsError, JsonRpcException, FeatureRequestError +from pygls.exceptions import ( + FeatureNotificationError, + JsonRpcInternalError, + PyglsError, + JsonRpcException, + FeatureRequestError, +) from lsprotocol.types import ( ClientCapabilities, Diagnostic, @@ -62,6 +68,14 @@ F = TypeVar("F", bound=Callable) +ServerErrors = Union[ + PyglsError, + JsonRpcException, + Type[JsonRpcInternalError], + Type[FeatureNotificationError], + Type[FeatureRequestError], +] + async def aio_readline(loop, executor, stop_event, rfile, proxy): """Reads data from stdin in separate thread (asynchronously).""" @@ -204,7 +218,9 @@ def __init__( self._owns_loop = False self.loop = loop - self.lsp = protocol_cls(self, converter_factory()) + + # TODO: Will move this to `LanguageServer` soon + self.lsp = protocol_cls(self, converter_factory()) # type: ignore def shutdown(self): """Shutdown server.""" @@ -404,6 +420,7 @@ def __init__( self.version = version self._text_document_sync_kind = text_document_sync_kind self._notebook_document_sync = notebook_document_sync + self.process_id: Optional[Union[int, None]] = None super().__init__(protocol_cls, converter_factory, loop, max_workers) def apply_edit( @@ -541,7 +558,9 @@ def show_message_log(self, message, msg_type=MessageType.Log) -> None: self.lsp.show_message_log(message, msg_type) def _report_server_error( - self, error: Exception, source: Union[PyglsError, JsonRpcException] + self, + error: Exception, + source: ServerErrors, ): # Prevent recursive error reporting try: @@ -549,9 +568,7 @@ def _report_server_error( except Exception: logger.warning("Failed to report error to client") - def report_server_error( - self, error: Exception, source: Union[PyglsError, JsonRpcException] - ): + def report_server_error(self, error: Exception, source: ServerErrors): """ Sends error to the client for displaying. diff --git a/pygls/uris.py b/pygls/uris.py index 2b9997db..8c40f70b 100644 --- a/pygls/uris.py +++ b/pygls/uris.py @@ -16,10 +16,13 @@ # See the License for the specific language governing permissions and # # limitations under the License. # ############################################################################ -"""A collection of URI utilities with logic built on the VSCode URI library. +""" +A collection of URI utilities with logic built on the VSCode URI library. https://github.com/Microsoft/vscode-uri/blob/e59cab84f5df6265aed18ae5f43552d3eef13bb9/lib/index.ts """ +from typing import Optional, Tuple + import re from urllib import parse @@ -27,8 +30,10 @@ RE_DRIVE_LETTER_PATH = re.compile(r"^\/[a-zA-Z]:") +URLParts = Tuple[str, str, str, str, str, str] -def _normalize_win_path(path): + +def _normalize_win_path(path: str): netloc = "" # normalize to fwd-slashes on windows, @@ -59,7 +64,7 @@ def _normalize_win_path(path): return path, netloc -def from_fs_path(path): +def from_fs_path(path: str): """Returns a URI for the given filesystem path.""" try: scheme = "file" @@ -70,8 +75,9 @@ def from_fs_path(path): return None -def to_fs_path(uri): - """Returns the filesystem path of the given URI. +def to_fs_path(uri: str): + """ + Returns the filesystem path of the given URI. Will handle UNC paths and normalize windows drive letters to lower-case. Also uses the platform specific path separator. Will *not* validate the @@ -80,7 +86,7 @@ def to_fs_path(uri): """ try: # scheme://netloc/path;parameters?query#fragment - scheme, netloc, path, _params, _query, _fragment = urlparse(uri) + scheme, netloc, path, _, _, _ = urlparse(uri) if netloc and path and scheme == "file": # unc path: file://shares/c$/far/boo @@ -102,25 +108,35 @@ def to_fs_path(uri): return None -def uri_scheme(uri): +def uri_scheme(uri: str): try: return urlparse(uri)[0] except (TypeError, IndexError): return None +# TODO: Use `URLParts` type def uri_with( - uri, scheme=None, netloc=None, path=None, params=None, query=None, fragment=None + uri: str, + scheme: Optional[str] = None, + netloc: Optional[str] = None, + path: Optional[str] = None, + params: Optional[str] = None, + query: Optional[str] = None, + fragment: Optional[str] = None, ): - """Return a URI with the given part(s) replaced. - + """ + Return a URI with the given part(s) replaced. Parts are decoded / encoded. """ old_scheme, old_netloc, old_path, old_params, old_query, old_fragment = urlparse( uri ) - path, _netloc = _normalize_win_path(path) + if path is None: + raise Exception("`path` must not be None") + + path, _ = _normalize_win_path(path) return urlunparse( ( scheme or old_scheme, @@ -133,7 +149,7 @@ def uri_with( ) -def urlparse(uri): +def urlparse(uri: str): """Parse and decode the parts of a URI.""" scheme, netloc, path, params, query, fragment = parse.urlparse(uri) return ( @@ -146,7 +162,7 @@ def urlparse(uri): ) -def urlunparse(parts): +def urlunparse(parts: URLParts) -> str: """Unparse and encode parts of a URI.""" scheme, netloc, path, params, query, fragment = parts diff --git a/pygls/workspace.py b/pygls/workspace.py deleted file mode 100644 index fad94ef6..00000000 --- a/pygls/workspace.py +++ /dev/null @@ -1,648 +0,0 @@ -############################################################################ -# Original work Copyright 2017 Palantir Technologies, Inc. # -# Original work licensed under the MIT License. # -# See ThirdPartyNotices.txt in the project root for license information. # -# All modifications Copyright (c) Open Law Library. All rights reserved. # -# # -# Licensed under the Apache License, Version 2.0 (the "License") # -# you may not use this file except in compliance with the License. # -# You may obtain a copy of the License at # -# # -# http: // www.apache.org/licenses/LICENSE-2.0 # -# # -# Unless required by applicable law or agreed to in writing, software # -# distributed under the License is distributed on an "AS IS" BASIS, # -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # -# See the License for the specific language governing permissions and # -# limitations under the License. # -############################################################################ -import copy -import io -import logging -import os -import re -import warnings -from typing import Dict, List, Optional, Pattern - -from lsprotocol import types - -from pygls.uris import to_fs_path, uri_scheme - -# TODO: this is not the best e.g. we capture numbers -RE_END_WORD = re.compile("^[A-Za-z_0-9]*") -RE_START_WORD = re.compile("[A-Za-z_0-9]*$") - -logger = logging.getLogger(__name__) - - -def is_char_beyond_multilingual_plane(char: str) -> bool: - return ord(char) > 0xFFFF - - -def utf16_unit_offset(chars: str): - """Calculate the number of characters which need two utf-16 code units. - - Arguments: - chars (str): The string to count occurrences of utf-16 code units for. - """ - return sum(is_char_beyond_multilingual_plane(ch) for ch in chars) - - -def utf16_num_units(chars: str): - """Calculate the length of `str` in utf-16 code units. - - Arguments: - chars (str): The string to return the length in utf-16 code units for. - """ - return len(chars) + utf16_unit_offset(chars) - - -def position_from_utf16(lines: List[str], position: types.Position) -> types.Position: - """Convert the position.character from utf-16 code units to utf-32. - - A python application can't use the character member of `Position` - directly. As per specification it is represented as a zero-based line and - character offset based on a UTF-16 string representation. - - All characters whose code point exceeds the Basic Multilingual Plane are - represented by 2 UTF-16 code units. - - The offset of the closing quotation mark in x="😋" is - - 5 in UTF-16 representation - - 4 in UTF-32 representation - - see: https://github.com/microsoft/language-server-protocol/issues/376 - - Arguments: - lines (list): - The content of the document which the position refers to. - position (Position): - The line and character offset in utf-16 code units. - - Returns: - The position with `character` being converted to utf-32 code units. - """ - if len(lines) == 0: - return types.Position(0, 0) - if position.line >= len(lines): - return types.Position(len(lines) - 1, utf16_num_units(lines[-1])) - - _line = lines[position.line] - _line = _line.replace("\r\n", "\n") # TODO: it's a bit of a hack - _utf16_len = utf16_num_units(_line) - _utf32_len = len(_line) - - if _utf16_len == 0: - return types.Position(position.line, 0) - - _utf16_end_of_line = utf16_num_units(_line) - if position.character > _utf16_end_of_line: - position.character = _utf16_end_of_line - 1 - - _utf16_index = 0 - utf32_index = 0 - while True: - _is_searching_queried_position = _utf16_index < position.character - _is_before_end_of_line = utf32_index < _utf32_len - _is_searching_for_position = ( - _is_searching_queried_position and _is_before_end_of_line - ) - if not _is_searching_for_position: - break - - _current_char = _line[utf32_index] - _is_double_width = is_char_beyond_multilingual_plane(_current_char) - if _is_double_width: - _utf16_index += 2 - else: - _utf16_index += 1 - utf32_index += 1 - - position = types.Position(line=position.line, character=utf32_index) - return position - - -def position_to_utf16(lines: List[str], position: types.Position) -> types.Position: - """Convert the position.character from utf-32 to utf-16 code units. - - A python application can't use the character member of `Position` - directly as per specification it is represented as a zero-based line and - character offset based on a UTF-16 string representation. - - All characters whose code point exceeds the Basic Multilingual Plane are - represented by 2 UTF-16 code units. - - The offset of the closing quotation mark in x="😋" is - - 5 in UTF-16 representation - - 4 in UTF-32 representation - - see: https://github.com/microsoft/language-server-protocol/issues/376 - - Arguments: - lines (list): - The content of the document which the position refers to. - position (Position): - The line and character offset in utf-32 code units. - - Returns: - The position with `character` being converted to utf-16 code units. - """ - try: - return types.Position( - line=position.line, - character=position.character - + utf16_unit_offset(lines[position.line][: position.character]), - ) - except IndexError: - return types.Position(line=len(lines), character=0) - - -def range_from_utf16(lines: List[str], range: types.Range) -> types.Range: - """Convert range.[start|end].character from utf-16 code units to utf-32. - - Arguments: - lines (list): - The content of the document which the range refers to. - range (Range): - The line and character offset in utf-32 code units. - - Returns: - The range with `character` offsets being converted to utf-16 code units. - """ - range_new = types.Range( - start=position_from_utf16(lines, range.start), - end=position_from_utf16(lines, range.end), - ) - return range_new - - -def range_to_utf16(lines: List[str], range: types.Range) -> types.Range: - """Convert range.[start|end].character from utf-32 to utf-16 code units. - - Arguments: - lines (list): - The content of the document which the range refers to. - range (Range): - The line and character offset in utf-16 code units. - - Returns: - The range with `character` offsets being converted to utf-32 code units. - """ - return types.Range( - start=position_to_utf16(lines, range.start), - end=position_to_utf16(lines, range.end), - ) - - -class TextDocument(object): - def __init__( - self, - uri: str, - source: Optional[str] = None, - version: Optional[int] = None, - language_id: Optional[str] = None, - local: bool = True, - sync_kind: types.TextDocumentSyncKind = types.TextDocumentSyncKind.Incremental, - ): - self.uri = uri - self.version = version - self.path = to_fs_path(uri) - self.language_id = language_id - self.filename = os.path.basename(self.path) - - self._local = local - self._source = source - - self._is_sync_kind_full = sync_kind == types.TextDocumentSyncKind.Full - self._is_sync_kind_incremental = ( - sync_kind == types.TextDocumentSyncKind.Incremental - ) - self._is_sync_kind_none = sync_kind == types.TextDocumentSyncKind.None_ - - def __str__(self): - return str(self.uri) - - def _apply_incremental_change( - self, change: types.TextDocumentContentChangeEvent_Type1 - ) -> None: - """Apply an ``Incremental`` text change to the document""" - lines = self.lines - text = change.text - change_range = change.range - - range = range_from_utf16(lines, change_range) # type: ignore - start_line = range.start.line - start_col = range.start.character - end_line = range.end.line - end_col = range.end.character - - # Check for an edit occurring at the very end of the file - if start_line == len(lines): - self._source = self.source + text - return - - new = io.StringIO() - - # Iterate over the existing document until we hit the edit range, - # at which point we write the new text, then loop until we hit - # the end of the range and continue writing. - for i, line in enumerate(lines): - if i < start_line: - new.write(line) - continue - - if i > end_line: - new.write(line) - continue - - if i == start_line: - new.write(line[:start_col]) - new.write(text) - - if i == end_line: - new.write(line[end_col:]) - - self._source = new.getvalue() - - def _apply_full_change(self, change: types.TextDocumentContentChangeEvent) -> None: - """Apply a ``Full`` text change to the document.""" - self._source = change.text - - def _apply_none_change(self, change: types.TextDocumentContentChangeEvent) -> None: - """Apply a ``None`` text change to the document - - Currently does nothing, provided for consistency. - """ - pass - - def apply_change(self, change: types.TextDocumentContentChangeEvent) -> None: - """Apply a text change to a document, considering TextDocumentSyncKind - - Performs either - :attr:`~lsprotocol.types.TextDocumentSyncKind.Incremental`, - :attr:`~lsprotocol.types.TextDocumentSyncKind.Full`, or no synchronization - based on both the client request and server capabilities. - - .. admonition:: ``Incremental`` versus ``Full`` synchronization - - Even if a server accepts ``Incremantal`` SyncKinds, clients may request - a ``Full`` SyncKind. In LSP 3.x, clients make this request by omitting - both Range and RangeLength from their request. Consequently, the - attributes "range" and "rangeLength" will be missing from ``Full`` - content update client requests in the pygls Python library. - - """ - if isinstance(change, types.TextDocumentContentChangeEvent_Type1): - if self._is_sync_kind_incremental: - self._apply_incremental_change(change) - return - # Log an error, but still perform full update to preserve existing - # assumptions in test_document/test_document_full_edit. Test breaks - # otherwise, and fixing the tests would require a broader fix to - # protocol.py. - logger.error( - "Unsupported client-provided TextDocumentContentChangeEvent. " - "Please update / submit a Pull Request to your LSP client." - ) - - if self._is_sync_kind_none: - self._apply_none_change(change) - else: - self._apply_full_change(change) - - @property - def lines(self) -> List[str]: - return self.source.splitlines(True) - - def offset_at_position(self, position: types.Position) -> int: - """Return the character offset pointed at by the given position.""" - lines = self.lines - pos = position_from_utf16(lines, position) - row, col = pos.line, pos.character - return col + sum(utf16_num_units(line) for line in lines[:row]) - - @property - def source(self) -> str: - if self._source is None: - with io.open(self.path, "r", encoding="utf-8") as f: - return f.read() - return self._source - - def word_at_position( - self, - position: types.Position, - re_start_word: Pattern = RE_START_WORD, - re_end_word: Pattern = RE_END_WORD, - ) -> str: - """Return the word at position. - - The word is constructed in two halves, the first half is found by taking - the first match of ``re_start_word`` on the line up until - ``position.character``. - - The second half is found by taking ``position.character`` up until the - last match of ``re_end_word`` on the line. - - :func:`python:re.findall` is used to find the matches. - - Parameters - ---------- - position - The line and character offset. - - re_start_word - The regular expression for extracting the word backward from - position. The default pattern is ``[A-Za-z_0-9]*$``. - - re_end_word - The regular expression for extracting the word forward from - position. The default pattern is ``^[A-Za-z_0-9]*``. - - Returns - ------- - str - The word (obtained by concatenating the two matches) at position. - """ - lines = self.lines - if position.line >= len(lines): - return "" - - pos = position_from_utf16(lines, position) - row, col = pos.line, pos.character - line = lines[row] - # Split word in two - start = line[:col] - end = line[col:] - - # Take end of start and start of end to find word - # These are guaranteed to match, even if they match the empty string - m_start = re_start_word.findall(start) - m_end = re_end_word.findall(end) - - return m_start[0] + m_end[-1] - - -# For backwards compatibility -Document = TextDocument - - -class Workspace(object): - def __init__(self, root_uri, sync_kind=None, workspace_folders=None): - self._root_uri = root_uri - self._root_uri_scheme = uri_scheme(self._root_uri) - self._root_path = to_fs_path(self._root_uri) - self._sync_kind = sync_kind - self._folders = {} - self._text_documents: Dict[str, TextDocument] = {} - self._notebook_documents: Dict[str, types.NotebookDocument] = {} - - # Used to lookup notebooks which contain a given cell. - self._cell_in_notebook: Dict[str, str] = {} - - if workspace_folders is not None: - for folder in workspace_folders: - self.add_folder(folder) - - def _create_text_document( - self, - doc_uri: str, - source: Optional[str] = None, - version: Optional[int] = None, - language_id: Optional[str] = None, - ) -> TextDocument: - return TextDocument( - doc_uri, - source=source, - version=version, - language_id=language_id, - sync_kind=self._sync_kind, - ) - - def add_folder(self, folder: types.WorkspaceFolder): - self._folders[folder.uri] = folder - - @property - def documents(self): - warnings.warn( - "'workspace.documents' has been deprecated, use " - "'workspace.text_documents' instead", - DeprecationWarning, - stacklevel=2, - ) - return self.text_documents - - @property - def notebook_documents(self): - return self._notebook_documents - - @property - def text_documents(self): - return self._text_documents - - @property - def folders(self): - return self._folders - - def get_notebook_document( - self, *, notebook_uri: Optional[str] = None, cell_uri: Optional[str] = None - ) -> Optional[types.NotebookDocument]: - """Return the notebook corresponding with the given uri. - - If both ``notebook_uri`` and ``cell_uri`` are given, ``notebook_uri`` takes - precedence. - - Parameters - ---------- - notebook_uri - If given, return the notebook document with the given uri. - - cell_uri - If given, return the notebook document which contains a cell with the - given uri - - Returns - ------- - Optional[NotebookDocument] - The requested notebook document if found, ``None`` otherwise. - """ - if notebook_uri is not None: - return self._notebook_documents.get(notebook_uri) - - if cell_uri is not None: - notebook_uri = self._cell_in_notebook.get(cell_uri) - if notebook_uri is None: - return None - - return self._notebook_documents.get(notebook_uri) - - return None - - def get_text_document(self, doc_uri: str) -> TextDocument: - """ - Return a managed document if-present, - else create one pointing at disk. - - See https://github.com/Microsoft/language-server-protocol/issues/177 - """ - return self._text_documents.get(doc_uri) or self._create_text_document(doc_uri) - - def is_local(self): - return ( - self._root_uri_scheme == "" or self._root_uri_scheme == "file" - ) and os.path.exists(self._root_path) - - def put_notebook_document(self, params: types.DidOpenNotebookDocumentParams): - notebook = params.notebook_document - - # Create a fresh instance to ensure our copy cannot be accidentally modified. - self._notebook_documents[notebook.uri] = copy.deepcopy(notebook) - - for cell_document in params.cell_text_documents: - self.put_text_document(cell_document, notebook_uri=notebook.uri) - - def put_text_document( - self, - text_document: types.TextDocumentItem, - notebook_uri: Optional[str] = None, - ): - """Add a text document to the workspace. - - Parameters - ---------- - text_document - The text document to add - - notebook_uri - If set, indicates that this text document represents a cell in a notebook - document - """ - doc_uri = text_document.uri - - self._text_documents[doc_uri] = self._create_text_document( - doc_uri, - source=text_document.text, - version=text_document.version, - language_id=text_document.language_id, - ) - - if notebook_uri: - self._cell_in_notebook[doc_uri] = notebook_uri - - def remove_notebook_document(self, params: types.DidCloseNotebookDocumentParams): - notebook_uri = params.notebook_document.uri - self._notebook_documents.pop(notebook_uri, None) - - for cell_document in params.cell_text_documents: - self.remove_text_document(cell_document.uri) - - def remove_text_document(self, doc_uri: str): - self._text_documents.pop(doc_uri, None) - self._cell_in_notebook.pop(doc_uri, None) - - def remove_folder(self, folder_uri: str): - self._folders.pop(folder_uri, None) - try: - del self._folders[folder_uri] - except KeyError: - pass - - @property - def root_path(self): - return self._root_path - - @property - def root_uri(self): - return self._root_uri - - def update_notebook_document(self, params: types.DidChangeNotebookDocumentParams): - uri = params.notebook_document.uri - notebook = self._notebook_documents[uri] - notebook.version = params.notebook_document.version - - if params.change.metadata: - notebook.metadata = params.change.metadata - - cell_changes = params.change.cells - if cell_changes is None: - return - - # Process changes to any cell metadata. - nb_cells = {cell.document: cell for cell in notebook.cells} - for new_data in cell_changes.data or []: - nb_cell = nb_cells.get(new_data.document) - if nb_cell is None: - logger.warning( - "Ignoring metadata for '%s': not in notebook.", new_data.document - ) - continue - - nb_cell.kind = new_data.kind - nb_cell.metadata = new_data.metadata - nb_cell.execution_summary = new_data.execution_summary - - # Process changes to the notebook's structure - structure = cell_changes.structure - if structure: - cells = notebook.cells - new_cells = structure.array.cells or [] - - # Re-order the cells - before = cells[: structure.array.start] - after = cells[(structure.array.start + structure.array.delete_count) :] - notebook.cells = [*before, *new_cells, *after] - - for new_cell in structure.did_open or []: - self.put_text_document(new_cell, notebook_uri=uri) - - for removed_cell in structure.did_close or []: - self.remove_text_document(removed_cell.uri) - - # Process changes to the text content of existing cells. - for text in cell_changes.text_content or []: - for change in text.changes: - self.update_text_document(text.document, change) - - def update_text_document( - self, - text_doc: types.VersionedTextDocumentIdentifier, - change: types.TextDocumentContentChangeEvent, - ): - doc_uri = text_doc.uri - self._text_documents[doc_uri].apply_change(change) - self._text_documents[doc_uri].version = text_doc.version - - def get_document(self, *args, **kwargs): - warnings.warn( - "'workspace.get_document' has been deprecated, use " - "'workspace.get_text_document' instead", - DeprecationWarning, - stacklevel=2, - ) - return self.get_text_document(*args, **kwargs) - - def remove_document(self, *args, **kwargs): - warnings.warn( - "'workspace.remove_document' has been deprecated, use " - "'workspace.remove_text_document' instead", - DeprecationWarning, - stacklevel=2, - ) - return self.remove_text_document(*args, **kwargs) - - def put_document(self, *args, **kwargs): - warnings.warn( - "'workspace.put_document' has been deprecated, use " - "'workspace.put_text_document' instead", - DeprecationWarning, - stacklevel=2, - ) - return self.put_text_document(*args, **kwargs) - - def update_document(self, *args, **kwargs): - warnings.warn( - "'workspace.update_document' has been deprecated, use " - "'workspace.update_text_document' instead", - DeprecationWarning, - stacklevel=2, - ) - return self.update_text_document(*args, **kwargs) diff --git a/pygls/workspace/__init__.py b/pygls/workspace/__init__.py new file mode 100644 index 00000000..afa25901 --- /dev/null +++ b/pygls/workspace/__init__.py @@ -0,0 +1,81 @@ +from typing import List +import warnings + +from lsprotocol import types + +from .workspace import Workspace +from .text_document import TextDocument +from .position import Position + +Workspace = Workspace +TextDocument = TextDocument +Position = Position + +# For backwards compatibility +Document = TextDocument + + +def utf16_unit_offset(chars: str): + warnings.warn( + "'utf16_unit_offset' has been deprecated, use " + "'Position.utf16_unit_offset' instead", + DeprecationWarning, + stacklevel=2, + ) + _position = Position() + return _position.utf16_unit_offset(chars) + + +def utf16_num_units(chars: str): + warnings.warn( + "'utf16_num_units' has been deprecated, use " + "'Position.client_num_units' instead", + DeprecationWarning, + stacklevel=2, + ) + _position = Position() + return _position.client_num_units(chars) + + +def position_from_utf16(lines: List[str], position: types.Position): + warnings.warn( + "'position_from_utf16' has been deprecated, use " + "'Position.position_from_client_units' instead", + DeprecationWarning, + stacklevel=2, + ) + _position = Position() + return _position.position_from_client_units(lines, position) + + +def position_to_utf16(lines: List[str], position: types.Position): + warnings.warn( + "'position_to_utf16' has been deprecated, use " + "'Position.position_to_client_units' instead", + DeprecationWarning, + stacklevel=2, + ) + _position = Position() + return _position.position_to_client_units(lines, position) + + +def range_from_utf16(lines: List[str], range: types.Range): + warnings.warn( + "'range_from_utf16' has been deprecated, use " + "'Position.range_from_client_units' instead", + DeprecationWarning, + stacklevel=2, + ) + _position = Position() + return _position.range_from_client_units(lines, range) + + +def range_to_utf16(lines: List[str], range: types.Range): + warnings.warn( + "'range_to_utf16' has been deprecated, use " + "'Position.range_to_client_units' instead", + DeprecationWarning, + stacklevel=2, + ) + _position = Position() + return _position.range_to_client_units(lines, range) diff --git a/pygls/workspace/position.py b/pygls/workspace/position.py new file mode 100644 index 00000000..0f4616d5 --- /dev/null +++ b/pygls/workspace/position.py @@ -0,0 +1,204 @@ +############################################################################ +# Original work Copyright 2017 Palantir Technologies, Inc. # +# Original work licensed under the MIT License. # +# See ThirdPartyNotices.txt in the project root for license information. # +# All modifications Copyright (c) Open Law Library. All rights reserved. # +# # +# Licensed under the Apache License, Version 2.0 (the "License") # +# you may not use this file except in compliance with the License. # +# You may obtain a copy of the License at # +# # +# http: // www.apache.org/licenses/LICENSE-2.0 # +# # +# Unless required by applicable law or agreed to in writing, software # +# distributed under the License is distributed on an "AS IS" BASIS, # +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # +# See the License for the specific language governing permissions and # +# limitations under the License. # +############################################################################ +import logging +from typing import List, Optional, Union + +from lsprotocol import types + + +log = logging.getLogger(__name__) + + +class Position: + def __init__( + self, + encoding: Optional[ + Union[types.PositionEncodingKind, str] + ] = types.PositionEncodingKind.Utf16, + ): + self.encoding = encoding + + @classmethod + def is_char_beyond_multilingual_plane(cls, char: str) -> bool: + return ord(char) > 0xFFFF + + def utf16_unit_offset(self, chars: str): + """ + Calculate the number of characters which need two utf-16 code units. + + Arguments: + chars (str): The string to count occurrences of utf-16 code units for. + """ + return sum(self.is_char_beyond_multilingual_plane(ch) for ch in chars) + + def client_num_units(self, chars: str): + """ + Calculate the length of `str` in utf-16 code units. + + Arguments: + chars (str): The string to return the length in utf-16 code units for. + """ + utf32_units = len(chars) + if self.encoding == types.PositionEncodingKind.Utf32: + return utf32_units + + if self.encoding == types.PositionEncodingKind.Utf8: + return utf32_units + (self.utf16_unit_offset(chars) * 2) + + return utf32_units + self.utf16_unit_offset(chars) + + def position_from_client_units( + self, lines: List[str], position: types.Position + ) -> types.Position: + """ + Convert the position.character from UTF-[32|16|8] code units to UTF-32. + + A python application can't use the character member of `Position` + directly. As per specification it is represented as a zero-based line and + character offset based on posible a UTF-[32|16|8] string representation. + + All characters whose code point exceeds the Basic Multilingual Plane are + represented by 2 UTF-16 or 4 UTF-8 code units. + + The offset of the closing quotation mark in x="😋" is + - 7 in UTF-8 representation + - 5 in UTF-16 representation + - 4 in UTF-32 representation + + see: https://github.com/microsoft/language-server-protocol/issues/376 + + Arguments: + lines (list): + The content of the document which the position refers to. + position (Position): + The line and character offset in UTF-[32|16|8] code units. + + Returns: + The position with `character` being converted to UTF-32 code units. + """ + if len(lines) == 0: + return types.Position(0, 0) + if position.line >= len(lines): + return types.Position(len(lines) - 1, self.client_num_units(lines[-1])) + + _line = lines[position.line] + _line = _line.replace("\r\n", "\n") # TODO: it's a bit of a hack + _client_len = self.client_num_units(_line) + _utf32_len = len(_line) + + if _client_len == 0: + return types.Position(position.line, 0) + + _client_end_of_line = self.client_num_units(_line) + if position.character > _client_end_of_line: + position.character = _client_end_of_line - 1 + + _client_index = 0 + utf32_index = 0 + while True: + _is_searching_queried_position = _client_index < position.character + _is_before_end_of_line = utf32_index < _utf32_len + _is_searching_for_position = ( + _is_searching_queried_position and _is_before_end_of_line + ) + if not _is_searching_for_position: + break + + _current_char = _line[utf32_index] + _is_double_width = Position.is_char_beyond_multilingual_plane(_current_char) + if _is_double_width: + if self.encoding == types.PositionEncodingKind.Utf32: + _client_index += 1 + if self.encoding == types.PositionEncodingKind.Utf8: + _client_index += 4 + _client_index += 2 + else: + _client_index += 1 + utf32_index += 1 + + position = types.Position(line=position.line, character=utf32_index) + return position + + def position_to_client_units( + self, lines: List[str], position: types.Position + ) -> types.Position: + """ + Convert the position.character from its internal UTF-32 representation + to client-supported UTF-[32|16|8] code units. + + Arguments: + lines (list): + The content of the document which the position refers to. + position (Position): + The line and character offset in UTF-32 code units. + + Returns: + The position with `character` being converted to UTF-[32|16|8] code units. + """ + try: + character = self.client_num_units( + lines[position.line][: position.character] + ) + return types.Position( + line=position.line, + character=character, + ) + except IndexError: + return types.Position(line=len(lines), character=0) + + def range_from_client_units( + self, lines: List[str], range: types.Range + ) -> types.Range: + """ + Convert range.[start|end].character from UTF-[32|16|8] code units to UTF-32. + + Arguments: + lines (list): + The content of the document which the range refers to. + range (Range): + The line and character offset in UTF-[32|16|8] code units. + + Returns: + The range with `character` offsets being converted to UTF-32 code units. + """ + range_new = types.Range( + start=self.position_from_client_units(lines, range.start), + end=self.position_from_client_units(lines, range.end), + ) + return range_new + + def range_to_client_units( + self, lines: List[str], range: types.Range + ) -> types.Range: + """ + Convert range.[start|end].character from UTF-32 to UTF-[32|16|8] code units. + + Arguments: + lines (list): + The content of the document which the range refers to. + range (Range): + The line and character offset in code units. + + Returns: + The range with `character` offsets being converted to UTF-[32|16|8] code units. + """ + return types.Range( + start=self.position_to_client_units(lines, range.start), + end=self.position_to_client_units(lines, range.end), + ) diff --git a/pygls/workspace/text_document.py b/pygls/workspace/text_document.py new file mode 100644 index 00000000..27b300ab --- /dev/null +++ b/pygls/workspace/text_document.py @@ -0,0 +1,234 @@ +############################################################################ +# Original work Copyright 2017 Palantir Technologies, Inc. # +# Original work licensed under the MIT License. # +# See ThirdPartyNotices.txt in the project root for license information. # +# All modifications Copyright (c) Open Law Library. All rights reserved. # +# # +# Licensed under the Apache License, Version 2.0 (the "License") # +# you may not use this file except in compliance with the License. # +# You may obtain a copy of the License at # +# # +# http: // www.apache.org/licenses/LICENSE-2.0 # +# # +# Unless required by applicable law or agreed to in writing, software # +# distributed under the License is distributed on an "AS IS" BASIS, # +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # +# See the License for the specific language governing permissions and # +# limitations under the License. # +############################################################################ +import io +import logging +import os +import re +from typing import List, Optional, Pattern, Union + +from lsprotocol import types + +from pygls.uris import to_fs_path +from .position import Position + +# TODO: this is not the best e.g. we capture numbers +RE_END_WORD = re.compile("^[A-Za-z_0-9]*") +RE_START_WORD = re.compile("[A-Za-z_0-9]*$") + +logger = logging.getLogger(__name__) + + +class TextDocument(object): + def __init__( + self, + uri: str, + source: Optional[str] = None, + version: Optional[int] = None, + language_id: Optional[str] = None, + local: bool = True, + sync_kind: types.TextDocumentSyncKind = types.TextDocumentSyncKind.Incremental, + position_encoding: Optional[ + Union[types.PositionEncodingKind, str] + ] = types.PositionEncodingKind.Utf16, + ): + self.uri = uri + self.version = version + path = to_fs_path(uri) + if path is None: + raise Exception("`path` cannot be None") + self.path = path + self.language_id = language_id + self.filename: Optional[str] = os.path.basename(self.path) + + self._local = local + self._source = source + + self._is_sync_kind_full = sync_kind == types.TextDocumentSyncKind.Full + self._is_sync_kind_incremental = ( + sync_kind == types.TextDocumentSyncKind.Incremental + ) + self._is_sync_kind_none = sync_kind == types.TextDocumentSyncKind.None_ + + self.position = Position(encoding=position_encoding) + + def __str__(self): + return str(self.uri) + + def _apply_incremental_change( + self, change: types.TextDocumentContentChangeEvent_Type1 + ) -> None: + """Apply an ``Incremental`` text change to the document""" + lines = self.lines + text = change.text + change_range = change.range + + range = self.position.range_from_client_units(lines, change_range) + start_line = range.start.line + start_col = range.start.character + end_line = range.end.line + end_col = range.end.character + + # Check for an edit occurring at the very end of the file + if start_line == len(lines): + self._source = self.source + text + return + + new = io.StringIO() + + # Iterate over the existing document until we hit the edit range, + # at which point we write the new text, then loop until we hit + # the end of the range and continue writing. + for i, line in enumerate(lines): + if i < start_line: + new.write(line) + continue + + if i > end_line: + new.write(line) + continue + + if i == start_line: + new.write(line[:start_col]) + new.write(text) + + if i == end_line: + new.write(line[end_col:]) + + self._source = new.getvalue() + + def _apply_full_change(self, change: types.TextDocumentContentChangeEvent) -> None: + """Apply a ``Full`` text change to the document.""" + self._source = change.text + + def _apply_none_change(self, _: types.TextDocumentContentChangeEvent) -> None: + """Apply a ``None`` text change to the document + + Currently does nothing, provided for consistency. + """ + pass + + def apply_change(self, change: types.TextDocumentContentChangeEvent) -> None: + """Apply a text change to a document, considering TextDocumentSyncKind + + Performs either + :attr:`~lsprotocol.types.TextDocumentSyncKind.Incremental`, + :attr:`~lsprotocol.types.TextDocumentSyncKind.Full`, or no synchronization + based on both the client request and server capabilities. + + .. admonition:: ``Incremental`` versus ``Full`` synchronization + + Even if a server accepts ``Incremantal`` SyncKinds, clients may request + a ``Full`` SyncKind. In LSP 3.x, clients make this request by omitting + both Range and RangeLength from their request. Consequently, the + attributes "range" and "rangeLength" will be missing from ``Full`` + content update client requests in the pygls Python library. + + """ + if isinstance(change, types.TextDocumentContentChangeEvent_Type1): + if self._is_sync_kind_incremental: + self._apply_incremental_change(change) + return + # Log an error, but still perform full update to preserve existing + # assumptions in test_document/test_document_full_edit. Test breaks + # otherwise, and fixing the tests would require a broader fix to + # protocol.py. + logger.error( + "Unsupported client-provided TextDocumentContentChangeEvent. " + "Please update / submit a Pull Request to your LSP client." + ) + + if self._is_sync_kind_none: + self._apply_none_change(change) + else: + self._apply_full_change(change) + + @property + def lines(self) -> List[str]: + return self.source.splitlines(True) + + def offset_at_position(self, client_position: types.Position) -> int: + """Return the character offset pointed at by the given client_position.""" + lines = self.lines + server_position = self.position.position_from_client_units( + lines, client_position + ) + row, col = server_position.line, server_position.character + return col + sum(self.position.client_num_units(line) for line in lines[:row]) + + @property + def source(self) -> str: + if self._source is None: + with io.open(self.path, "r", encoding="utf-8") as f: + return f.read() + return self._source + + def word_at_position( + self, + client_position: types.Position, + re_start_word: Pattern[str] = RE_START_WORD, + re_end_word: Pattern[str] = RE_END_WORD, + ) -> str: + """Return the word at position. + + The word is constructed in two halves, the first half is found by taking + the first match of ``re_start_word`` on the line up until + ``position.character``. + + The second half is found by taking ``position.character`` up until the + last match of ``re_end_word`` on the line. + + :func:`python:re.findall` is used to find the matches. + + Parameters + ---------- + position + The line and character offset. + + re_start_word + The regular expression for extracting the word backward from + position. The default pattern is ``[A-Za-z_0-9]*$``. + + re_end_word + The regular expression for extracting the word forward from + position. The default pattern is ``^[A-Za-z_0-9]*``. + + Returns + ------- + str + The word (obtained by concatenating the two matches) at position. + """ + lines = self.lines + if client_position.line >= len(lines): + return "" + + server_position = self.position.position_from_client_units( + lines, client_position + ) + row, col = server_position.line, server_position.character + line = lines[row] + # Split word in two + start = line[:col] + end = line[col:] + + # Take end of start and start of end to find word + # These are guaranteed to match, even if they match the empty string + m_start = re_start_word.findall(start) + m_end = re_end_word.findall(end) + + return m_start[0] + m_end[-1] diff --git a/pygls/workspace/workspace.py b/pygls/workspace/workspace.py new file mode 100644 index 00000000..1ae25283 --- /dev/null +++ b/pygls/workspace/workspace.py @@ -0,0 +1,311 @@ +############################################################################ +# Original work Copyright 2017 Palantir Technologies, Inc. # +# Original work licensed under the MIT License. # +# See ThirdPartyNotices.txt in the project root for license information. # +# All modifications Copyright (c) Open Law Library. All rights reserved. # +# # +# Licensed under the Apache License, Version 2.0 (the "License") # +# you may not use this file except in compliance with the License. # +# You may obtain a copy of the License at # +# # +# http: // www.apache.org/licenses/LICENSE-2.0 # +# # +# Unless required by applicable law or agreed to in writing, software # +# distributed under the License is distributed on an "AS IS" BASIS, # +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # +# See the License for the specific language governing permissions and # +# limitations under the License. # +############################################################################ +import copy +import logging +import os +import warnings +from typing import Dict, List, Optional, Union + +from lsprotocol import types +from lsprotocol.types import ( + PositionEncodingKind, + TextDocumentSyncKind, + WorkspaceFolder, +) +from pygls.uris import to_fs_path, uri_scheme +from pygls.workspace.text_document import TextDocument + +logger = logging.getLogger(__name__) + + +class Workspace(object): + def __init__( + self, + root_uri: Optional[str], + sync_kind: TextDocumentSyncKind = TextDocumentSyncKind.Incremental, + workspace_folders: Optional[List[WorkspaceFolder]] = None, + position_encoding: Optional[ + Union[PositionEncodingKind, str] + ] = PositionEncodingKind.Utf16, + ): + self._root_uri = root_uri + if self._root_uri is not None: + self._root_uri_scheme = uri_scheme(self._root_uri) + root_path = to_fs_path(self._root_uri) + if root_path is None: + raise Exception("Couldn't get `root_path` from `root_uri`") + self._root_path = root_path + self._sync_kind = sync_kind + self._text_documents: Dict[str, TextDocument] = {} + self._notebook_documents: Dict[str, types.NotebookDocument] = {} + + # Used to lookup notebooks which contain a given cell. + self._cell_in_notebook: Dict[str, str] = {} + self._folders: Dict[str, WorkspaceFolder] = {} + self._docs: Dict[str, TextDocument] = {} + self._position_encoding = position_encoding + + if workspace_folders is not None: + for folder in workspace_folders: + self.add_folder(folder) + + def _create_text_document( + self, + doc_uri: str, + source: Optional[str] = None, + version: Optional[int] = None, + language_id: Optional[str] = None, + ) -> TextDocument: + return TextDocument( + doc_uri, + source=source, + version=version, + language_id=language_id, + sync_kind=self._sync_kind, + position_encoding=self._position_encoding, + ) + + def add_folder(self, folder: WorkspaceFolder): + self._folders[folder.uri] = folder + + @property + def documents(self): + warnings.warn( + "'workspace.documents' has been deprecated, use " + "'workspace.text_documents' instead", + DeprecationWarning, + stacklevel=2, + ) + return self.text_documents + + @property + def notebook_documents(self): + return self._notebook_documents + + @property + def text_documents(self): + return self._text_documents + + @property + def folders(self): + return self._folders + + def get_notebook_document( + self, *, notebook_uri: Optional[str] = None, cell_uri: Optional[str] = None + ) -> Optional[types.NotebookDocument]: + """Return the notebook corresponding with the given uri. + + If both ``notebook_uri`` and ``cell_uri`` are given, ``notebook_uri`` takes + precedence. + + Parameters + ---------- + notebook_uri + If given, return the notebook document with the given uri. + + cell_uri + If given, return the notebook document which contains a cell with the + given uri + + Returns + ------- + Optional[NotebookDocument] + The requested notebook document if found, ``None`` otherwise. + """ + if notebook_uri is not None: + return self._notebook_documents.get(notebook_uri) + + if cell_uri is not None: + notebook_uri = self._cell_in_notebook.get(cell_uri) + if notebook_uri is None: + return None + + return self._notebook_documents.get(notebook_uri) + + return None + + def get_text_document(self, doc_uri: str) -> TextDocument: + """ + Return a managed document if-present, + else create one pointing at disk. + + See https://github.com/Microsoft/language-server-protocol/issues/177 + """ + return self._text_documents.get(doc_uri) or self._create_text_document(doc_uri) + + def is_local(self): + return ( + self._root_uri_scheme == "" or self._root_uri_scheme == "file" + ) and os.path.exists(self._root_path) + + def put_notebook_document(self, params: types.DidOpenNotebookDocumentParams): + notebook = params.notebook_document + + # Create a fresh instance to ensure our copy cannot be accidentally modified. + self._notebook_documents[notebook.uri] = copy.deepcopy(notebook) + + for cell_document in params.cell_text_documents: + self.put_text_document(cell_document, notebook_uri=notebook.uri) + + def put_text_document( + self, + text_document: types.TextDocumentItem, + notebook_uri: Optional[str] = None, + ): + """Add a text document to the workspace. + + Parameters + ---------- + text_document + The text document to add + + notebook_uri + If set, indicates that this text document represents a cell in a notebook + document + """ + doc_uri = text_document.uri + + self._text_documents[doc_uri] = self._create_text_document( + doc_uri, + source=text_document.text, + version=text_document.version, + language_id=text_document.language_id, + ) + + if notebook_uri: + self._cell_in_notebook[doc_uri] = notebook_uri + + def remove_notebook_document(self, params: types.DidCloseNotebookDocumentParams): + notebook_uri = params.notebook_document.uri + self._notebook_documents.pop(notebook_uri, None) + + for cell_document in params.cell_text_documents: + self.remove_text_document(cell_document.uri) + + def remove_text_document(self, doc_uri: str): + self._text_documents.pop(doc_uri, None) + self._cell_in_notebook.pop(doc_uri, None) + + def remove_folder(self, folder_uri: str): + self._folders.pop(folder_uri, None) + try: + del self._folders[folder_uri] + except KeyError: + pass + + @property + def root_path(self): + return self._root_path + + @property + def root_uri(self): + return self._root_uri + + def update_notebook_document(self, params: types.DidChangeNotebookDocumentParams): + uri = params.notebook_document.uri + notebook = self._notebook_documents[uri] + notebook.version = params.notebook_document.version + + if params.change.metadata: + notebook.metadata = params.change.metadata + + cell_changes = params.change.cells + if cell_changes is None: + return + + # Process changes to any cell metadata. + nb_cells = {cell.document: cell for cell in notebook.cells} + for new_data in cell_changes.data or []: + nb_cell = nb_cells.get(new_data.document) + if nb_cell is None: + logger.warning( + "Ignoring metadata for '%s': not in notebook.", new_data.document + ) + continue + + nb_cell.kind = new_data.kind + nb_cell.metadata = new_data.metadata + nb_cell.execution_summary = new_data.execution_summary + + # Process changes to the notebook's structure + structure = cell_changes.structure + if structure: + cells = notebook.cells + new_cells = structure.array.cells or [] + + # Re-order the cells + before = cells[: structure.array.start] + after = cells[(structure.array.start + structure.array.delete_count) :] + notebook.cells = [*before, *new_cells, *after] + + for new_cell in structure.did_open or []: + self.put_text_document(new_cell, notebook_uri=uri) + + for removed_cell in structure.did_close or []: + self.remove_text_document(removed_cell.uri) + + # Process changes to the text content of existing cells. + for text in cell_changes.text_content or []: + for change in text.changes: + self.update_text_document(text.document, change) + + def update_text_document( + self, + text_doc: types.VersionedTextDocumentIdentifier, + change: types.TextDocumentContentChangeEvent, + ): + doc_uri = text_doc.uri + self._text_documents[doc_uri].apply_change(change) + self._text_documents[doc_uri].version = text_doc.version + + def get_document(self, *args, **kwargs): + warnings.warn( + "'workspace.get_document' has been deprecated, use " + "'workspace.get_text_document' instead", + DeprecationWarning, + stacklevel=2, + ) + return self.get_text_document(*args, **kwargs) + + def remove_document(self, *args, **kwargs): + warnings.warn( + "'workspace.remove_document' has been deprecated, use " + "'workspace.remove_text_document' instead", + DeprecationWarning, + stacklevel=2, + ) + return self.remove_text_document(*args, **kwargs) + + def put_document(self, *args, **kwargs): + warnings.warn( + "'workspace.put_document' has been deprecated, use " + "'workspace.put_text_document' instead", + DeprecationWarning, + stacklevel=2, + ) + return self.put_text_document(*args, **kwargs) + + def update_document(self, *args, **kwargs): + warnings.warn( + "'workspace.update_document' has been deprecated, use " + "'workspace.update_text_document' instead", + DeprecationWarning, + stacklevel=2, + ) + return self.update_text_document(*args, **kwargs) diff --git a/pyproject.toml b/pyproject.toml index 93e4cfab..a1eb850f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,7 @@ readme = "README.md" [tool.poetry.dependencies] python = ">=3.7.9,<4" -lsprotocol = "2023.0.0a3" +lsprotocol = "2023.0.0b1" typeguard = "^3.0.0" websockets = {version = "^11.0.3", optional = true} @@ -64,6 +64,9 @@ generate_client = "python scripts/generate_client.py --output pygls/lsp/client.p generate_contributors_md = "python scripts/generate_contributors_md.py" black_check = "black --check ." +[tool.pyright] +strict = ["pygls"] + [tool.ruff] # Sometimes Black can't reduce line length without breaking more imortant rules. # So allow Ruff to be more lenient. diff --git a/tests/conftest.py b/tests/conftest.py index c816bd13..70d0a3ad 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -25,7 +25,7 @@ from pygls import uris, IS_PYODIDE, IS_WIN from pygls.feature_manager import FeatureManager -from pygls.workspace import Document, Workspace +from pygls.workspace import Workspace from .ls_setup import ( NativeClientServer, @@ -40,7 +40,7 @@ testing with "😋" unicode. """ -DOC_URI = uris.from_fs_path(__file__) +DOC_URI = uris.from_fs_path(__file__) or "" ClientServer = NativeClientServer @@ -110,11 +110,6 @@ def server_dir(): json_server_client = create_client_for_server("json_server.py") -@pytest.fixture -def doc(): - return Document(DOC_URI, DOC) - - @pytest.fixture def feature_manager(): """Return a feature manager""" diff --git a/tests/test_document.py b/tests/test_document.py index 402c43fe..859f5084 100644 --- a/tests/test_document.py +++ b/tests/test_document.py @@ -18,27 +18,17 @@ ############################################################################ import re -from lsprotocol.types import ( - Position, - Range, - TextDocumentContentChangeEvent_Type1, - TextDocumentSyncKind, -) -from pygls.workspace import ( - Document, - position_from_utf16, - position_to_utf16, - range_from_utf16, - range_to_utf16, -) +from lsprotocol import types +from pygls.workspace import TextDocument, Position from .conftest import DOC, DOC_URI def test_document_empty_edit(): - doc = Document("file:///uri", "") - change = TextDocumentContentChangeEvent_Type1( - range=Range( - start=Position(line=0, character=0), end=Position(line=0, character=0) + doc = TextDocument("file:///uri", "") + change = types.TextDocumentContentChangeEvent_Type1( + range=types.Range( + start=types.Position(line=0, character=0), + end=types.Position(line=0, character=0), ), range_length=0, text="f", @@ -49,11 +39,12 @@ def test_document_empty_edit(): def test_document_end_of_file_edit(): old = ["print 'a'\n", "print 'b'\n"] - doc = Document("file:///uri", "".join(old)) + doc = TextDocument("file:///uri", "".join(old)) - change = TextDocumentContentChangeEvent_Type1( - range=Range( - start=Position(line=2, character=0), end=Position(line=2, character=0) + change = types.TextDocumentContentChangeEvent_Type1( + range=types.Range( + start=types.Position(line=2, character=0), + end=types.Position(line=2, character=0), ), range_length=0, text="o", @@ -69,10 +60,13 @@ def test_document_end_of_file_edit(): def test_document_full_edit(): old = ["def hello(a, b):\n", " print a\n", " print b\n"] - doc = Document("file:///uri", "".join(old), sync_kind=TextDocumentSyncKind.Full) - change = TextDocumentContentChangeEvent_Type1( - range=Range( - start=Position(line=1, character=4), end=Position(line=2, character=11) + doc = TextDocument( + "file:///uri", "".join(old), sync_kind=types.TextDocumentSyncKind.Full + ) + change = types.TextDocumentContentChangeEvent_Type1( + range=types.Range( + start=types.Position(line=1, character=4), + end=types.Position(line=2, character=11), ), range_length=0, text="print a, b", @@ -81,18 +75,27 @@ def test_document_full_edit(): assert doc.lines == ["print a, b"] - doc = Document("file:///uri", "".join(old), sync_kind=TextDocumentSyncKind.Full) - change = TextDocumentContentChangeEvent_Type1(range=None, text="print a, b") + doc = TextDocument( + "file:///uri", "".join(old), sync_kind=types.TextDocumentSyncKind.Full + ) + change = types.TextDocumentContentChangeEvent_Type1( + range=types.Range( + start=types.Position(line=0, character=0), + end=types.Position(line=0, character=0), + ), + text="print a, b", + ) doc.apply_change(change) assert doc.lines == ["print a, b"] def test_document_line_edit(): - doc = Document("file:///uri", "itshelloworld") - change = TextDocumentContentChangeEvent_Type1( - range=Range( - start=Position(line=0, character=3), end=Position(line=0, character=8) + doc = TextDocument("file:///uri", "itshelloworld") + change = types.TextDocumentContentChangeEvent_Type1( + range=types.Range( + start=types.Position(line=0, character=3), + end=types.Position(line=0, character=8), ), range_length=0, text="goodbye", @@ -101,19 +104,21 @@ def test_document_line_edit(): assert doc.source == "itsgoodbyeworld" -def test_document_lines(doc): +def test_document_lines(): + doc = TextDocument(DOC_URI, DOC) assert len(doc.lines) == 4 assert doc.lines[0] == "document\n" def test_document_multiline_edit(): old = ["def hello(a, b):\n", " print a\n", " print b\n"] - doc = Document( - "file:///uri", "".join(old), sync_kind=TextDocumentSyncKind.Incremental + doc = TextDocument( + "file:///uri", "".join(old), sync_kind=types.TextDocumentSyncKind.Incremental ) - change = TextDocumentContentChangeEvent_Type1( - range=Range( - start=Position(line=1, character=4), end=Position(line=2, character=11) + change = types.TextDocumentContentChangeEvent_Type1( + range=types.Range( + start=types.Position(line=1, character=4), + end=types.Position(line=2, character=11), ), range_length=0, text="print a, b", @@ -122,12 +127,13 @@ def test_document_multiline_edit(): assert doc.lines == ["def hello(a, b):\n", " print a, b\n"] - doc = Document( - "file:///uri", "".join(old), sync_kind=TextDocumentSyncKind.Incremental + doc = TextDocument( + "file:///uri", "".join(old), sync_kind=types.TextDocumentSyncKind.Incremental ) - change = TextDocumentContentChangeEvent_Type1( - range=Range( - start=Position(line=1, character=4), end=Position(line=2, character=11) + change = types.TextDocumentContentChangeEvent_Type1( + range=types.Range( + start=types.Position(line=1, character=4), + end=types.Position(line=2, character=11), ), text="print a, b", ) @@ -138,10 +144,13 @@ def test_document_multiline_edit(): def test_document_no_edit(): old = ["def hello(a, b):\n", " print a\n", " print b\n"] - doc = Document("file:///uri", "".join(old), sync_kind=TextDocumentSyncKind.None_) - change = TextDocumentContentChangeEvent_Type1( - range=Range( - start=Position(line=1, character=4), end=Position(line=2, character=11) + doc = TextDocument( + "file:///uri", "".join(old), sync_kind=types.TextDocumentSyncKind.None_ + ) + change = types.TextDocumentContentChangeEvent_Type1( + range=types.Range( + start=types.Position(line=1, character=4), + end=types.Position(line=2, character=11), ), range_length=0, text="print a, b", @@ -151,145 +160,218 @@ def test_document_no_edit(): assert doc.lines == old -def test_document_props(doc): +def test_document_props(): + doc = TextDocument(DOC_URI, DOC) + assert doc.uri == DOC_URI assert doc.source == DOC def test_document_source_unicode(): - document_mem = Document(DOC_URI, "my source") - document_disk = Document(DOC_URI) + document_mem = TextDocument(DOC_URI, "my source") + document_disk = TextDocument(DOC_URI) assert isinstance(document_mem.source, type(document_disk.source)) def test_position_from_utf16(): - assert position_from_utf16(['x="😋"'], Position(line=0, character=3)) == Position( - line=0, character=3 - ) - assert position_from_utf16(['x="😋"'], Position(line=0, character=5)) == Position( - line=0, character=4 - ) - - position = Position(line=0, character=5) - position_from_utf16(['x="😋"'], position) - assert position == Position(line=0, character=5) + position = Position(encoding=types.PositionEncodingKind.Utf16) + assert position.position_from_client_units( + ['x="😋"'], types.Position(line=0, character=3) + ) == types.Position(line=0, character=3) + assert position.position_from_client_units( + ['x="😋"'], types.Position(line=0, character=5) + ) == types.Position(line=0, character=4) + + +def test_position_from_utf32(): + position = Position(encoding=types.PositionEncodingKind.Utf32) + assert position.position_from_client_units( + ['x="😋"'], types.Position(line=0, character=3) + ) == types.Position(line=0, character=3) + assert position.position_from_client_units( + ['x="😋"'], types.Position(line=0, character=4) + ) == types.Position(line=0, character=4) + + +def test_position_from_utf8(): + position = Position(encoding=types.PositionEncodingKind.Utf8) + assert position.position_from_client_units( + ['x="😋"'], types.Position(line=0, character=3) + ) == types.Position(line=0, character=3) + assert position.position_from_client_units( + ['x="😋"'], types.Position(line=0, character=7) + ) == types.Position(line=0, character=4) def test_position_to_utf16(): - assert position_to_utf16(['x="😋"'], Position(line=0, character=3)) == Position( - line=0, character=3 - ) + position = Position(encoding=types.PositionEncodingKind.Utf16) + assert position.position_to_client_units( + ['x="😋"'], types.Position(line=0, character=3) + ) == types.Position(line=0, character=3) - assert position_to_utf16(['x="😋"'], Position(line=0, character=4)) == Position( - line=0, character=5 - ) + assert position.position_to_client_units( + ['x="😋"'], types.Position(line=0, character=4) + ) == types.Position(line=0, character=5) + + +def test_position_to_utf32(): + position = Position(encoding=types.PositionEncodingKind.Utf32) + assert position.position_to_client_units( + ['x="😋"'], types.Position(line=0, character=3) + ) == types.Position(line=0, character=3) - position = Position(line=0, character=4) - position_to_utf16(['x="😋"'], position) - assert position == Position(line=0, character=4) + assert position.position_to_client_units( + ['x="😋"'], types.Position(line=0, character=4) + ) == types.Position(line=0, character=4) + + +def test_position_to_utf8(): + position = Position(encoding=types.PositionEncodingKind.Utf8) + assert position.position_to_client_units( + ['x="😋"'], types.Position(line=0, character=3) + ) == types.Position(line=0, character=3) + + assert position.position_to_client_units( + ['x="😋"'], types.Position(line=0, character=4) + ) == types.Position(line=0, character=6) def test_range_from_utf16(): - assert range_from_utf16( + position = Position(encoding=types.PositionEncodingKind.Utf16) + assert position.range_from_client_units( ['x="😋"'], - Range(start=Position(line=0, character=3), end=Position(line=0, character=5)), - ) == Range(start=Position(line=0, character=3), end=Position(line=0, character=4)) + types.Range( + start=types.Position(line=0, character=3), + end=types.Position(line=0, character=5), + ), + ) == types.Range( + start=types.Position(line=0, character=3), + end=types.Position(line=0, character=4), + ) - range = Range( - start=Position(line=0, character=3), end=Position(line=0, character=5) + range = types.Range( + start=types.Position(line=0, character=3), + end=types.Position(line=0, character=5), ) - actual = range_from_utf16(['x="😋😋"'], range) - expected = Range( - start=Position(line=0, character=3), end=Position(line=0, character=4) + actual = position.range_from_client_units(['x="😋😋"'], range) + expected = types.Range( + start=types.Position(line=0, character=3), + end=types.Position(line=0, character=4), ) assert actual == expected def test_range_to_utf16(): - assert range_to_utf16( + position = Position(encoding=types.PositionEncodingKind.Utf16) + assert position.range_to_client_units( ['x="😋"'], - Range(start=Position(line=0, character=3), end=Position(line=0, character=4)), - ) == Range(start=Position(line=0, character=3), end=Position(line=0, character=5)) + types.Range( + start=types.Position(line=0, character=3), + end=types.Position(line=0, character=4), + ), + ) == types.Range( + start=types.Position(line=0, character=3), + end=types.Position(line=0, character=5), + ) - range = Range( - start=Position(line=0, character=3), end=Position(line=0, character=4) + range = types.Range( + start=types.Position(line=0, character=3), + end=types.Position(line=0, character=4), ) - actual = range_to_utf16(['x="😋😋"'], range) - expected = Range( - start=Position(line=0, character=3), end=Position(line=0, character=5) + actual = position.range_to_client_units(['x="😋😋"'], range) + expected = types.Range( + start=types.Position(line=0, character=3), + end=types.Position(line=0, character=5), ) assert actual == expected -def test_offset_at_position(doc): - assert doc.offset_at_position(Position(line=0, character=8)) == 8 - assert doc.offset_at_position(Position(line=1, character=5)) == 12 - assert doc.offset_at_position(Position(line=2, character=0)) == 13 - assert doc.offset_at_position(Position(line=2, character=4)) == 17 - assert doc.offset_at_position(Position(line=3, character=6)) == 27 - assert doc.offset_at_position(Position(line=3, character=7)) == 28 - assert doc.offset_at_position(Position(line=3, character=8)) == 28 - assert doc.offset_at_position(Position(line=4, character=0)) == 40 - assert doc.offset_at_position(Position(line=5, character=0)) == 40 +def test_offset_at_position_utf16(): + doc = TextDocument(DOC_URI, DOC) + assert doc.offset_at_position(types.Position(line=0, character=8)) == 8 + assert doc.offset_at_position(types.Position(line=1, character=5)) == 12 + assert doc.offset_at_position(types.Position(line=2, character=0)) == 13 + assert doc.offset_at_position(types.Position(line=2, character=4)) == 17 + assert doc.offset_at_position(types.Position(line=3, character=6)) == 27 + assert doc.offset_at_position(types.Position(line=3, character=7)) == 28 + assert doc.offset_at_position(types.Position(line=3, character=8)) == 28 + assert doc.offset_at_position(types.Position(line=4, character=0)) == 40 + assert doc.offset_at_position(types.Position(line=5, character=0)) == 40 -def test_utf16_to_utf32_position_cast(doc): - lines = ["", "😋😋", ""] - assert position_from_utf16(lines, Position(line=0, character=0)) == Position( - line=0, character=0 - ) - assert position_from_utf16(lines, Position(line=0, character=1)) == Position( - line=0, character=0 - ) - assert position_from_utf16(lines, Position(line=1, character=0)) == Position( - line=1, character=0 - ) - assert position_from_utf16(lines, Position(line=1, character=2)) == Position( - line=1, character=1 - ) - assert position_from_utf16(lines, Position(line=1, character=3)) == Position( - line=1, character=2 - ) - assert position_from_utf16(lines, Position(line=1, character=4)) == Position( - line=1, character=2 - ) - assert position_from_utf16(lines, Position(line=1, character=100)) == Position( - line=1, character=2 - ) - assert position_from_utf16(lines, Position(line=3, character=0)) == Position( - line=2, character=0 - ) - assert position_from_utf16(lines, Position(line=4, character=10)) == Position( - line=2, character=0 - ) +def test_offset_at_position_utf32(): + doc = TextDocument(DOC_URI, DOC, position_encoding=types.PositionEncodingKind.Utf32) + assert doc.offset_at_position(types.Position(line=0, character=8)) == 8 + assert doc.offset_at_position(types.Position(line=5, character=0)) == 39 + +def test_offset_at_position_utf8(): + doc = TextDocument(DOC_URI, DOC, position_encoding=types.PositionEncodingKind.Utf8) + assert doc.offset_at_position(types.Position(line=0, character=8)) == 8 + assert doc.offset_at_position(types.Position(line=5, character=0)) == 41 -def test_position_for_line_endings(doc): + +def test_utf16_to_utf32_position_cast(): + position = Position(encoding=types.PositionEncodingKind.Utf16) + lines = ["", "😋😋", ""] + assert position.position_from_client_units( + lines, types.Position(line=0, character=0) + ) == types.Position(line=0, character=0) + assert position.position_from_client_units( + lines, types.Position(line=0, character=1) + ) == types.Position(line=0, character=0) + assert position.position_from_client_units( + lines, types.Position(line=1, character=0) + ) == types.Position(line=1, character=0) + assert position.position_from_client_units( + lines, types.Position(line=1, character=2) + ) == types.Position(line=1, character=1) + assert position.position_from_client_units( + lines, types.Position(line=1, character=3) + ) == types.Position(line=1, character=2) + assert position.position_from_client_units( + lines, types.Position(line=1, character=4) + ) == types.Position(line=1, character=2) + assert position.position_from_client_units( + lines, types.Position(line=1, character=100) + ) == types.Position(line=1, character=2) + assert position.position_from_client_units( + lines, types.Position(line=3, character=0) + ) == types.Position(line=2, character=0) + assert position.position_from_client_units( + lines, types.Position(line=4, character=10) + ) == types.Position(line=2, character=0) + + +def test_position_for_line_endings(): + position = Position(encoding=types.PositionEncodingKind.Utf16) lines = ["x\r\n", "y\n"] - assert position_from_utf16(lines, Position(line=0, character=10)) == Position( - line=0, character=1 - ) - assert position_from_utf16(lines, Position(line=1, character=10)) == Position( - line=1, character=1 - ) + assert position.position_from_client_units( + lines, types.Position(line=0, character=10) + ) == types.Position(line=0, character=1) + assert position.position_from_client_units( + lines, types.Position(line=1, character=10) + ) == types.Position(line=1, character=1) -def test_word_at_position(doc): +def test_word_at_position(): """ Return word under the cursor (or last in line if past the end) """ - assert doc.word_at_position(Position(line=0, character=8)) == "document" - assert doc.word_at_position(Position(line=0, character=1000)) == "document" - assert doc.word_at_position(Position(line=1, character=5)) == "for" - assert doc.word_at_position(Position(line=2, character=0)) == "testing" - assert doc.word_at_position(Position(line=3, character=10)) == "unicode" - assert doc.word_at_position(Position(line=4, character=0)) == "" - assert doc.word_at_position(Position(line=4, character=0)) == "" + doc = TextDocument(DOC_URI, DOC) + + assert doc.word_at_position(types.Position(line=0, character=8)) == "document" + assert doc.word_at_position(types.Position(line=0, character=1000)) == "document" + assert doc.word_at_position(types.Position(line=1, character=5)) == "for" + assert doc.word_at_position(types.Position(line=2, character=0)) == "testing" + assert doc.word_at_position(types.Position(line=3, character=10)) == "unicode" + assert doc.word_at_position(types.Position(line=4, character=0)) == "" + assert doc.word_at_position(types.Position(line=4, character=0)) == "" re_start_word = re.compile(r"[A-Za-z_0-9.]*$") re_end_word = re.compile(r"^[A-Za-z_0-9.]*") assert ( doc.word_at_position( - Position( + types.Position( line=3, character=10, ), diff --git a/tests/test_feature_manager.py b/tests/test_feature_manager.py index d9655b85..4e5c852b 100644 --- a/tests/test_feature_manager.py +++ b/tests/test_feature_manager.py @@ -215,12 +215,38 @@ def server_capabilities(**kwargs): file_operations=lsp.FileOperationOptions(), ) + if "position_encoding" not in kwargs: + kwargs["position_encoding"] = lsp.PositionEncodingKind.Utf16 + return lsp.ServerCapabilities(**kwargs) @pytest.mark.parametrize( "method, options, capabilities, expected", [ + ( + lsp.INITIALIZE, + None, + lsp.ClientCapabilities( + general=lsp.GeneralClientCapabilities( + position_encodings=[lsp.PositionEncodingKind.Utf8] + ) + ), + server_capabilities(position_encoding=lsp.PositionEncodingKind.Utf8), + ), + ( + lsp.INITIALIZE, + None, + lsp.ClientCapabilities( + general=lsp.GeneralClientCapabilities( + position_encodings=[ + lsp.PositionEncodingKind.Utf8, + lsp.PositionEncodingKind.Utf32, + ] + ) + ), + server_capabilities(position_encoding=lsp.PositionEncodingKind.Utf32), + ), ( lsp.TEXT_DOCUMENT_DID_SAVE, lsp.SaveOptions(include_text=True),