Skip to content

Commit

Permalink
refactor(account/stages): Internalize stage related logic
Browse files Browse the repository at this point in the history
  • Loading branch information
pennersr committed Sep 20, 2024
1 parent 0812ed7 commit 7eda0e3
Show file tree
Hide file tree
Showing 9 changed files with 44 additions and 47 deletions.
3 changes: 2 additions & 1 deletion allauth/account/forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from django.utils.translation import gettext, gettext_lazy as _, pgettext

from allauth.account.internal import flows
from allauth.account.internal.stagekit import LOGIN_SESSION_KEY
from allauth.account.stages import EmailVerificationStage
from allauth.core import context, ratelimit
from allauth.utils import get_username_max_length, set_form_field_order
Expand Down Expand Up @@ -384,7 +385,7 @@ def try_save(self, request):
resp = flows.signup.prevent_enumeration(request, email)
user = None
# Fake a login stage.
request.session[flows.login.LOGIN_SESSION_KEY] = EmailVerificationStage.key
request.session[LOGIN_SESSION_KEY] = EmailVerificationStage.key
else:
user = self.save(request)
resp = None
Expand Down
1 change: 0 additions & 1 deletion allauth/account/internal/flows/login.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from allauth.core.exceptions import ImmediateHttpResponse


LOGIN_SESSION_KEY = "account_login"
AUTHENTICATION_METHODS_SESSION_KEY = "account_authentication_methods"


Expand Down
6 changes: 3 additions & 3 deletions allauth/account/internal/flows/login_by_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
record_authentication,
)
from allauth.account.internal.flows.signup import send_unknown_account_mail
from allauth.account.internal.stagekit import clear_login
from allauth.account.internal.stagekit import clear_login, stash_login
from allauth.account.models import Login


Expand All @@ -26,7 +26,7 @@
def request_login_code(
request: HttpRequest, email: str, login: Optional[Login] = None
) -> None:
from allauth.account.utils import filter_users_by_email, stash_login
from allauth.account.utils import filter_users_by_email

initiated_by_user = login is None
adapter = get_adapter()
Expand Down Expand Up @@ -88,7 +88,7 @@ def get_pending_login(


def record_invalid_attempt(request, login: Login) -> bool:
from allauth.account.utils import stash_login, unstash_login
from allauth.account.internal.stagekit import stash_login, unstash_login

pending_login = login.state[LOGIN_CODE_STATE_KEY]
n = pending_login["failed_attempts"]
Expand Down
35 changes: 31 additions & 4 deletions allauth/account/internal/stagekit.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
import time
from typing import Optional

from django.http import HttpResponseRedirect
from django.urls import reverse

from allauth.account import app_settings
from allauth.account.models import Login
from allauth.account.stages import LoginStage, LoginStageController


def get_pending_stage(request) -> Optional[LoginStage]:
from allauth.account.utils import unstash_login
LOGIN_SESSION_KEY = "account_login"


def get_pending_stage(request) -> Optional[LoginStage]:
stage = None
if not request.user.is_authenticated:
login = unstash_login(request, peek=True)
Expand All @@ -26,6 +30,29 @@ def redirect_to_pending_stage(request, stage: LoginStage):


def clear_login(request):
from allauth.account.internal.flows.login import LOGIN_SESSION_KEY

request.session.pop(LOGIN_SESSION_KEY, None)


def unstash_login(request, peek=False):
login = None
if peek:
data = request.session.get(LOGIN_SESSION_KEY)
else:
data = request.session.pop(LOGIN_SESSION_KEY, None)
if isinstance(data, dict):
try:
login = Login.deserialize(data)
except ValueError:
pass
else:
if time.time() - login.initiated_at > app_settings.LOGIN_TIMEOUT:
login = None
clear_login(request)
else:
request._account_login_accessed = True
return login


def stash_login(request, login):
request.session[LOGIN_SESSION_KEY] = login.serialize()
request._account_login_accessed = True
7 changes: 3 additions & 4 deletions allauth/account/stages.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def handle(self):
return None, True

def exit(self):
from allauth.account.utils import resume_login
from allauth.account.internal.flows.login import resume_login

self.controller.set_handled(self.key)
return resume_login(self.request, self.login)
Expand All @@ -54,7 +54,7 @@ def __init__(self, request, login):

@classmethod
def enter(cls, request, stage_key):
from allauth.account.utils import unstash_login
from allauth.account.internal.stagekit import unstash_login

login = unstash_login(request, peek=True)
if not login:
Expand Down Expand Up @@ -99,8 +99,7 @@ def get_stages(self):
return stages

def handle(self):
from allauth.account.internal.stagekit import clear_login
from allauth.account.utils import stash_login
from allauth.account.internal.stagekit import clear_login, stash_login

stages = self.get_stages()
for stage in stages:
Expand Down
2 changes: 1 addition & 1 deletion allauth/account/tests/test_login_by_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
import pytest

from allauth.account.authentication import AUTHENTICATION_METHODS_SESSION_KEY
from allauth.account.internal.flows.login import LOGIN_SESSION_KEY
from allauth.account.internal.flows.login_by_code import LOGIN_CODE_STATE_KEY
from allauth.account.internal.stagekit import LOGIN_SESSION_KEY
from allauth.account.models import EmailAddress


Expand Down
32 changes: 1 addition & 31 deletions allauth/account/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import time
import unicodedata
from collections import OrderedDict
from typing import Optional
Expand All @@ -14,7 +13,7 @@

from allauth.account import app_settings, signals
from allauth.account.adapter import get_adapter
from allauth.account.internal import flows, stagekit
from allauth.account.internal import flows
from allauth.account.models import Login
from allauth.core.internal import httpkit
from allauth.utils import (
Expand Down Expand Up @@ -162,35 +161,6 @@ def perform_login(
return flows.login.perform_login(request, login)


def resume_login(request, login):
return flows.login.resume_login(request, login)


def unstash_login(request, peek=False):
login = None
if peek:
data = request.session.get(flows.login.LOGIN_SESSION_KEY)
else:
data = request.session.pop(flows.login.LOGIN_SESSION_KEY, None)
if isinstance(data, dict):
try:
login = Login.deserialize(data)
except ValueError:
pass
else:
if time.time() - login.initiated_at > app_settings.LOGIN_TIMEOUT:
login = None
stagekit.clear_login(request)
else:
request._account_login_accessed = True
return login


def stash_login(request, login):
request.session[flows.login.LOGIN_SESSION_KEY] = login.serialize()
request._account_login_accessed = True


def complete_signup(request, user, email_verification, success_url, signal_kwargs=None):
if signal_kwargs is None:
signal_kwargs = {}
Expand Down
3 changes: 2 additions & 1 deletion allauth/headless/base/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from allauth.account.adapter import get_adapter as get_account_adapter
from allauth.account.authentication import get_authentication_records
from allauth.account.internal import flows
from allauth.account.internal.stagekit import LOGIN_SESSION_KEY
from allauth.headless.adapter import get_adapter
from allauth.headless.constants import Flow
from allauth.headless.internal import authkit
Expand Down Expand Up @@ -61,7 +62,7 @@ def _get_flows(self, request, user):
if stage:
stage_key = stage.key
else:
lsk = request.session.get(flows.login.LOGIN_SESSION_KEY)
lsk = request.session.get(LOGIN_SESSION_KEY)
if isinstance(lsk, str):
stage_key = lsk
if stage_key:
Expand Down
2 changes: 1 addition & 1 deletion allauth/mfa/totp/tests/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def test_totp_stage_expires(client, user_with_totp, user_password):
assert resp.status_code == 200
assertTemplateUsed(resp, "mfa/authenticate.html")
with patch(
"allauth.account.utils.time.time",
"allauth.account.internal.stagekit.time.time",
return_value=time.time() + 1.1 * app_settings.LOGIN_TIMEOUT,
):
resp = client.get(reverse("mfa_authenticate"))
Expand Down

0 comments on commit 7eda0e3

Please sign in to comment.