Skip to content

Commit

Permalink
feat(mfa): Reauthenticate
Browse files Browse the repository at this point in the history
  • Loading branch information
pennersr committed Nov 8, 2023
1 parent 108f3cd commit 19bc554
Show file tree
Hide file tree
Showing 13 changed files with 203 additions and 37 deletions.
26 changes: 26 additions & 0 deletions allauth/account/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -703,6 +703,32 @@ def get_login_stages(self):
ret.append("allauth.mfa.stages.AuthenticateStage")
return ret

def get_reauthentication_methods(self, user):
"""The order of the methods returned matters. The first method is the
default when using the `@reauthentication_required` decorator.
"""
ret = []
if not user.is_authenticated:
return ret
if user.has_usable_password():
ret.append(
{
"description": _("Use your password"),
"url": reverse("account_reauthenticate"),
}
)
if allauth_app_settings.MFA_ENABLED:
from allauth.mfa.utils import is_mfa_enabled

if is_mfa_enabled(user):
ret.append(
{
"description": _("Use your authenticator app"),
"url": reverse("mfa_reauthenticate"),
}
)
return ret


def get_adapter(request=None):
return import_attribute(app_settings.ADAPTER)(request)
10 changes: 5 additions & 5 deletions allauth/account/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from django.shortcuts import render
from django.urls import reverse

from allauth.account.adapter import get_adapter
from allauth.account.models import EmailAddress
from allauth.account.reauthentication import (
did_recently_authenticate,
Expand Down Expand Up @@ -60,11 +61,10 @@ def _wrapper_view(request, *args, **kwargs):
)
if ena and not pass_method:
if request.user.is_anonymous or not did_recently_authenticate(request):
redirect_url = reverse(
"account_login"
if request.user.is_anonymous
else "account_reauthenticate"
)
redirect_url = reverse("account_login")
methods = get_adapter().get_reauthentication_methods(request.user)
if methods:
redirect_url = methods[0]["url"]
return suspend_request(request, redirect_url)
return view_func(request, *args, **kwargs)

Expand Down
3 changes: 2 additions & 1 deletion allauth/account/reauthentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from django.utils.http import urlencode

from allauth.account import app_settings
from allauth.account.adapter import get_adapter
from allauth.account.utils import get_next_redirect_url
from allauth.core.internal.http import deserialize_request, serialize_request
from allauth.utils import import_callable
Expand Down Expand Up @@ -60,7 +61,7 @@ def reauthenticate_then_callback(request, serialize_state, callback):
def did_recently_authenticate(request):
if request.user.is_anonymous:
return False
if not request.user.has_usable_password():
if not get_adapter().get_reauthentication_methods(request.user):
# TODO: This user only has social accounts attached. Now, ideally, you
# would want to reauthenticate over at the social account provider. For
# now, this is not implemented. Although definitely suboptimal, this
Expand Down
33 changes: 33 additions & 0 deletions allauth/account/tests/test_reauthentication.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from django.urls import reverse

import pytest

from allauth.account.adapter import get_adapter


@pytest.mark.parametrize(
"with_totp,with_password,expected_method_urlnames",
[
(False, True, ["account_reauthenticate"]),
(True, True, ["account_reauthenticate", "mfa_reauthenticate"]),
(True, False, ["mfa_reauthenticate"]),
],
)
def test_user_with_mfa_only(
user_factory, with_totp, with_password, expected_method_urlnames, client
):
user = user_factory(with_totp=with_totp, password=None if with_password else "!")
assert user.has_usable_password() == with_password
client.force_login(user)
methods = get_adapter().get_reauthentication_methods(user)
assert len(methods) == len(expected_method_urlnames)
assert set([m["url"] for m in methods]) == set(
map(reverse, expected_method_urlnames)
)
for urlname in ["account_reauthenticate", "mfa_reauthenticate"]:
resp = client.get(reverse(urlname) + "?next=/foo")
if urlname in expected_method_urlnames:
assert resp.status_code == 200
else:
assert resp.status_code == 302
assert "next=%2Ffoo" in resp["location"]
68 changes: 53 additions & 15 deletions allauth/account/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from django.contrib.auth import REDIRECT_FIELD_NAME
from django.contrib.auth.decorators import login_required
from django.contrib.sites.shortcuts import get_current_site
from django.core.exceptions import PermissionDenied
from django.core.validators import validate_email
from django.forms import ValidationError
from django.http import (
Expand Down Expand Up @@ -981,35 +982,44 @@ class EmailVerificationSentView(TemplateView):
email_verification_sent = EmailVerificationSentView.as_view()


class ReauthenticateView(FormView):
form_class = ReauthenticateForm
template_name = "account/reauthenticate." + app_settings.TEMPLATE_EXTENSION
class BaseReauthenticateView(FormView):
redirect_field_name = REDIRECT_FIELD_NAME

def dispatch(self, request, *args, **kwargs):
r429 = ratelimit.consume_or_429(
resp = self._check_reauthentication_method_available(request)
if resp:
return resp
resp = self._check_ratelimit(request)
if resp:
return resp
return super().dispatch(request, *args, **kwargs)

def _check_ratelimit(self, request):
return ratelimit.consume_or_429(
self.request,
action="reauthenticate",
user=self.request.user,
)
if r429:
return r429
return super().dispatch(request, *args, **kwargs)

def get_form_class(self):
return get_form_class(app_settings.FORMS, "reauthenticate", self.form_class)
def _check_reauthentication_method_available(self, request):
methods = get_adapter().get_reauthentication_methods(self.request.user)
if any([m["url"] == request.path for m in methods]):
# Method is available
return None
if not methods:
# Reauthentication not available
raise PermissionDenied("Reauthentication not available")
url = passthrough_next_redirect_url(
request, methods[0]["url"], self.redirect_field_name
)
return HttpResponseRedirect(url)

def get_success_url(self):
url = get_next_redirect_url(self.request, self.redirect_field_name)
if not url:
url = get_adapter(self.request).get_login_redirect_url(self.request)
return url

def get_form_kwargs(self):
ret = super().get_form_kwargs()
ret["user"] = self.request.user
return ret

def form_valid(self, form):
record_authentication(self.request, self.request.user)
response = resume_request(self.request)
Expand All @@ -1018,15 +1028,43 @@ def form_valid(self, form):
return super().form_valid(form)

def get_context_data(self, **kwargs):
ret = super(ReauthenticateView, self).get_context_data(**kwargs)
ret = super().get_context_data(**kwargs)
redirect_field_value = get_request_param(self.request, self.redirect_field_name)
ret.update(
{
"redirect_field_name": self.redirect_field_name,
"redirect_field_value": redirect_field_value,
"reauthentication_alternatives": self.get_reauthentication_alternatives(),
}
)
return ret

def get_reauthentication_alternatives(self):
methods = get_adapter().get_reauthentication_methods(self.request.user)
alts = []
for method in methods:
alt = dict(method)
if self.request.path == alt["url"]:
continue
alt["url"] = passthrough_next_redirect_url(
self.request, alt["url"], self.redirect_field_name
)
alts.append(alt)
alts = sorted(alts, key=lambda alt: alt["description"])
return alts


class ReauthenticateView(BaseReauthenticateView):
form_class = ReauthenticateForm
template_name = "account/reauthenticate." + app_settings.TEMPLATE_EXTENSION

def get_form_class(self):
return get_form_class(app_settings.FORMS, "reauthenticate", self.form_class)

def get_form_kwargs(self):
ret = super().get_form_kwargs()
ret["user"] = self.request.user
return ret


reauthenticate = login_required(ReauthenticateView.as_view())
11 changes: 9 additions & 2 deletions allauth/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ def user_password(password_factory):

@pytest.fixture
def user_factory(email_factory, db, user_password):
from allauth.mfa import totp

def factory(
email=None,
username=None,
Expand All @@ -45,6 +47,7 @@ def factory(
email_verified=True,
password=None,
with_emailaddress=True,
with_totp=False,
):
if not username:
username = uuid.uuid4().hex
Expand All @@ -54,7 +57,10 @@ def factory(

User = get_user_model()
user = User()
user.set_password(user_password if password is None else password)
if password == "!":
user.password = password
else:
user.set_password(user_password if password is None else password)
user_username(user, username)
user_email(user, email or "")
if commit:
Expand All @@ -63,7 +69,8 @@ def factory(
EmailAddress.objects.create(
user=user, email=email, verified=email_verified, primary=True
)

if with_totp:
totp.TOTP.activate(user, totp.generate_totp_secret())
return user

return factory
Expand Down
1 change: 1 addition & 0 deletions allauth/mfa/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
urlpatterns = [
path("", views.index, name="mfa_index"),
path("authenticate/", views.authenticate, name="mfa_authenticate"),
path("reauthenticate/", views.reauthenticate, name="mfa_reauthenticate"),
path(
"totp/",
include(
Expand Down
15 changes: 15 additions & 0 deletions allauth/mfa/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from allauth.account.adapter import get_adapter as get_account_adapter
from allauth.account.decorators import reauthentication_required
from allauth.account.stages import LoginStageController
from allauth.account.views import BaseReauthenticateView
from allauth.mfa import app_settings, signals, totp
from allauth.mfa.adapter import get_adapter
from allauth.mfa.forms import ActivateTOTPForm, AuthenticateForm
Expand Down Expand Up @@ -48,6 +49,20 @@ def form_valid(self, form):
authenticate = AuthenticateView.as_view()


@method_decorator(login_required, name="dispatch")
class ReauthenticateView(BaseReauthenticateView):
form_class = AuthenticateForm
template_name = "mfa/reauthenticate." + account_settings.TEMPLATE_EXTENSION

def get_form_kwargs(self):
ret = super().get_form_kwargs()
ret["user"] = self.request.user
return ret


reauthenticate = ReauthenticateView.as_view()


@method_decorator(login_required, name="dispatch")
class IndexView(TemplateView):
template_name = "mfa/index." + account_settings.TEMPLATE_EXTENSION
Expand Down
27 changes: 27 additions & 0 deletions allauth/templates/account/base_reauthenticate.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
{% extends "account/base_entrance.html" %}
{% load allauth %}
{% load i18n %}
{% block head_title %}
{% trans "Confirm Access" %}
{% endblock head_title %}
{% block content %}
{% element h1 %}
{% trans "Confirm Access" %}
{% endelement %}
<p>{% blocktranslate %}Please reauthenticate to safeguard your account.{% endblocktranslate %}</p>
{% block reauthenticate_content %}{% endblock %}
{% if reauthentication_alternatives %}
{% element hr %}
{% endelement %}
{% element h2 %}
{% translate "Alternative options" %}
{% endelement %}
<ul>
{% for alt in reauthentication_alternatives %}
<li>
<a href="{{ alt.url }}">{{ alt.description }}</a>
</li>
{% endfor %}
</ul>
{% endif %}
{% endblock content %}
18 changes: 5 additions & 13 deletions allauth/templates/account/reauthenticate.html
Original file line number Diff line number Diff line change
@@ -1,21 +1,13 @@
{% extends "account/base_manage.html" %}
{% extends "account/base_reauthenticate.html" %}
{% load allauth %}
{% load i18n %}
{% block head_title %}
{% trans "Confirm Access" %}
{% endblock head_title %}
{% block content %}
{% element h1 %}
{% trans "Confirm Access" %}
{% endelement %}
<p>
{% blocktranslate %}To safeguard the security of your account, please enter your password:{% endblocktranslate %}
</p>
{% block reauthenticate_content %}
<p>{% blocktranslate %}Enter your password:{% endblocktranslate %}</p>
{% url 'account_reauthenticate' as action_url %}
{% element form form=form method="post" action=action_url %}
{% slot body %}
{% csrf_token %}
{% element fields form=form %}
{% element fields form=form unlabeled=True %}
{% endelement %}
{% if redirect_field_value %}
<input type="hidden"
Expand All @@ -29,4 +21,4 @@
{% endelement %}
{% endslot %}
{% endelement %}
{% endblock content %}
{% endblock %}
1 change: 1 addition & 0 deletions allauth/templates/allauth/elements/hr.html
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
<hr>
24 changes: 24 additions & 0 deletions allauth/templates/mfa/reauthenticate.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
{% extends "account/base_reauthenticate.html" %}
{% load i18n %}
{% load allauth %}
{% block reauthenticate_content %}
<p>{% blocktranslate %}Enter an authenticator code:{% endblocktranslate %}</p>
{% url 'mfa_reauthenticate' as action_url %}
{% element form form=form method="post" action=action_url %}
{% slot body %}
{% csrf_token %}
{% element fields form=form unlabeled=True %}
{% endelement %}
{% if redirect_field_value %}
<input type="hidden"
name="{{ redirect_field_name }}"
value="{{ redirect_field_value }}" />
{% endif %}
{% endslot %}
{% slot actions %}
{% element button type="submit" tags="primary,mfa,login" %}
{% trans "Confirm" %}
{% endelement %}
{% endslot %}
{% endelement %}
{% endblock %}
3 changes: 2 additions & 1 deletion allauth/templates/socialaccount/snippets/login.html
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
{% load socialaccount %}
{% get_providers as socialaccount_providers %}
{% if socialaccount_providers %}
<hr>
{% element hr %}
{% endelement %}
{% element h2 %}
{% translate "Or use a third-party" %}
{% endelement %}
Expand Down

0 comments on commit 19bc554

Please sign in to comment.