From 19bc554296264df934f674ea7f32d9f2b70b2c58 Mon Sep 17 00:00:00 2001 From: Raymond Penners Date: Wed, 8 Nov 2023 19:04:20 +0100 Subject: [PATCH] feat(mfa): Reauthenticate --- allauth/account/adapter.py | 26 +++++++ allauth/account/decorators.py | 10 +-- allauth/account/reauthentication.py | 3 +- .../account/tests/test_reauthentication.py | 33 +++++++++ allauth/account/views.py | 68 +++++++++++++++---- allauth/conftest.py | 11 ++- allauth/mfa/urls.py | 1 + allauth/mfa/views.py | 15 ++++ .../account/base_reauthenticate.html | 27 ++++++++ allauth/templates/account/reauthenticate.html | 18 ++--- allauth/templates/allauth/elements/hr.html | 1 + allauth/templates/mfa/reauthenticate.html | 24 +++++++ .../socialaccount/snippets/login.html | 3 +- 13 files changed, 203 insertions(+), 37 deletions(-) create mode 100644 allauth/account/tests/test_reauthentication.py create mode 100644 allauth/templates/account/base_reauthenticate.html create mode 100644 allauth/templates/allauth/elements/hr.html create mode 100644 allauth/templates/mfa/reauthenticate.html diff --git a/allauth/account/adapter.py b/allauth/account/adapter.py index 7fb96cb483..ab52f08ab2 100644 --- a/allauth/account/adapter.py +++ b/allauth/account/adapter.py @@ -703,6 +703,32 @@ def get_login_stages(self): ret.append("allauth.mfa.stages.AuthenticateStage") return ret + def get_reauthentication_methods(self, user): + """The order of the methods returned matters. The first method is the + default when using the `@reauthentication_required` decorator. + """ + ret = [] + if not user.is_authenticated: + return ret + if user.has_usable_password(): + ret.append( + { + "description": _("Use your password"), + "url": reverse("account_reauthenticate"), + } + ) + if allauth_app_settings.MFA_ENABLED: + from allauth.mfa.utils import is_mfa_enabled + + if is_mfa_enabled(user): + ret.append( + { + "description": _("Use your authenticator app"), + "url": reverse("mfa_reauthenticate"), + } + ) + return ret + def get_adapter(request=None): return import_attribute(app_settings.ADAPTER)(request) diff --git a/allauth/account/decorators.py b/allauth/account/decorators.py index 5eb3e0e248..0b9964f0e1 100644 --- a/allauth/account/decorators.py +++ b/allauth/account/decorators.py @@ -5,6 +5,7 @@ from django.shortcuts import render from django.urls import reverse +from allauth.account.adapter import get_adapter from allauth.account.models import EmailAddress from allauth.account.reauthentication import ( did_recently_authenticate, @@ -60,11 +61,10 @@ def _wrapper_view(request, *args, **kwargs): ) if ena and not pass_method: if request.user.is_anonymous or not did_recently_authenticate(request): - redirect_url = reverse( - "account_login" - if request.user.is_anonymous - else "account_reauthenticate" - ) + redirect_url = reverse("account_login") + methods = get_adapter().get_reauthentication_methods(request.user) + if methods: + redirect_url = methods[0]["url"] return suspend_request(request, redirect_url) return view_func(request, *args, **kwargs) diff --git a/allauth/account/reauthentication.py b/allauth/account/reauthentication.py index af2aaa7a69..f2df4e20e6 100644 --- a/allauth/account/reauthentication.py +++ b/allauth/account/reauthentication.py @@ -6,6 +6,7 @@ from django.utils.http import urlencode from allauth.account import app_settings +from allauth.account.adapter import get_adapter from allauth.account.utils import get_next_redirect_url from allauth.core.internal.http import deserialize_request, serialize_request from allauth.utils import import_callable @@ -60,7 +61,7 @@ def reauthenticate_then_callback(request, serialize_state, callback): def did_recently_authenticate(request): if request.user.is_anonymous: return False - if not request.user.has_usable_password(): + if not get_adapter().get_reauthentication_methods(request.user): # TODO: This user only has social accounts attached. Now, ideally, you # would want to reauthenticate over at the social account provider. For # now, this is not implemented. Although definitely suboptimal, this diff --git a/allauth/account/tests/test_reauthentication.py b/allauth/account/tests/test_reauthentication.py new file mode 100644 index 0000000000..a9a62a7420 --- /dev/null +++ b/allauth/account/tests/test_reauthentication.py @@ -0,0 +1,33 @@ +from django.urls import reverse + +import pytest + +from allauth.account.adapter import get_adapter + + +@pytest.mark.parametrize( + "with_totp,with_password,expected_method_urlnames", + [ + (False, True, ["account_reauthenticate"]), + (True, True, ["account_reauthenticate", "mfa_reauthenticate"]), + (True, False, ["mfa_reauthenticate"]), + ], +) +def test_user_with_mfa_only( + user_factory, with_totp, with_password, expected_method_urlnames, client +): + user = user_factory(with_totp=with_totp, password=None if with_password else "!") + assert user.has_usable_password() == with_password + client.force_login(user) + methods = get_adapter().get_reauthentication_methods(user) + assert len(methods) == len(expected_method_urlnames) + assert set([m["url"] for m in methods]) == set( + map(reverse, expected_method_urlnames) + ) + for urlname in ["account_reauthenticate", "mfa_reauthenticate"]: + resp = client.get(reverse(urlname) + "?next=/foo") + if urlname in expected_method_urlnames: + assert resp.status_code == 200 + else: + assert resp.status_code == 302 + assert "next=%2Ffoo" in resp["location"] diff --git a/allauth/account/views.py b/allauth/account/views.py index 2d08bfb457..83e48b8e39 100644 --- a/allauth/account/views.py +++ b/allauth/account/views.py @@ -2,6 +2,7 @@ from django.contrib.auth import REDIRECT_FIELD_NAME from django.contrib.auth.decorators import login_required from django.contrib.sites.shortcuts import get_current_site +from django.core.exceptions import PermissionDenied from django.core.validators import validate_email from django.forms import ValidationError from django.http import ( @@ -981,23 +982,37 @@ class EmailVerificationSentView(TemplateView): email_verification_sent = EmailVerificationSentView.as_view() -class ReauthenticateView(FormView): - form_class = ReauthenticateForm - template_name = "account/reauthenticate." + app_settings.TEMPLATE_EXTENSION +class BaseReauthenticateView(FormView): redirect_field_name = REDIRECT_FIELD_NAME def dispatch(self, request, *args, **kwargs): - r429 = ratelimit.consume_or_429( + resp = self._check_reauthentication_method_available(request) + if resp: + return resp + resp = self._check_ratelimit(request) + if resp: + return resp + return super().dispatch(request, *args, **kwargs) + + def _check_ratelimit(self, request): + return ratelimit.consume_or_429( self.request, action="reauthenticate", user=self.request.user, ) - if r429: - return r429 - return super().dispatch(request, *args, **kwargs) - def get_form_class(self): - return get_form_class(app_settings.FORMS, "reauthenticate", self.form_class) + def _check_reauthentication_method_available(self, request): + methods = get_adapter().get_reauthentication_methods(self.request.user) + if any([m["url"] == request.path for m in methods]): + # Method is available + return None + if not methods: + # Reauthentication not available + raise PermissionDenied("Reauthentication not available") + url = passthrough_next_redirect_url( + request, methods[0]["url"], self.redirect_field_name + ) + return HttpResponseRedirect(url) def get_success_url(self): url = get_next_redirect_url(self.request, self.redirect_field_name) @@ -1005,11 +1020,6 @@ def get_success_url(self): url = get_adapter(self.request).get_login_redirect_url(self.request) return url - def get_form_kwargs(self): - ret = super().get_form_kwargs() - ret["user"] = self.request.user - return ret - def form_valid(self, form): record_authentication(self.request, self.request.user) response = resume_request(self.request) @@ -1018,15 +1028,43 @@ def form_valid(self, form): return super().form_valid(form) def get_context_data(self, **kwargs): - ret = super(ReauthenticateView, self).get_context_data(**kwargs) + ret = super().get_context_data(**kwargs) redirect_field_value = get_request_param(self.request, self.redirect_field_name) ret.update( { "redirect_field_name": self.redirect_field_name, "redirect_field_value": redirect_field_value, + "reauthentication_alternatives": self.get_reauthentication_alternatives(), } ) return ret + def get_reauthentication_alternatives(self): + methods = get_adapter().get_reauthentication_methods(self.request.user) + alts = [] + for method in methods: + alt = dict(method) + if self.request.path == alt["url"]: + continue + alt["url"] = passthrough_next_redirect_url( + self.request, alt["url"], self.redirect_field_name + ) + alts.append(alt) + alts = sorted(alts, key=lambda alt: alt["description"]) + return alts + + +class ReauthenticateView(BaseReauthenticateView): + form_class = ReauthenticateForm + template_name = "account/reauthenticate." + app_settings.TEMPLATE_EXTENSION + + def get_form_class(self): + return get_form_class(app_settings.FORMS, "reauthenticate", self.form_class) + + def get_form_kwargs(self): + ret = super().get_form_kwargs() + ret["user"] = self.request.user + return ret + reauthenticate = login_required(ReauthenticateView.as_view()) diff --git a/allauth/conftest.py b/allauth/conftest.py index 59b1f0b9f9..a5d36bcb0a 100644 --- a/allauth/conftest.py +++ b/allauth/conftest.py @@ -37,6 +37,8 @@ def user_password(password_factory): @pytest.fixture def user_factory(email_factory, db, user_password): + from allauth.mfa import totp + def factory( email=None, username=None, @@ -45,6 +47,7 @@ def factory( email_verified=True, password=None, with_emailaddress=True, + with_totp=False, ): if not username: username = uuid.uuid4().hex @@ -54,7 +57,10 @@ def factory( User = get_user_model() user = User() - user.set_password(user_password if password is None else password) + if password == "!": + user.password = password + else: + user.set_password(user_password if password is None else password) user_username(user, username) user_email(user, email or "") if commit: @@ -63,7 +69,8 @@ def factory( EmailAddress.objects.create( user=user, email=email, verified=email_verified, primary=True ) - + if with_totp: + totp.TOTP.activate(user, totp.generate_totp_secret()) return user return factory diff --git a/allauth/mfa/urls.py b/allauth/mfa/urls.py index 04c99bde3d..3186a29d49 100644 --- a/allauth/mfa/urls.py +++ b/allauth/mfa/urls.py @@ -6,6 +6,7 @@ urlpatterns = [ path("", views.index, name="mfa_index"), path("authenticate/", views.authenticate, name="mfa_authenticate"), + path("reauthenticate/", views.reauthenticate, name="mfa_reauthenticate"), path( "totp/", include( diff --git a/allauth/mfa/views.py b/allauth/mfa/views.py index 5a0623e29b..a41732f1d3 100644 --- a/allauth/mfa/views.py +++ b/allauth/mfa/views.py @@ -14,6 +14,7 @@ from allauth.account.adapter import get_adapter as get_account_adapter from allauth.account.decorators import reauthentication_required from allauth.account.stages import LoginStageController +from allauth.account.views import BaseReauthenticateView from allauth.mfa import app_settings, signals, totp from allauth.mfa.adapter import get_adapter from allauth.mfa.forms import ActivateTOTPForm, AuthenticateForm @@ -48,6 +49,20 @@ def form_valid(self, form): authenticate = AuthenticateView.as_view() +@method_decorator(login_required, name="dispatch") +class ReauthenticateView(BaseReauthenticateView): + form_class = AuthenticateForm + template_name = "mfa/reauthenticate." + account_settings.TEMPLATE_EXTENSION + + def get_form_kwargs(self): + ret = super().get_form_kwargs() + ret["user"] = self.request.user + return ret + + +reauthenticate = ReauthenticateView.as_view() + + @method_decorator(login_required, name="dispatch") class IndexView(TemplateView): template_name = "mfa/index." + account_settings.TEMPLATE_EXTENSION diff --git a/allauth/templates/account/base_reauthenticate.html b/allauth/templates/account/base_reauthenticate.html new file mode 100644 index 0000000000..65d4c495e8 --- /dev/null +++ b/allauth/templates/account/base_reauthenticate.html @@ -0,0 +1,27 @@ +{% extends "account/base_entrance.html" %} +{% load allauth %} +{% load i18n %} +{% block head_title %} + {% trans "Confirm Access" %} +{% endblock head_title %} +{% block content %} + {% element h1 %} + {% trans "Confirm Access" %} + {% endelement %} +

{% blocktranslate %}Please reauthenticate to safeguard your account.{% endblocktranslate %}

+ {% block reauthenticate_content %}{% endblock %} + {% if reauthentication_alternatives %} + {% element hr %} + {% endelement %} + {% element h2 %} + {% translate "Alternative options" %} + {% endelement %} + + {% endif %} +{% endblock content %} diff --git a/allauth/templates/account/reauthenticate.html b/allauth/templates/account/reauthenticate.html index 16ae1fcc1e..38c22def50 100644 --- a/allauth/templates/account/reauthenticate.html +++ b/allauth/templates/account/reauthenticate.html @@ -1,21 +1,13 @@ -{% extends "account/base_manage.html" %} +{% extends "account/base_reauthenticate.html" %} {% load allauth %} {% load i18n %} -{% block head_title %} - {% trans "Confirm Access" %} -{% endblock head_title %} -{% block content %} - {% element h1 %} - {% trans "Confirm Access" %} - {% endelement %} -

- {% blocktranslate %}To safeguard the security of your account, please enter your password:{% endblocktranslate %} -

+{% block reauthenticate_content %} +

{% blocktranslate %}Enter your password:{% endblocktranslate %}

{% url 'account_reauthenticate' as action_url %} {% element form form=form method="post" action=action_url %} {% slot body %} {% csrf_token %} - {% element fields form=form %} + {% element fields form=form unlabeled=True %} {% endelement %} {% if redirect_field_value %} diff --git a/allauth/templates/mfa/reauthenticate.html b/allauth/templates/mfa/reauthenticate.html new file mode 100644 index 0000000000..749af87af1 --- /dev/null +++ b/allauth/templates/mfa/reauthenticate.html @@ -0,0 +1,24 @@ +{% extends "account/base_reauthenticate.html" %} +{% load i18n %} +{% load allauth %} +{% block reauthenticate_content %} +

{% blocktranslate %}Enter an authenticator code:{% endblocktranslate %}

+ {% url 'mfa_reauthenticate' as action_url %} + {% element form form=form method="post" action=action_url %} + {% slot body %} + {% csrf_token %} + {% element fields form=form unlabeled=True %} + {% endelement %} + {% if redirect_field_value %} + + {% endif %} + {% endslot %} + {% slot actions %} + {% element button type="submit" tags="primary,mfa,login" %} + {% trans "Confirm" %} + {% endelement %} + {% endslot %} + {% endelement %} +{% endblock %} diff --git a/allauth/templates/socialaccount/snippets/login.html b/allauth/templates/socialaccount/snippets/login.html index e112ef9389..9ec904e8a2 100644 --- a/allauth/templates/socialaccount/snippets/login.html +++ b/allauth/templates/socialaccount/snippets/login.html @@ -3,7 +3,8 @@ {% load socialaccount %} {% get_providers as socialaccount_providers %} {% if socialaccount_providers %} -
+ {% element hr %} + {% endelement %} {% element h2 %} {% translate "Or use a third-party" %} {% endelement %}