Skip to content

Commit

Permalink
Allow overriding the scheme used in websocket_connect to support wss (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
shevron authored Dec 30, 2021
1 parent 0e92f98 commit 56779ef
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 22 deletions.
34 changes: 13 additions & 21 deletions async_asgi_testclient/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,9 @@ def __init__(
scope: Optional[dict] = None,
):
self.application = guarantee_single_callable(application)
self.cookie_jar: Optional[
SimpleCookie
] = SimpleCookie() if use_cookies else None
self.cookie_jar: Optional[SimpleCookie] = (
SimpleCookie() if use_cookies else None
)
self.timeout = timeout
self.headers = headers or {}
self._scope = scope or {}
Expand Down Expand Up @@ -106,8 +106,8 @@ async def send_lifespan(self, action):
elif message["type"] == f"lifespan.{action}.failed":
raise Exception(message)

def websocket_connect(self, path, headers=None, cookies=None):
return WebSocketSession(self, path, headers, cookies)
def websocket_connect(self, *args: Any, **kwargs: Any) -> WebSocketSession:
return WebSocketSession(self, *args, **kwargs)

async def open(
self,
Expand Down Expand Up @@ -304,41 +304,33 @@ async def wait_response(self, receive_or_fail, type_):
return message

async def delete(self, *args: Any, **kwargs: Any) -> Response:
"""Make a DELETE request.
"""
"""Make a DELETE request."""
return await self.open(*args, method="DELETE", **kwargs)

async def get(self, *args: Any, **kwargs: Any) -> Response:
"""Make a GET request.
"""
"""Make a GET request."""
return await self.open(*args, method="GET", **kwargs)

async def head(self, *args: Any, **kwargs: Any) -> Response:
"""Make a HEAD request.
"""
"""Make a HEAD request."""
return await self.open(*args, method="HEAD", **kwargs)

async def options(self, *args: Any, **kwargs: Any) -> Response:
"""Make a OPTIONS request.
"""
"""Make a OPTIONS request."""
return await self.open(*args, method="OPTIONS", **kwargs)

async def patch(self, *args: Any, **kwargs: Any) -> Response:
"""Make a PATCH request.
"""
"""Make a PATCH request."""
return await self.open(*args, method="PATCH", **kwargs)

async def post(self, *args: Any, **kwargs: Any) -> Response:
"""Make a POST request.
"""
"""Make a POST request."""
return await self.open(*args, method="POST", **kwargs)

async def put(self, *args: Any, **kwargs: Any) -> Response:
"""Make a PUT request.
"""
"""Make a PUT request."""
return await self.open(*args, method="PUT", **kwargs)

async def trace(self, *args: Any, **kwargs: Any) -> Response:
"""Make a TRACE request.
"""
"""Make a TRACE request."""
return await self.open(*args, method="TRACE", **kwargs)
20 changes: 20 additions & 0 deletions async_asgi_testclient/tests/test_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ class Echo(WebSocketEndpoint):
async def on_receive(self, websocket, data):
if data == "cookies":
await websocket.send_text(dumps(websocket.cookies))
elif data == "url":
await websocket.send_text(str(websocket.url))
else:
await websocket.send_text(f"Message text was: {data}")

Expand Down Expand Up @@ -430,6 +432,24 @@ async def test_ws_connect_inherits_test_client_cookies(starlette_app):
assert msg == '{"session": "abc"}'


@pytest.mark.asyncio
async def test_ws_connect_default_scheme(starlette_app):
async with TestClient(starlette_app, timeout=0.1) as client:
async with client.websocket_connect("/ws") as ws:
await ws.send_text("url")
msg = await ws.receive_text()
assert msg.startswith("ws://")


@pytest.mark.asyncio
async def test_ws_connect_custom_scheme(starlette_app):
async with TestClient(starlette_app, timeout=0.1) as client:
async with client.websocket_connect("/ws", scheme="wss") as ws:
await ws.send_text("url")
msg = await ws.receive_text()
assert msg.startswith("wss://")


@pytest.mark.asyncio
async def test_request_stream(starlette_app):
from starlette.responses import StreamingResponse
Expand Down
4 changes: 3 additions & 1 deletion async_asgi_testclient/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@ def __init__(
path,
headers: Optional[Dict] = None,
cookies: Optional[Dict] = None,
scheme: str = "ws",
):
self.testclient = testclient
self.path = path
self.headers = headers or {}
self.cookies = cookies
self.scheme = scheme
self.input_queue: asyncio.Queue[dict] = asyncio.Queue()
self.output_queue: asyncio.Queue[dict] = asyncio.Queue()
self._app_task = None # Necessary to keep a hard reference to running task
Expand Down Expand Up @@ -118,7 +120,7 @@ async def connect(self):
"path": path,
"query_string": query_string_bytes,
"root_path": "",
"scheme": "ws",
"scheme": self.scheme,
"subprotocols": [],
}

Expand Down

0 comments on commit 56779ef

Please sign in to comment.