From 6885bce958b371b4d5079a34ff18e2d0e81b603c Mon Sep 17 00:00:00 2001 From: James Scott Date: Sun, 7 Jan 2024 17:03:20 -0500 Subject: [PATCH] And all tests pass now. --- starlette/middleware/sessions.py | 61 +++++++++++++++++++------------- 1 file changed, 37 insertions(+), 24 deletions(-) diff --git a/starlette/middleware/sessions.py b/starlette/middleware/sessions.py index 9f1a07f1b..df90be0d5 100644 --- a/starlette/middleware/sessions.py +++ b/starlette/middleware/sessions.py @@ -22,7 +22,7 @@ def __init__( same_site: typing.Literal["lax", "strict", "none"] = "lax", https_only: bool = False, persist_session: bool = False, - auto_refresh_window: int = 0, + auto_refresh_window: int = 0, domain: typing.Optional[str] = None, ) -> None: self.app = app @@ -38,17 +38,19 @@ def __init__( if domain is not None: self.security_flags += f"; domain={domain}" - - def decode_cookie(self,cookie: bytes) -> typing.Dict[str,typing.Any]: - result: typing.Dict[str, typing.Any] = {"session": {}} + def decode_cookie(self, cookie: bytes) -> typing.Dict[str, typing.Any]: + result: typing.Dict[str, typing.Any] = {"session": {}} try: - data = self.signer.unsign(cookie,max_age=self.max_age,return_timestamp=True) - result["session"] = json.loads(b64decode(data[0])) - result["datetime"] = data[1] #DateTime obj + data = self.signer.unsign( + cookie, max_age=self.max_age, return_timestamp=True + ) + result["session"] = json.loads(b64decode(data[0])) + + result["datetime"] = data[1] # DateTime obj except (BadSignature, SignatureExpired): return result return result - + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: if scope["type"] not in ("http", "websocket"): # pragma: no cover await self.app(scope, receive, send) @@ -58,29 +60,41 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: update_session = True if self.session_cookie in connection.cookies: - data = self.decode_cookie(connection.cookies[self.session_cookie].encode("utf-8")) # noqa E501 - scope["session"] = data["session"] - scope["exp"] = data["datetime"] + timedelta(seconds=self.max_age) # type: ignore[arg-type] - - if self.auto_refresh_window: - now = datetime.now(timezone.utc) - #if the expiry date not inside of the expiry window, do not update. - if not (now >= (scope["exp"] - timedelta(seconds=self.auto_refresh_window)) and now <= scope["exp"]): # noqa E501 - update_session = False - elif self.persist_session: + data = self.decode_cookie( + connection.cookies[self.session_cookie].encode("utf-8") + ) # noqa E501 + if data["session"]: + scope["session"] = data["session"] + expiration = data["datetime"] + timedelta(seconds=self.max_age) # type: ignore[arg-type] + + if self.auto_refresh_window: + now = datetime.now(timezone.utc) + # if the expiry date not inside of the expiry window, do not update. + if not ( + now + >= (expiration - timedelta(seconds=self.auto_refresh_window)) + and now <= expiration + ): # noqa E501 + update_session = False + elif self.persist_session: update_session = False + else: + scope["session"] = {} else: scope["session"] = {} - async def send_wrapper(message: Message) -> None: session_changed = False if message["type"] == "http.response.start": if self.session_cookie in connection.cookies: - previous_session_data = self.decode_cookie(connection.cookies[self.session_cookie].encode("utf-8")) # noqa E501 - if (previous_session_data["session"] and scope["session"]) and previous_session_data["session"] != scope["session"]: # noqa E501 - session_changed = True - + previous_session_data = self.decode_cookie( + connection.cookies[self.session_cookie].encode("utf-8") + ) # noqa E501 + if ( + previous_session_data["session"] and scope["session"] + ) and previous_session_data["session"] != scope["session"]: # noqa E501 + session_changed = True + if scope["session"] and (update_session or session_changed): # We have data that needs to be persisted or refreshed. data = b64encode(json.dumps(scope["session"]).encode("utf-8")) @@ -107,4 +121,3 @@ async def send_wrapper(message: Message) -> None: await send(message) await self.app(scope, receive, send_wrapper) -