Skip to content

Commit

Permalink
fix: don't cancel the future handling the shutdown request
Browse files Browse the repository at this point in the history
This commit makes use of a `ContextVar` to keep track of the current
request's id, allowing handlers to reference it. Most importantly so
that the shutdown request handler does not cancel its own future!
  • Loading branch information
alcarney committed Nov 30, 2024
1 parent 2d9513f commit e4862c1
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 13 deletions.
43 changes: 32 additions & 11 deletions pygls/protocol/json_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from __future__ import annotations

import asyncio
import contextvars
import enum
import inspect
import json
Expand Down Expand Up @@ -62,6 +63,10 @@

logger = logging.getLogger(__name__)

# cattrs needs access to this type definition so we cannot include it in the
# TYPE_CHECKING block above
MsgId = Union[str | int]


@attrs.define
class JsonRPCNotification:
Expand All @@ -80,7 +85,7 @@ class JsonRPCRequestMessage:
Used as a fallback for unknown types.
"""

id: Union[int, str]
id: MsgId
method: str
jsonrpc: str
params: Any
Expand All @@ -92,7 +97,7 @@ class JsonRPCResponseMessage:
Used as a fallback for unknown types.
"""

id: Union[int, str]
id: MsgId
jsonrpc: str
result: Any

Expand All @@ -117,8 +122,11 @@ def __init__(self, server: JsonRPCServer, converter: Converter):
self._shutdown = False

# Book keeping for in-flight requests
self._request_futures: dict[str | int, Future[Any]] = {}
self._result_types: dict[str | int, Any] = {}
self._ctx_msg_id: contextvars.ContextVar[MsgId | None] = contextvars.ContextVar(
"msg_id", default=None
)
self._request_futures: dict[MsgId, Future[Any]] = {}
self._result_types: dict[MsgId, Any] = {}

self.fm = FeatureManager(server, converter)
self.writer: AsyncWriter | Writer | None = None
Expand All @@ -127,9 +135,15 @@ def __init__(self, server: JsonRPCServer, converter: Converter):
def __call__(self):
return self

@property
def msg_id(self) -> MsgId | None:
"""Returns the id of the current context (if it exists)."""
ctx = contextvars.copy_context()
return ctx.get(self._ctx_msg_id)

def _execute_handler(
self,
msg_id: str | int,
msg_id: MsgId,
handler: MessageHandler,
callback: MessageCallback,
args: tuple[Any, ...] | None = None,
Expand Down Expand Up @@ -300,7 +314,7 @@ def _handle_cancel_notification(self, msg_id):
if future.cancel():
logger.info('Cancelled request with id "%s"', msg_id)

def _handle_notification(self, method_name, params):
def _handle_notification(self, method_name: str, params: Any):
"""Handles a notification from the client."""
if method_name == CANCEL_REQUEST:
self._handle_cancel_notification(params.id)
Expand All @@ -325,11 +339,13 @@ def _handle_notification(self, method_name, params):
)
self._server._report_server_error(error, FeatureNotificationError)

def _handle_request(self, msg_id, method_name, params):
def _handle_request(self, msg_id: MsgId, method_name: str, params: Any):
"""Handles a request from the client."""
try:
handler = self._get_handler(method_name)

# Set the request id within the current context.
self._ctx_msg_id.set(msg_id)
self._execute_handler(
msg_id=msg_id,
handler=handler,
Expand Down Expand Up @@ -438,20 +454,25 @@ def handle_message(self, message):
logger.warning("Server shutting down. No more requests!")
return

# Run each handler within its own context.
ctx = contextvars.copy_context()

if hasattr(message, "method"):
if hasattr(message, "id"):
logger.debug("Request %r received", message.method)
self._handle_request(message.id, message.method, message.params)
ctx.run(
self._handle_request, message.id, message.method, message.params
)
else:
logger.debug("Notification %r received", message.method)
self._handle_notification(message.method, message.params)
ctx.run(self._handle_notification, message.method, message.params)
else:
if hasattr(message, "error"):
logger.debug("Error message received.")
self._handle_response(message.id, None, message.error)
ctx.run(self._handle_response, message.id, None, message.error)
else:
logger.debug("Response message received.")
self._handle_response(message.id, message.result)
ctx.run(self._handle_response, message.id, message.result)

def _send_data(self, data):
"""Sends data to the client."""
Expand Down
8 changes: 6 additions & 2 deletions pygls/protocol/language_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,12 @@ def lsp_shutdown(self, *args) -> None:
if (user_handler := self.fm.features.get(types.SHUTDOWN)) is not None:
yield user_handler, args, None

for future in self._request_futures.values():
future.cancel()
# Don't cancel the future for this request!
current_id = self.msg_id

for msg_id, future in self._request_futures.items():
if msg_id != current_id and not future.done():
future.cancel()

self._shutdown = True
return None
Expand Down

0 comments on commit e4862c1

Please sign in to comment.