-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add a custom google backend for sso (#27082)
- Loading branch information
1 parent
c3b1949
commit ea86a3c
Showing
4 changed files
with
126 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,12 +12,14 @@ | |
from freezegun.api import freeze_time | ||
from rest_framework import status | ||
from social_core.exceptions import AuthFailed, AuthMissingParameter | ||
from social_django.models import UserSocialAuth | ||
|
||
from ee.api.test.base import APILicensedTest | ||
from ee.models.license import License | ||
from posthog.constants import AvailableFeature | ||
from posthog.models import OrganizationMembership, User | ||
from posthog.models.organization_domain import OrganizationDomain | ||
from ee.api.authentication import CustomGoogleOAuth2 | ||
|
||
SAML_MOCK_SETTINGS = { | ||
"SOCIAL_AUTH_SAML_SECURITY_CONFIG": { | ||
|
@@ -707,3 +709,62 @@ def test_xmlsec_and_lxml(self): | |
|
||
assert "1.3.13" == xmlsec.__version__ | ||
assert "4.9.4" == lxml.__version__ | ||
|
||
|
||
class TestCustomGoogleOAuth2(APILicensedTest): | ||
def setUp(self): | ||
super().setUp() | ||
self.google_oauth = CustomGoogleOAuth2() | ||
self.details = {"email": "[email protected]"} | ||
self.sub = "google-oauth2|123456789" | ||
|
||
def test_get_user_id_existing_user_with_sub(self): | ||
"""Test that a user with sub as uid continues using that sub.""" | ||
# Create user with sub as uid | ||
UserSocialAuth.objects.create(provider="google-oauth2", uid=self.sub, user=self.user) | ||
|
||
response = {"email": "[email protected]", "sub": self.sub} | ||
|
||
uid = self.google_oauth.get_user_id(self.details, response) | ||
|
||
self.assertEqual(uid, self.sub) | ||
# Verify no migration occurred (count should be 1) | ||
self.assertEqual(UserSocialAuth.objects.filter(provider="google-oauth2").count(), 1) | ||
# Verify uid is still sub | ||
self.assertEqual(UserSocialAuth.objects.get(provider="google-oauth2").uid, self.sub) | ||
|
||
def test_get_user_id_migrates_email_to_sub(self): | ||
"""Test that a user with email as uid gets migrated to using sub.""" | ||
# Create user with email as uid (legacy format) | ||
social_auth = UserSocialAuth.objects.create(provider="google-oauth2", uid="[email protected]", user=self.user) | ||
|
||
response = {"email": "[email protected]", "sub": self.sub} | ||
|
||
uid = self.google_oauth.get_user_id(self.details, response) | ||
|
||
self.assertEqual(uid, self.sub) | ||
# Verify the uid was updated | ||
social_auth.refresh_from_db() | ||
self.assertEqual(social_auth.uid, self.sub) | ||
|
||
def test_get_user_id_new_user_uses_sub(self): | ||
"""Test that a new user gets sub as uid.""" | ||
response = {"email": "[email protected]", "sub": self.sub} | ||
|
||
uid = self.google_oauth.get_user_id(self.details, response) | ||
|
||
self.assertEqual(uid, self.sub) | ||
# Verify no UserSocialAuth objects were created | ||
self.assertEqual(UserSocialAuth.objects.filter(provider="google-oauth2").count(), 0) | ||
|
||
def test_get_user_id_missing_sub_raises_error(self): | ||
"""Test that missing sub in response raises ValueError.""" | ||
response = { | ||
"email": "[email protected]", | ||
# no sub provided | ||
} | ||
|
||
with self.assertRaises(ValueError) as e: | ||
self.google_oauth.get_user_id(self.details, response) | ||
|
||
self.assertEqual(str(e.exception), "Google OAuth response missing 'sub' claim") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -655,6 +655,7 @@ def run_test_for_allowed_domain( | |
mock_request.return_value.json.return_value = { | ||
"access_token": "123", | ||
"email": "[email protected]", | ||
"sub": "123", | ||
} | ||
|
||
response = self.client.get(url, follow=True) | ||
|
@@ -789,6 +790,7 @@ def test_social_signup_with_allowed_domain_on_cloud_reverse(self, mock_sso_provi | |
mock_request.return_value.json.return_value = { | ||
"access_token": "123", | ||
"email": "[email protected]", | ||
"sub": "123", | ||
} | ||
|
||
response = self.client.get(url, follow=True) | ||
|
@@ -828,6 +830,7 @@ def test_cannot_social_signup_with_allowed_but_jit_provisioning_disabled(self, m | |
mock_request.return_value.json.return_value = { | ||
"access_token": "123", | ||
"email": "[email protected]", | ||
"sub": "123", | ||
} | ||
|
||
response = self.client.get(url, follow=True) | ||
|
@@ -857,6 +860,7 @@ def test_cannot_social_signup_with_allowed_but_unverified_domain(self, mock_sso_ | |
mock_request.return_value.json.return_value = { | ||
"access_token": "123", | ||
"email": "[email protected]", | ||
"sub": "123", | ||
} | ||
|
||
response = self.client.get(url, follow=True) | ||
|
@@ -886,6 +890,7 @@ def test_api_cannot_use_allow_list_for_different_domain(self, mock_sso_providers | |
mock_request.return_value.json.return_value = { | ||
"access_token": "123", | ||
"email": "[email protected]", | ||
"sub": "123", | ||
} # note evil.com | ||
|
||
response = self.client.get(url, follow=True) | ||
|
@@ -911,6 +916,7 @@ def test_social_signup_to_existing_org_without_allowed_domain_on_cloud(self, moc | |
mock_request.return_value.json.return_value = { | ||
"access_token": "123", | ||
"email": "[email protected]", | ||
"sub": "123", | ||
} | ||
response = self.client.get(url, follow=True) | ||
|
||
|