Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add way to read http response for failed ws connect #6515

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES/6515.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add `session.try_ws_connect` method that allows access to ClientResponse in case WS handshake failed.
202 changes: 163 additions & 39 deletions aiohttp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
)
from .client_ws import (
DEFAULT_WS_CLIENT_TIMEOUT,
ClientWebSocketHandshakeResponse,
ClientWebSocketResponse as ClientWebSocketResponse,
ClientWSTimeout,
)
Expand Down Expand Up @@ -714,7 +715,97 @@ async def _ws_connect(
proxy_headers: Optional[LooseHeaders] = None,
compress: int = 0,
max_msg_size: int = 4 * 1024 * 1024,
) -> ClientWebSocketResponse:
) -> "ClientWebSocketResponse":
ws_handshake_resp = await self.try_ws_connect(
url,
method=method,
protocols=protocols,
timeout=timeout,
receive_timeout=receive_timeout,
autoclose=autoclose,
autoping=autoping,
heartbeat=heartbeat,
auth=auth,
origin=origin,
params=params,
headers=headers,
proxy=proxy,
proxy_auth=proxy_auth,
ssl=ssl,
proxy_headers=proxy_headers,
compress=compress,
max_msg_size=max_msg_size,
)
return ws_handshake_resp.upgrade()

def try_ws_connect(
self,
url: StrOrURL,
*,
method: str = hdrs.METH_GET,
protocols: Iterable[str] = (),
timeout: Union[ClientWSTimeout, float, _SENTINEL, None] = sentinel,
receive_timeout: Optional[float] = None,
autoclose: bool = True,
autoping: bool = True,
heartbeat: Optional[float] = None,
auth: Optional[BasicAuth] = None,
origin: Optional[str] = None,
params: Optional[Mapping[str, str]] = None,
headers: Optional[LooseHeaders] = None,
proxy: Optional[StrOrURL] = None,
proxy_auth: Optional[BasicAuth] = None,
ssl: Union[SSLContext, bool, None, Fingerprint] = None,
proxy_headers: Optional[LooseHeaders] = None,
compress: int = 0,
max_msg_size: int = 4 * 1024 * 1024,
) -> "_WSTryRequestContextManager":
"""Initiate websocket connection."""
return _WSTryRequestContextManager(
self._try_ws_connect(
url,
method=method,
protocols=protocols,
timeout=timeout,
receive_timeout=receive_timeout,
autoclose=autoclose,
autoping=autoping,
heartbeat=heartbeat,
auth=auth,
origin=origin,
params=params,
headers=headers,
proxy=proxy,
proxy_auth=proxy_auth,
ssl=ssl,
proxy_headers=proxy_headers,
compress=compress,
max_msg_size=max_msg_size,
)
)

async def _try_ws_connect(
self,
url: StrOrURL,
*,
method: str = hdrs.METH_GET,
protocols: Iterable[str] = (),
timeout: Union[ClientWSTimeout, float, _SENTINEL, None] = sentinel,
receive_timeout: Optional[float] = None,
autoclose: bool = True,
autoping: bool = True,
heartbeat: Optional[float] = None,
auth: Optional[BasicAuth] = None,
origin: Optional[str] = None,
params: Optional[Mapping[str, str]] = None,
headers: Optional[LooseHeaders] = None,
proxy: Optional[StrOrURL] = None,
proxy_auth: Optional[BasicAuth] = None,
ssl: Union[SSLContext, bool, None, Fingerprint] = None,
proxy_headers: Optional[LooseHeaders] = None,
compress: int = 0,
max_msg_size: int = 4 * 1024 * 1024,
) -> ClientWebSocketHandshakeResponse:
if timeout is sentinel or timeout is None:
ws_timeout = DEFAULT_WS_CLIENT_TIMEOUT
else:
Expand Down Expand Up @@ -788,42 +879,54 @@ async def _ws_connect(
try:
# check handshake
if resp.status != 101:
raise WSServerHandshakeError(
resp.request_info,
resp.history,
message="Invalid response status",
status=resp.status,
headers=resp.headers,
return ClientWebSocketHandshakeResponse(
error=WSServerHandshakeError(
resp.request_info,
resp.history,
message="Invalid response status",
status=resp.status,
headers=resp.headers,
),
error_response=resp,
)

if resp.headers.get(hdrs.UPGRADE, "").lower() != "websocket":
raise WSServerHandshakeError(
resp.request_info,
resp.history,
message="Invalid upgrade header",
status=resp.status,
headers=resp.headers,
return ClientWebSocketHandshakeResponse(
error=WSServerHandshakeError(
resp.request_info,
resp.history,
message="Invalid upgrade header",
status=resp.status,
headers=resp.headers,
),
error_response=resp,
)

if resp.headers.get(hdrs.CONNECTION, "").lower() != "upgrade":
raise WSServerHandshakeError(
resp.request_info,
resp.history,
message="Invalid connection header",
status=resp.status,
headers=resp.headers,
return ClientWebSocketHandshakeResponse(
error=WSServerHandshakeError(
resp.request_info,
resp.history,
message="Invalid connection header",
status=resp.status,
headers=resp.headers,
),
error_response=resp,
)

# key calculation
r_key = resp.headers.get(hdrs.SEC_WEBSOCKET_ACCEPT, "")
match = base64.b64encode(hashlib.sha1(sec_key + WS_KEY).digest()).decode()
if r_key != match:
raise WSServerHandshakeError(
resp.request_info,
resp.history,
message="Invalid challenge response",
status=resp.status,
headers=resp.headers,
return ClientWebSocketHandshakeResponse(
error=WSServerHandshakeError(
resp.request_info,
resp.history,
message="Invalid challenge response",
status=resp.status,
headers=resp.headers,
),
error_response=resp,
)

# websocket protocol
Expand All @@ -847,13 +950,18 @@ async def _ws_connect(
try:
compress, notakeover = ws_ext_parse(compress_hdrs)
except WSHandshakeError as exc:
raise WSServerHandshakeError(
error = WSServerHandshakeError(
resp.request_info,
resp.history,
message=exc.args[0],
status=resp.status,
headers=resp.headers,
) from exc
)
error.__cause__ = exc
return ClientWebSocketHandshakeResponse(
error=error,
error_response=resp,
)
else:
compress = 0
notakeover = False
Expand All @@ -879,18 +987,20 @@ async def _ws_connect(
resp.close()
raise
else:
return self._ws_response_class(
reader,
writer,
protocol,
resp,
ws_timeout,
autoclose,
autoping,
self._loop,
heartbeat=heartbeat,
compress=compress,
client_notakeover=notakeover,
return ClientWebSocketHandshakeResponse(
ws_response=self._ws_response_class(
reader,
writer,
protocol,
resp,
ws_timeout,
autoclose,
autoping,
self._loop,
heartbeat=heartbeat,
compress=compress,
client_notakeover=notakeover,
)
)

def _prepare_headers(self, headers: Optional[LooseHeaders]) -> "CIMultiDict[str]":
Expand Down Expand Up @@ -1136,6 +1246,20 @@ async def __aexit__(
await self._resp.close()


class _WSTryRequestContextManager(
_BaseRequestContextManager[ClientWebSocketHandshakeResponse]
):
__slots__ = ()

async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc: Optional[BaseException],
tb: Optional[TracebackType],
) -> None:
await self._resp.close()


class _SessionRequestContextManager:

__slots__ = ("_coro", "_resp", "_session")
Expand Down
37 changes: 36 additions & 1 deletion aiohttp/client_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import async_timeout
from typing_extensions import Final

from .client_exceptions import ClientError
from .client_exceptions import ClientError, WSServerHandshakeError
from .client_reqrep import ClientResponse
from .helpers import call_later, set_result
from .http import (
Expand Down Expand Up @@ -39,6 +39,41 @@ class ClientWSTimeout:
)


class ClientWebSocketHandshakeResponse:
def __init__(
self,
*,
ws_response: Optional["ClientWebSocketResponse"] = None,
error: Optional[WSServerHandshakeError] = None,
error_response: Optional[ClientResponse] = None,
):
self._error = error
self._error_response = error_response
self._ws_response = ws_response

def upgrade(self) -> "ClientWebSocketResponse":
if self._error:
if self._error_response:
self._error_response.close()
raise self._error
assert self._ws_response
return self._ws_response

@property
def error(self) -> Optional[WSServerHandshakeError]:
return self._error

@property
def error_response(self) -> Optional[ClientResponse]:
return self._error_response

async def close(self) -> None:
if self._ws_response:
await self._ws_response.close()
elif self._error_response:
self._error_response.close()


class ClientWebSocketResponse:
def __init__(
self,
Expand Down
51 changes: 49 additions & 2 deletions docs/client_reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -642,7 +642,7 @@ The client session supports the context manager protocol for self closing.
:return ClientResponse: a :class:`client response
<ClientResponse>` object.

.. comethod:: ws_connect(url, *, method='GET', \
.. comethod:: try_ws_connect(url, *, method='GET', \
protocols=(), timeout=10.0,\
receive_timeout=None,\
auth=None,\
Expand All @@ -660,7 +660,7 @@ The client session supports the context manager protocol for self closing.
:coroutine:

Create a websocket connection. Returns a
:class:`ClientWebSocketResponse` object.
:class:`ClientWebSocketHandshakeResponse` object.

:param url: Websocket server url, :class:`str` or :class:`~yarl.URL`

Expand Down Expand Up @@ -784,6 +784,25 @@ The client session supports the context manager protocol for self closing.

.. versionadded:: 3.5

.. comethod:: ws_connect(url, *, **kwargs)
:async-with:
:coroutine:

Create a websocket connection. Returns a
:class:`ClientWebSocketResponse` object.

This is shortcut to::

async with session.try_ws_connect(url, **kwargs) as handshake_resp:
resp = handshake_resp.upgrade()

In order to modify inner
:meth:`try_ws_connect<aiohttp.ClientSession.try_ws_connect>`
parameters, provide `kwargs`.

:param url: Request URL, :class:`str` or :class:`~yarl.URL`



.. comethod:: close()

Expand Down Expand Up @@ -1663,6 +1682,34 @@ manually.
:raise ValueError: if message is not valid JSON.


ClientWebSocketHandshakeResponse
--------------------------------


.. class:: ClientWebSocketHandshakeResponse()

Class for handling client-side websockets handshake result.

.. method:: upgrade()

Get a underlying :class:`ClientWebSocketResponse` if handshake
was a successful or raise an exception

:return: :class:`ClientWebSocketResponse`.

.. attribute:: error

Read-only property, :exc:`WSServerHandshakeError` if handshake
failed or ``None`` otherwise.

.. attribute:: error_response

Read-only property, :class:`ClientResponse` of initial http
request if handshake failed or ``None`` otherwise.

This property allows to read error response body.


Utilities
---------

Expand Down
Loading