Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Session cookie fix #2401

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/middleware.md
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ The following arguments are supported:
* `max_age` - Session expiry time in seconds. Defaults to 2 weeks. If set to `None` then the cookie will last as long as the browser session.
* `same_site` - SameSite flag prevents the browser from sending session cookie along with cross-site requests. Defaults to `'lax'`.
* `https_only` - Indicate that Secure flag should be set (can be used with HTTPS only). Defaults to `False`.
* `persist_session` - Sets the session cookie that's created to persist between connections. Defaults to `False`.
* `auto_refresh_window` - Refresh window in seconds before max_age. If the cookies age is max_age - auto_refresh_window the cookie will be refreshed with a new session cookie. Default is 0 seconds, and if set overrides persist_session.
* `domain` - Domain of the cookie used to share cookie between subdomains or cross-domains. The browser defaults the domain to the same host that set the cookie, excluding subdomains [refrence](https://developer.mozilla.org/en-US/docs/Web/HTTP/Cookies#domain_attribute).


Expand Down
64 changes: 52 additions & 12 deletions starlette/middleware/sessions.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import json
import typing
from base64 import b64decode, b64encode
from datetime import datetime, timedelta, timezone

import itsdangerous
from itsdangerous.exc import BadSignature
from itsdangerous.exc import BadSignature, SignatureExpired

from starlette.datastructures import MutableHeaders, Secret
from starlette.requests import HTTPConnection
Expand All @@ -20,6 +21,8 @@ def __init__(
path: str = "/",
same_site: typing.Literal["lax", "strict", "none"] = "lax",
https_only: bool = False,
persist_session: bool = False,
auto_refresh_window: int = 0,
domain: typing.Optional[str] = None,
) -> None:
self.app = app
Expand All @@ -28,34 +31,72 @@ def __init__(
self.max_age = max_age
self.path = path
self.security_flags = "httponly; samesite=" + same_site
self.persist_session = persist_session
self.auto_refresh_window = auto_refresh_window
if https_only: # Secure flag can be used with HTTPS only
self.security_flags += "; secure"
if domain is not None:
self.security_flags += f"; domain={domain}"

def decode_cookie(self, cookie: bytes) -> typing.Dict[str, typing.Any]:
result: typing.Dict[str, typing.Any] = {"session": {}}
try:
data = self.signer.unsign(
cookie, max_age=self.max_age, return_timestamp=True
)
result["session"] = json.loads(b64decode(data[0]))

result["datetime"] = data[1] # DateTime obj
except (BadSignature, SignatureExpired):
return result
return result

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] not in ("http", "websocket"): # pragma: no cover
await self.app(scope, receive, send)
return

connection = HTTPConnection(scope)
initial_session_was_empty = True
update_session = True

if self.session_cookie in connection.cookies:
data = connection.cookies[self.session_cookie].encode("utf-8")
try:
data = self.signer.unsign(data, max_age=self.max_age)
scope["session"] = json.loads(b64decode(data))
initial_session_was_empty = False
except BadSignature:
data = self.decode_cookie(
connection.cookies[self.session_cookie].encode("utf-8")
) # noqa E501
if data["session"]:
scope["session"] = data["session"]
expiration = data["datetime"] + timedelta(seconds=self.max_age) # type: ignore[arg-type]

if self.auto_refresh_window:
now = datetime.now(timezone.utc)
# if the expiry date not inside of the expiry window, do not update.
if not (
now
>= (expiration - timedelta(seconds=self.auto_refresh_window))
and now <= expiration
): # noqa E501
update_session = False
elif self.persist_session:
update_session = False
else:
scope["session"] = {}
else:
scope["session"] = {}

async def send_wrapper(message: Message) -> None:
session_changed = False
if message["type"] == "http.response.start":
if scope["session"]:
# We have session data to persist.
if self.session_cookie in connection.cookies:
previous_session_data = self.decode_cookie(
connection.cookies[self.session_cookie].encode("utf-8")
) # noqa E501
if (
previous_session_data["session"] and scope["session"]
) and previous_session_data["session"] != scope["session"]: # noqa E501
session_changed = True

if scope["session"] and (update_session or session_changed):
# We have data that needs to be persisted or refreshed.
data = b64encode(json.dumps(scope["session"]).encode("utf-8"))
data = self.signer.sign(data)
headers = MutableHeaders(scope=message)
Expand All @@ -67,8 +108,7 @@ async def send_wrapper(message: Message) -> None:
security_flags=self.security_flags,
)
headers.append("Set-Cookie", header_value)
elif not initial_session_was_empty:
# The session has been cleared.
elif update_session and not scope["session"]:
headers = MutableHeaders(scope=message)
header_value = "{session_cookie}={data}; path={path}; {expires}{security_flags}".format( # noqa E501
session_cookie=self.session_cookie,
Expand Down
Loading