Skip to content

Commit

Permalink
feat: add captcha to login page (pypi#15701)
Browse files Browse the repository at this point in the history
  • Loading branch information
miketheman authored Apr 2, 2024
1 parent 96f48ba commit c577ab4
Show file tree
Hide file tree
Showing 7 changed files with 222 additions and 102 deletions.
13 changes: 13 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from warehouse.accounts import services as account_services
from warehouse.accounts.interfaces import ITokenService, IUserService
from warehouse.admin.flags import AdminFlag, AdminFlagValue
from warehouse.csp import CSPPolicy
from warehouse.email import services as email_services
from warehouse.email.interfaces import IEmailSender
from warehouse.macaroons import services as macaroon_services
Expand Down Expand Up @@ -154,6 +155,16 @@ def __init__(self):
def register_service(self, service_obj, iface=None, context=None, name=""):
self._services[iface][context][name] = service_obj

def register_service_factory(
self, config_or_request, iface=None, context=None, name=""
):
# TODO: This is a bit of a hack, but it's good enough for now.
# It doesn't implement the depth of logic that Pyramid's service
# factories do, but it's good enough for our tests.
# It's functionally equivalent to `register_service()`, but it
# Makes it clear that we're using a factory.
self._services[iface][context][name] = config_or_request

def find_service(self, iface=None, context=None, name=""):
return self._services[iface][context][name]

Expand Down Expand Up @@ -191,6 +202,8 @@ def pyramid_services(
activestate_oidc_service, IOIDCPublisherService, None, name="activestate"
)
services.register_service(macaroon_service, IMacaroonService, None, name="")
# CSP is a special case, as it's a service that's not registered globally.
services.register_service_factory(CSPPolicy({}), name="csp")

return services

Expand Down
53 changes: 50 additions & 3 deletions tests/unit/accounts/test_forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,19 @@ def test_validate(self):
breach_service = pretend.stub(
check_password=pretend.call_recorder(lambda pw, tags: False)
)
captcha_service = pretend.stub(enabled=True, verify_response=lambda _: None)
form = forms.LoginForm(
MultiDict({"username": "user", "password": "password"}),
MultiDict(
{
"username": "user",
"password": "password",
"g_recaptcha_response": "success",
}
),
request=request,
user_service=user_service,
breach_service=breach_service,
captcha_service=captcha_service,
)

assert form.request is request
Expand All @@ -68,8 +76,12 @@ def test_validate_username_with_no_user(self):
find_userid=pretend.call_recorder(lambda userid: None)
)
breach_service = pretend.stub()
captcha_service = pretend.stub(enabled=False)
form = forms.LoginForm(
request=request, user_service=user_service, breach_service=breach_service
request=request,
user_service=user_service,
breach_service=breach_service,
captcha_service=captcha_service,
)
field = pretend.stub(data="my_username")

Expand All @@ -92,8 +104,12 @@ def test_validate_username_with_user(self, input_username, expected_username):
request = pretend.stub()
user_service = pretend.stub(find_userid=pretend.call_recorder(lambda userid: 1))
breach_service = pretend.stub()
captcha_service = pretend.stub(enabled=False)
form = forms.LoginForm(
request=request, user_service=user_service, breach_service=breach_service
request=request,
user_service=user_service,
breach_service=breach_service,
captcha_service=captcha_service,
)
field = pretend.stub(data=input_username)
form.validate_username(field)
Expand All @@ -111,11 +127,13 @@ def test_validate_password_no_user(self):
find_userid=pretend.call_recorder(lambda userid: None)
)
breach_service = pretend.stub()
captcha_service = pretend.stub(enabled=False)
form = forms.LoginForm(
formdata=MultiDict({"username": "my_username"}),
request=request,
user_service=user_service,
breach_service=breach_service,
captcha_service=captcha_service,
)
field = pretend.stub(data="password")

Expand All @@ -140,11 +158,13 @@ def test_validate_password_disabled_for_compromised_pw(self, db_session):
),
)
breach_service = pretend.stub(failure_message="Bad Password!")
captcha_service = pretend.stub(enabled=False)
form = forms.LoginForm(
formdata=MultiDict({"username": "my_username"}),
request=request,
user_service=user_service,
breach_service=breach_service,
captcha_service=captcha_service,
)
field = pretend.stub(data="pw")

Expand Down Expand Up @@ -174,11 +194,13 @@ def test_validate_password_ok(self):
breach_service = pretend.stub(
check_password=pretend.call_recorder(lambda pw, tags: False)
)
captcha_service = pretend.stub(enabled=False)
form = forms.LoginForm(
formdata=MultiDict({"username": "my_username"}),
request=request,
user_service=user_service,
breach_service=breach_service,
captcha_service=captcha_service,
check_password_metrics_tags=["bar"],
)
field = pretend.stub(data="pw")
Expand Down Expand Up @@ -216,11 +238,13 @@ def test_validate_password_notok(self, db_session):
is_disabled=pretend.call_recorder(lambda userid: (False, None)),
)
breach_service = pretend.stub()
captcha_service = pretend.stub(enabled=False)
form = forms.LoginForm(
formdata=MultiDict({"username": "my_username"}),
request=request,
user_service=user_service,
breach_service=breach_service,
captcha_service=captcha_service,
)
field = pretend.stub(data="pw")

Expand Down Expand Up @@ -257,11 +281,13 @@ def test_validate_password_too_many_failed(self):
is_disabled=pretend.call_recorder(lambda userid: (False, None)),
)
breach_service = pretend.stub()
captcha_service = pretend.stub(enabled=False)
form = forms.LoginForm(
formdata=MultiDict({"username": "my_username"}),
request=request,
user_service=user_service,
breach_service=breach_service,
captcha_service=captcha_service,
)
field = pretend.stub(data="pw")

Expand Down Expand Up @@ -297,12 +323,14 @@ def test_password_breached(self, monkeypatch):
breach_service = pretend.stub(
check_password=lambda pw, tags=None: True, failure_message="Bad Password!"
)
captcha_service = pretend.stub(enabled=False, verify_response=lambda _: None)

form = forms.LoginForm(
MultiDict({"password": "password"}),
request=request,
user_service=user_service,
breach_service=breach_service,
captcha_service=captcha_service,
)
assert not form.validate()
assert form.password.errors.pop() == "Bad Password!"
Expand Down Expand Up @@ -332,12 +360,14 @@ def test_validate_password_ok_ip_banned(self):
breach_service = pretend.stub(
check_password=pretend.call_recorder(lambda pw, tags: False)
)
captcha_service = pretend.stub(enabled=False)
form = forms.LoginForm(
formdata=MultiDict({"username": "my_username"}),
request=request,
user_service=user_service,
breach_service=breach_service,
check_password_metrics_tags=["bar"],
captcha_service=captcha_service,
)
field = pretend.stub(data="pw")

Expand Down Expand Up @@ -365,11 +395,13 @@ def test_validate_password_notok_ip_banned(self, db_session):
record_event=pretend.call_recorder(lambda *a, **kw: None),
)
breach_service = pretend.stub()
captcha_service = pretend.stub(enabled=False)
form = forms.LoginForm(
formdata=MultiDict({"username": "my_username"}),
request=request,
user_service=user_service,
breach_service=breach_service,
captcha_service=captcha_service,
)
field = pretend.stub(data="pw")

Expand All @@ -380,6 +412,21 @@ def test_validate_password_notok_ip_banned(self, db_session):
assert user_service.is_disabled.calls == []
assert user_service.check_password.calls == []

@pytest.mark.parametrize("response_data", ["asd", None])
def test_recaptcha_error(self, response_data):
form = forms.LoginForm(
request=pretend.stub(),
formdata=MultiDict({"g_recaptcha_response": response_data}),
user_service=pretend.stub(),
captcha_service=pretend.stub(
verify_response=pretend.raiser(recaptcha.RecaptchaError),
enabled=True,
),
breach_service=pretend.stub(check_password=lambda pw, tags=None: False),
)
assert not form.validate()
assert form.g_recaptcha_response.errors.pop() == "Recaptcha error."


class TestRegistrationForm:
def test_validate(self):
Expand Down
40 changes: 34 additions & 6 deletions tests/unit/accounts/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,11 +215,15 @@ class TestLogin:
def test_get_returns_form(self, pyramid_request, pyramid_services, next_url):
user_service = pretend.stub()
breach_service = pretend.stub()
captcha_service = pretend.stub(csp_policy={})

pyramid_services.register_service(user_service, IUserService, None)
pyramid_services.register_service(
breach_service, IPasswordBreachedService, None
)
pyramid_services.register_service(
captcha_service, ICaptchaService, name="captcha"
)

form_obj = pretend.stub()
form_class = pretend.call_recorder(lambda d, **kw: form_obj)
Expand All @@ -239,6 +243,7 @@ def test_get_returns_form(self, pyramid_request, pyramid_services, next_url):
request=pyramid_request,
user_service=user_service,
breach_service=breach_service,
captcha_service=captcha_service,
check_password_metrics_tags=["method:auth", "auth_method:login_form"],
)
]
Expand All @@ -249,11 +254,15 @@ def test_post_invalid_returns_form(
):
user_service = pretend.stub()
breach_service = pretend.stub()
captcha_service = pretend.stub(csp_policy={})

pyramid_services.register_service(user_service, IUserService, None)
pyramid_services.register_service(
breach_service, IPasswordBreachedService, None
)
pyramid_services.register_service(
captcha_service, ICaptchaService, name="captcha"
)

pyramid_request.method = "POST"
if next_url is not None:
Expand All @@ -274,6 +283,7 @@ def test_post_invalid_returns_form(
request=pyramid_request,
user_service=user_service,
breach_service=breach_service,
captcha_service=captcha_service,
check_password_metrics_tags=["method:auth", "auth_method:login_form"],
)
]
Expand All @@ -300,11 +310,15 @@ def test_post_validate_redirects(
get_password_timestamp=lambda userid: 0,
)
breach_service = pretend.stub(check_password=lambda password, tags=None: False)
captcha_service = pretend.stub(csp_policy={})

pyramid_services.register_service(user_service, IUserService, None)
pyramid_services.register_service(
breach_service, IPasswordBreachedService, None
)
pyramid_services.register_service(
captcha_service, ICaptchaService, name="captcha"
)

pyramid_request.method = "POST"
pyramid_request.session = pretend.stub(
Expand Down Expand Up @@ -351,6 +365,7 @@ def test_post_validate_redirects(
request=pyramid_request,
user_service=user_service,
breach_service=breach_service,
captcha_service=captcha_service,
check_password_metrics_tags=["method:auth", "auth_method:login_form"],
)
]
Expand Down Expand Up @@ -396,11 +411,15 @@ def test_post_validate_no_redirects(
get_password_timestamp=lambda userid: 0,
)
breach_service = pretend.stub(check_password=lambda password, tags=None: False)
captcha_service = pretend.stub(csp_policy={})

pyramid_services.register_service(user_service, IUserService, None)
pyramid_services.register_service(
breach_service, IPasswordBreachedService, None
)
pyramid_services.register_service(
captcha_service, ICaptchaService, name="captcha"
)

pyramid_request.method = "POST"
pyramid_request.POST["next"] = expected_next_url
Expand Down Expand Up @@ -448,7 +467,12 @@ def test_redirect_authenticated_user(self):

@pytest.mark.parametrize("redirect_url", ["test_redirect_url", None])
def test_two_factor_auth(
self, monkeypatch, pyramid_request, redirect_url, token_service
self,
monkeypatch,
pyramid_request,
pyramid_services,
redirect_url,
token_service,
):
token_service.dumps = lambda d: "fake_token"

Expand All @@ -464,12 +488,16 @@ def test_two_factor_auth(
)

breach_service = pretend.stub(check_password=lambda pw: False)
captcha_service = pretend.stub(csp_policy={})

pyramid_request.find_service = lambda interface, **kwargs: {
ITokenService: token_service,
IUserService: user_service,
IPasswordBreachedService: breach_service,
}[interface]
pyramid_services.register_service(
token_service, ITokenService, name="two_factor"
)
pyramid_services.register_service(user_service, IUserService)
pyramid_services.register_service(breach_service, IPasswordBreachedService)
pyramid_services.register_service(
captcha_service, ICaptchaService, name="captcha"
)

pyramid_request.method = "POST"
if redirect_url:
Expand Down
21 changes: 18 additions & 3 deletions warehouse/accounts/forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import disposable_email_domains
import humanize
import markupsafe
import sentry_sdk
import wtforms
import wtforms.fields

Expand Down Expand Up @@ -366,17 +367,31 @@ def validate_g_recaptcha_response(self, field):
raise wtforms.validators.ValidationError("Recaptcha error.")
try:
self.captcha_service.verify_response(field.data)
except recaptcha.RecaptchaError:
# TODO: log error
except recaptcha.RecaptchaError as exc:
sentry_sdk.capture_exception(exc)
# don't want to provide the user with any detail
raise wtforms.validators.ValidationError("Recaptcha error.")


class LoginForm(PasswordMixin, UsernameMixin, forms.Form):
def __init__(self, *args, user_service, breach_service, **kwargs):
def __init__(self, *args, user_service, breach_service, captcha_service, **kwargs):
super().__init__(*args, **kwargs)
self.user_service = user_service
self.breach_service = breach_service
self.captcha_service = captcha_service

g_recaptcha_response = wtforms.StringField()

def validate_g_recaptcha_response(self, field):
# do required data validation here due to enabled flag being required
if self.captcha_service.enabled and not field.data:
raise wtforms.validators.ValidationError("Recaptcha error.")
try:
self.captcha_service.verify_response(field.data)
except recaptcha.RecaptchaError as exc:
sentry_sdk.capture_exception(exc)
# don't want to provide the user with any detail
raise wtforms.validators.ValidationError("Recaptcha error.")

def validate_password(self, field):
# Before we try to validate anything, first check to see if the IP is banned
Expand Down
Loading

0 comments on commit c577ab4

Please sign in to comment.