Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(mfa): Reauthenticate #3517

Merged
merged 1 commit into from
Nov 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions allauth/account/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -703,6 +703,32 @@ def get_login_stages(self):
ret.append("allauth.mfa.stages.AuthenticateStage")
return ret

def get_reauthentication_methods(self, user):
"""The order of the methods returned matters. The first method is the
default when using the `@reauthentication_required` decorator.
"""
ret = []
if not user.is_authenticated:
return ret
if user.has_usable_password():
ret.append(
{
"description": _("Use your password"),
"url": reverse("account_reauthenticate"),
}
)
if allauth_app_settings.MFA_ENABLED:
from allauth.mfa.utils import is_mfa_enabled

if is_mfa_enabled(user):
ret.append(
{
"description": _("Use your authenticator app"),
"url": reverse("mfa_reauthenticate"),
}
)
return ret


def get_adapter(request=None):
return import_attribute(app_settings.ADAPTER)(request)
10 changes: 5 additions & 5 deletions allauth/account/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from django.shortcuts import render
from django.urls import reverse

from allauth.account.adapter import get_adapter
from allauth.account.models import EmailAddress
from allauth.account.reauthentication import (
did_recently_authenticate,
Expand Down Expand Up @@ -60,11 +61,10 @@ def _wrapper_view(request, *args, **kwargs):
)
if ena and not pass_method:
if request.user.is_anonymous or not did_recently_authenticate(request):
redirect_url = reverse(
"account_login"
if request.user.is_anonymous
else "account_reauthenticate"
)
redirect_url = reverse("account_login")
methods = get_adapter().get_reauthentication_methods(request.user)
if methods:
redirect_url = methods[0]["url"]
return suspend_request(request, redirect_url)
return view_func(request, *args, **kwargs)

Expand Down
3 changes: 2 additions & 1 deletion allauth/account/reauthentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from django.utils.http import urlencode

from allauth.account import app_settings
from allauth.account.adapter import get_adapter
from allauth.account.utils import get_next_redirect_url
from allauth.core.internal.http import deserialize_request, serialize_request
from allauth.utils import import_callable
Expand Down Expand Up @@ -60,7 +61,7 @@ def reauthenticate_then_callback(request, serialize_state, callback):
def did_recently_authenticate(request):
if request.user.is_anonymous:
return False
if not request.user.has_usable_password():
if not get_adapter().get_reauthentication_methods(request.user):
# TODO: This user only has social accounts attached. Now, ideally, you
# would want to reauthenticate over at the social account provider. For
# now, this is not implemented. Although definitely suboptimal, this
Expand Down
33 changes: 33 additions & 0 deletions allauth/account/tests/test_reauthentication.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from django.urls import reverse

import pytest

from allauth.account.adapter import get_adapter


@pytest.mark.parametrize(
"with_totp,with_password,expected_method_urlnames",
[
(False, True, ["account_reauthenticate"]),
(True, True, ["account_reauthenticate", "mfa_reauthenticate"]),
(True, False, ["mfa_reauthenticate"]),
],
)
def test_user_with_mfa_only(
user_factory, with_totp, with_password, expected_method_urlnames, client
):
user = user_factory(with_totp=with_totp, password=None if with_password else "!")
assert user.has_usable_password() == with_password
client.force_login(user)
methods = get_adapter().get_reauthentication_methods(user)
assert len(methods) == len(expected_method_urlnames)
assert set([m["url"] for m in methods]) == set(
map(reverse, expected_method_urlnames)
)
for urlname in ["account_reauthenticate", "mfa_reauthenticate"]:
resp = client.get(reverse(urlname) + "?next=/foo")
if urlname in expected_method_urlnames:
assert resp.status_code == 200
else:
assert resp.status_code == 302
assert "next=%2Ffoo" in resp["location"]
68 changes: 53 additions & 15 deletions allauth/account/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from django.contrib.auth import REDIRECT_FIELD_NAME
from django.contrib.auth.decorators import login_required
from django.contrib.sites.shortcuts import get_current_site
from django.core.exceptions import PermissionDenied
from django.core.validators import validate_email
from django.forms import ValidationError
from django.http import (
Expand Down Expand Up @@ -981,35 +982,44 @@ class EmailVerificationSentView(TemplateView):
email_verification_sent = EmailVerificationSentView.as_view()


class ReauthenticateView(FormView):
form_class = ReauthenticateForm
template_name = "account/reauthenticate." + app_settings.TEMPLATE_EXTENSION
class BaseReauthenticateView(FormView):
redirect_field_name = REDIRECT_FIELD_NAME

def dispatch(self, request, *args, **kwargs):
r429 = ratelimit.consume_or_429(
resp = self._check_reauthentication_method_available(request)
if resp:
return resp
resp = self._check_ratelimit(request)
if resp:
return resp
return super().dispatch(request, *args, **kwargs)

def _check_ratelimit(self, request):
return ratelimit.consume_or_429(
self.request,
action="reauthenticate",
user=self.request.user,
)
if r429:
return r429
return super().dispatch(request, *args, **kwargs)

def get_form_class(self):
return get_form_class(app_settings.FORMS, "reauthenticate", self.form_class)
def _check_reauthentication_method_available(self, request):
methods = get_adapter().get_reauthentication_methods(self.request.user)
if any([m["url"] == request.path for m in methods]):
# Method is available
return None
if not methods:
# Reauthentication not available
raise PermissionDenied("Reauthentication not available")
url = passthrough_next_redirect_url(
request, methods[0]["url"], self.redirect_field_name
)
return HttpResponseRedirect(url)

def get_success_url(self):
url = get_next_redirect_url(self.request, self.redirect_field_name)
if not url:
url = get_adapter(self.request).get_login_redirect_url(self.request)
return url

def get_form_kwargs(self):
ret = super().get_form_kwargs()
ret["user"] = self.request.user
return ret

def form_valid(self, form):
record_authentication(self.request, self.request.user)
response = resume_request(self.request)
Expand All @@ -1018,15 +1028,43 @@ def form_valid(self, form):
return super().form_valid(form)

def get_context_data(self, **kwargs):
ret = super(ReauthenticateView, self).get_context_data(**kwargs)
ret = super().get_context_data(**kwargs)
redirect_field_value = get_request_param(self.request, self.redirect_field_name)
ret.update(
{
"redirect_field_name": self.redirect_field_name,
"redirect_field_value": redirect_field_value,
"reauthentication_alternatives": self.get_reauthentication_alternatives(),
}
)
return ret

def get_reauthentication_alternatives(self):
methods = get_adapter().get_reauthentication_methods(self.request.user)
alts = []
for method in methods:
alt = dict(method)
if self.request.path == alt["url"]:
continue
alt["url"] = passthrough_next_redirect_url(
self.request, alt["url"], self.redirect_field_name
)
alts.append(alt)
alts = sorted(alts, key=lambda alt: alt["description"])
return alts


class ReauthenticateView(BaseReauthenticateView):
form_class = ReauthenticateForm
template_name = "account/reauthenticate." + app_settings.TEMPLATE_EXTENSION

def get_form_class(self):
return get_form_class(app_settings.FORMS, "reauthenticate", self.form_class)

def get_form_kwargs(self):
ret = super().get_form_kwargs()
ret["user"] = self.request.user
return ret


reauthenticate = login_required(ReauthenticateView.as_view())
11 changes: 9 additions & 2 deletions allauth/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ def user_password(password_factory):

@pytest.fixture
def user_factory(email_factory, db, user_password):
from allauth.mfa import totp

def factory(
email=None,
username=None,
Expand All @@ -45,6 +47,7 @@ def factory(
email_verified=True,
password=None,
with_emailaddress=True,
with_totp=False,
):
if not username:
username = uuid.uuid4().hex
Expand All @@ -54,7 +57,10 @@ def factory(

User = get_user_model()
user = User()
user.set_password(user_password if password is None else password)
if password == "!":
user.password = password
else:
user.set_password(user_password if password is None else password)
user_username(user, username)
user_email(user, email or "")
if commit:
Expand All @@ -63,7 +69,8 @@ def factory(
EmailAddress.objects.create(
user=user, email=email, verified=email_verified, primary=True
)

if with_totp:
totp.TOTP.activate(user, totp.generate_totp_secret())
return user

return factory
Expand Down
1 change: 1 addition & 0 deletions allauth/mfa/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
urlpatterns = [
path("", views.index, name="mfa_index"),
path("authenticate/", views.authenticate, name="mfa_authenticate"),
path("reauthenticate/", views.reauthenticate, name="mfa_reauthenticate"),
path(
"totp/",
include(
Expand Down
15 changes: 15 additions & 0 deletions allauth/mfa/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from allauth.account.adapter import get_adapter as get_account_adapter
from allauth.account.decorators import reauthentication_required
from allauth.account.stages import LoginStageController
from allauth.account.views import BaseReauthenticateView
from allauth.mfa import app_settings, signals, totp
from allauth.mfa.adapter import get_adapter
from allauth.mfa.forms import ActivateTOTPForm, AuthenticateForm
Expand Down Expand Up @@ -48,6 +49,20 @@ def form_valid(self, form):
authenticate = AuthenticateView.as_view()


@method_decorator(login_required, name="dispatch")
class ReauthenticateView(BaseReauthenticateView):
form_class = AuthenticateForm
template_name = "mfa/reauthenticate." + account_settings.TEMPLATE_EXTENSION

def get_form_kwargs(self):
ret = super().get_form_kwargs()
ret["user"] = self.request.user
return ret


reauthenticate = ReauthenticateView.as_view()


@method_decorator(login_required, name="dispatch")
class IndexView(TemplateView):
template_name = "mfa/index." + account_settings.TEMPLATE_EXTENSION
Expand Down
27 changes: 27 additions & 0 deletions allauth/templates/account/base_reauthenticate.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
{% extends "account/base_entrance.html" %}
{% load allauth %}
{% load i18n %}
{% block head_title %}
{% trans "Confirm Access" %}
{% endblock head_title %}
{% block content %}
{% element h1 %}
{% trans "Confirm Access" %}
{% endelement %}
<p>{% blocktranslate %}Please reauthenticate to safeguard your account.{% endblocktranslate %}</p>
{% block reauthenticate_content %}{% endblock %}
{% if reauthentication_alternatives %}
{% element hr %}
{% endelement %}
{% element h2 %}
{% translate "Alternative options" %}
{% endelement %}
<ul>
{% for alt in reauthentication_alternatives %}
<li>
<a href="{{ alt.url }}">{{ alt.description }}</a>
</li>
{% endfor %}
</ul>
{% endif %}
{% endblock content %}
18 changes: 5 additions & 13 deletions allauth/templates/account/reauthenticate.html
Original file line number Diff line number Diff line change
@@ -1,21 +1,13 @@
{% extends "account/base_manage.html" %}
{% extends "account/base_reauthenticate.html" %}
{% load allauth %}
{% load i18n %}
{% block head_title %}
{% trans "Confirm Access" %}
{% endblock head_title %}
{% block content %}
{% element h1 %}
{% trans "Confirm Access" %}
{% endelement %}
<p>
{% blocktranslate %}To safeguard the security of your account, please enter your password:{% endblocktranslate %}
</p>
{% block reauthenticate_content %}
<p>{% blocktranslate %}Enter your password:{% endblocktranslate %}</p>
{% url 'account_reauthenticate' as action_url %}
{% element form form=form method="post" action=action_url %}
{% slot body %}
{% csrf_token %}
{% element fields form=form %}
{% element fields form=form unlabeled=True %}
{% endelement %}
{% if redirect_field_value %}
<input type="hidden"
Expand All @@ -29,4 +21,4 @@
{% endelement %}
{% endslot %}
{% endelement %}
{% endblock content %}
{% endblock %}
1 change: 1 addition & 0 deletions allauth/templates/allauth/elements/hr.html
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
<hr>
24 changes: 24 additions & 0 deletions allauth/templates/mfa/reauthenticate.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
{% extends "account/base_reauthenticate.html" %}
{% load i18n %}
{% load allauth %}
{% block reauthenticate_content %}
<p>{% blocktranslate %}Enter an authenticator code:{% endblocktranslate %}</p>
{% url 'mfa_reauthenticate' as action_url %}
{% element form form=form method="post" action=action_url %}
{% slot body %}
{% csrf_token %}
{% element fields form=form unlabeled=True %}
{% endelement %}
{% if redirect_field_value %}
<input type="hidden"
name="{{ redirect_field_name }}"
value="{{ redirect_field_value }}" />
{% endif %}
{% endslot %}
{% slot actions %}
{% element button type="submit" tags="primary,mfa,login" %}
{% trans "Confirm" %}
{% endelement %}
{% endslot %}
{% endelement %}
{% endblock %}
3 changes: 2 additions & 1 deletion allauth/templates/socialaccount/snippets/login.html
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
{% load socialaccount %}
{% get_providers as socialaccount_providers %}
{% if socialaccount_providers %}
<hr>
{% element hr %}
{% endelement %}
{% element h2 %}
{% translate "Or use a third-party" %}
{% endelement %}
Expand Down