Skip to content

Commit

Permalink
And all tests pass now.
Browse files Browse the repository at this point in the history
  • Loading branch information
ToasterChicken committed Jan 7, 2024
1 parent 28b2be0 commit 6885bce
Showing 1 changed file with 37 additions and 24 deletions.
61 changes: 37 additions & 24 deletions starlette/middleware/sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def __init__(
same_site: typing.Literal["lax", "strict", "none"] = "lax",
https_only: bool = False,
persist_session: bool = False,
auto_refresh_window: int = 0,
auto_refresh_window: int = 0,
domain: typing.Optional[str] = None,
) -> None:
self.app = app
Expand All @@ -38,17 +38,19 @@ def __init__(
if domain is not None:
self.security_flags += f"; domain={domain}"


def decode_cookie(self,cookie: bytes) -> typing.Dict[str,typing.Any]:
result: typing.Dict[str, typing.Any] = {"session": {}}
def decode_cookie(self, cookie: bytes) -> typing.Dict[str, typing.Any]:
result: typing.Dict[str, typing.Any] = {"session": {}}
try:
data = self.signer.unsign(cookie,max_age=self.max_age,return_timestamp=True)
result["session"] = json.loads(b64decode(data[0]))
result["datetime"] = data[1] #DateTime obj
data = self.signer.unsign(
cookie, max_age=self.max_age, return_timestamp=True
)
result["session"] = json.loads(b64decode(data[0]))

result["datetime"] = data[1] # DateTime obj
except (BadSignature, SignatureExpired):
return result
return result

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] not in ("http", "websocket"): # pragma: no cover
await self.app(scope, receive, send)
Expand All @@ -58,29 +60,41 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
update_session = True

if self.session_cookie in connection.cookies:
data = self.decode_cookie(connection.cookies[self.session_cookie].encode("utf-8")) # noqa E501
scope["session"] = data["session"]
scope["exp"] = data["datetime"] + timedelta(seconds=self.max_age) # type: ignore[arg-type]

if self.auto_refresh_window:
now = datetime.now(timezone.utc)
#if the expiry date not inside of the expiry window, do not update.
if not (now >= (scope["exp"] - timedelta(seconds=self.auto_refresh_window)) and now <= scope["exp"]): # noqa E501
update_session = False
elif self.persist_session:
data = self.decode_cookie(
connection.cookies[self.session_cookie].encode("utf-8")
) # noqa E501
if data["session"]:
scope["session"] = data["session"]
expiration = data["datetime"] + timedelta(seconds=self.max_age) # type: ignore[arg-type]

if self.auto_refresh_window:
now = datetime.now(timezone.utc)
# if the expiry date not inside of the expiry window, do not update.
if not (
now
>= (expiration - timedelta(seconds=self.auto_refresh_window))
and now <= expiration
): # noqa E501
update_session = False
elif self.persist_session:
update_session = False
else:
scope["session"] = {}
else:
scope["session"] = {}


async def send_wrapper(message: Message) -> None:
session_changed = False
if message["type"] == "http.response.start":
if self.session_cookie in connection.cookies:
previous_session_data = self.decode_cookie(connection.cookies[self.session_cookie].encode("utf-8")) # noqa E501
if (previous_session_data["session"] and scope["session"]) and previous_session_data["session"] != scope["session"]: # noqa E501
session_changed = True

previous_session_data = self.decode_cookie(
connection.cookies[self.session_cookie].encode("utf-8")
) # noqa E501
if (
previous_session_data["session"] and scope["session"]
) and previous_session_data["session"] != scope["session"]: # noqa E501
session_changed = True

if scope["session"] and (update_session or session_changed):
# We have data that needs to be persisted or refreshed.
data = b64encode(json.dumps(scope["session"]).encode("utf-8"))
Expand All @@ -107,4 +121,3 @@ async def send_wrapper(message: Message) -> None:
await send(message)

await self.app(scope, receive, send_wrapper)

0 comments on commit 6885bce

Please sign in to comment.