From 9614a07517853125ca9ff3a01fdc469db6e0409d Mon Sep 17 00:00:00 2001 From: shadinaif Date: Fri, 28 Feb 2025 20:39:23 +0300 Subject: [PATCH 1/2] chore: add django-config-models to test requirements --- requirements/test-constraints-redwood.txt | 1 + requirements/test-constraints-sumac.txt | 1 + requirements/test.in | 1 + 3 files changed, 3 insertions(+) diff --git a/requirements/test-constraints-redwood.txt b/requirements/test-constraints-redwood.txt index eb8436c..1230545 100644 --- a/requirements/test-constraints-redwood.txt +++ b/requirements/test-constraints-redwood.txt @@ -9,6 +9,7 @@ eox-tenant Date: Thu, 27 Feb 2025 09:39:33 +0300 Subject: [PATCH 2/2] feat: use SSO term instead of Nafath And allow fetching user data when multiple SSO configs are available for the same site --- .../dashboard/serializers.py | 63 ++++++++++++++----- .../helpers/settings/common_production.py | 20 +++--- futurex_openedx_extensions/helpers/tenants.py | 26 +++++--- .../fake_models/models.py | 6 +- test_utils/test_settings_common.py | 14 ++++- tests/conftest.py | 9 +++ tests/test_dashboard/test_serializers.py | 50 +++++++++------ tests/test_helpers/test_apps.py | 6 ++ tests/test_helpers/test_tenants.py | 48 ++++++++------ 9 files changed, 163 insertions(+), 79 deletions(-) diff --git a/futurex_openedx_extensions/dashboard/serializers.py b/futurex_openedx_extensions/dashboard/serializers.py index b6ec60c..ff9a2f6 100644 --- a/futurex_openedx_extensions/dashboard/serializers.py +++ b/futurex_openedx_extensions/dashboard/serializers.py @@ -2,7 +2,7 @@ from __future__ import annotations import re -from typing import Any, Dict, Tuple +from typing import Any, Dict, List, Tuple from common.djangoapps.student.models import CourseEnrollment from django.conf import settings @@ -44,8 +44,8 @@ ) from futurex_openedx_extensions.helpers.tenants import ( get_all_tenants_info, - get_nafath_sites, get_org_to_tenant_map, + get_sso_sites, get_tenants_by_org, ) @@ -421,14 +421,14 @@ class LearnerEnrollmentSerializer( ): # pylint: disable=too-many-ancestors """Serializer for learner enrollments""" course_id = serializers.SerializerMethodField() - nafath_id = SerializerOptionalMethodField(field_tags=['nafath_id', 'csv_export']) + sso_external_id = SerializerOptionalMethodField(field_tags=['sso_external_id', 'csv_export']) class Meta: model = CourseEnrollment fields = ( LearnerBasicDetailsSerializer.Meta.fields + CourseScoreAndCertificateSerializer.Meta.fields + - ['course_id', 'nafath_id'] + ['course_id', 'sso_external_id'] ) def _get_course_id(self, obj: Any = None) -> CourseLocator | None: @@ -446,20 +446,51 @@ def get_course_id(self, obj: Any) -> str: """Get course id""" return str(self._get_course_id(obj)) - def get_nafath_id(self, obj: Any) -> str: # pylint: disable=no-self-use - """Get Nafath ID from social auth extra_data.""" - tenant_sites = get_all_tenants_info()['sites'] + @staticmethod + def get_sso_site_info(obj: Any) -> List[Dict[str, Any]]: + """Get SSO information of the tenant's site related to the course""" course_tenants = get_org_to_tenant_map().get(obj.course_id.org.lower(), []) + for tenant_id in course_tenants: + sso_site_info = get_sso_sites().get(get_all_tenants_info()['sites'][tenant_id]) + if sso_site_info: + return sso_site_info + + return [] + + def get_sso_external_id(self, obj: Any) -> str: + """Get the SSO external ID from social auth extra_data.""" + result = '' + + sso_site_info = self.get_sso_site_info(obj) + if not sso_site_info: + return result + + social_auth_records = obj.user.social_auth.filter(provider='tpa-saml') + user_auth_by_slug = {} + for record in social_auth_records: + if record.uid.count(':') == 1: + sso_slug, _ = record.uid.split(':') + user_auth_by_slug[sso_slug] = record + + if not user_auth_by_slug: + return result + + for entity_id, sso_info in settings.FX_SSO_INFO.items: + if not sso_info['external_id_field'] or not sso_info['external_id_extractor']: + continue + + for sso_links in sso_site_info: + if entity_id == sso_links['entity_id']: + user_auth_record = user_auth_by_slug.get(sso_links['slug']) + if not user_auth_record: + continue + + external_id_value = user_auth_record.extra_data.get(sso_info['external_id_field']) + if external_id_value: + result = str(sso_info['external_id_extractor'](external_id_value) or '') + break - if not any(tenant_sites[tenant] in get_nafath_sites() for tenant in course_tenants): - return '' - - drupal_social_auth = obj.user.social_auth.filter( - provider=settings.FX_NAFATH_AUTH_PROVIDER, - ).first() - - uid = drupal_social_auth.extra_data.get('uid') if drupal_social_auth else None - return uid[0] if isinstance(uid, list) and len(uid) == 1 else '' + return result class LearnerDetailsExtendedSerializer(LearnerDetailsSerializer): # pylint: disable=too-many-ancestors diff --git a/futurex_openedx_extensions/helpers/settings/common_production.py b/futurex_openedx_extensions/helpers/settings/common_production.py index 8817a9e..29a9250 100644 --- a/futurex_openedx_extensions/helpers/settings/common_production.py +++ b/futurex_openedx_extensions/helpers/settings/common_production.py @@ -59,18 +59,16 @@ def plugin_settings(settings: Any) -> None: }, ) - # Nafath Entry Id - settings.FX_NAFATH_ENTRY_ID = getattr( + # FX SSO Information + settings.FX_SSO_INFO = getattr( settings, - 'FX_NAFATH_ENTRY_ID', - '', - ) - - # Nafath Social Auth Provider - settings.FX_NAFATH_AUTH_PROVIDER = getattr( - settings, - 'FX_NAFATH_AUTH_PROVIDER', - 'tpa-saml', + 'FX_SSO_INFO', + { + 'dummy_entity_id': { + 'external_id_field': 'uid', + 'external_id_extractor': None, # should be a valid function or lambda + }, + }, ) # Default Tenant site diff --git a/futurex_openedx_extensions/helpers/tenants.py b/futurex_openedx_extensions/helpers/tenants.py index f9fd8ee..9c8bfd2 100644 --- a/futurex_openedx_extensions/helpers/tenants.py +++ b/futurex_openedx_extensions/helpers/tenants.py @@ -102,6 +102,18 @@ def get_all_tenants_info() -> Dict[str, str | dict | List[int]]: """ tenant_ids = list(get_all_tenants().values_list('id', flat=True)) info = TenantConfig.objects.filter(id__in=tenant_ids).values('id', 'route__domain', 'lms_configs') + sso_sites: Dict[str, List[Dict[str, str]]] = {} + for sso_site in SAMLProviderConfig.objects.current_set().filter( + entity_id__in=settings.FX_SSO_INFO, enabled=True, + ).values('site__domain', 'slug', 'entity_id'): + site_domain = sso_site['site__domain'] + if site_domain not in sso_sites: + sso_sites[site_domain] = [] + sso_sites[site_domain].append({ + 'slug': sso_site['slug'], + 'entity_id': sso_site['entity_id'], + }) + return { 'tenant_ids': tenant_ids, 'sites': { @@ -124,13 +136,7 @@ def get_all_tenants_info() -> Dict[str, str | dict | List[int]]: 'tenant_by_site': { tenant['route__domain']: tenant['id'] for tenant in info }, - 'special_info': { - 'nafath_sites': list( - SAMLProviderConfig.objects.filter( - entity_id=settings.FX_NAFATH_ENTRY_ID, enabled=True, - ).values_list('site__domain', flat=True) - ), - }, + 'sso_sites': sso_sites, } @@ -313,9 +319,9 @@ def get_tenants_sites(tenant_ids: List[int]) -> List[str]: return tenant_sites -def get_nafath_sites() -> List: - """Get all nafath sites""" - return get_all_tenants_info()['special_info']['nafath_sites'] +def get_sso_sites() -> Dict[str, List[Dict[str, int]]]: + """Get all SSO sites""" + return get_all_tenants_info()['sso_sites'] def generate_tenant_config(sub_domain: str, platform_name: str) -> dict: diff --git a/test_utils/edx_platform_mocks_shared/fake_models/models.py b/test_utils/edx_platform_mocks_shared/fake_models/models.py index 8875c87..95472ec 100644 --- a/test_utils/edx_platform_mocks_shared/fake_models/models.py +++ b/test_utils/edx_platform_mocks_shared/fake_models/models.py @@ -1,6 +1,7 @@ """edx-platform models mocks for testing purposes.""" import re +from config_models.models import ConfigurationModel from django import forms from django.contrib.auth import get_user_model from django.contrib.sites.models import Site @@ -322,12 +323,15 @@ class Meta: role = forms.ChoiceField(choices=COURSE_ACCESS_ROLES) -class SAMLProviderConfig(models.Model): +class SAMLProviderConfig(ConfigurationModel): """Mock""" + KEY_FIELDS = ('slug',) + site = models.ForeignKey(Site, default=1, on_delete=models.CASCADE) name = models.CharField(max_length=255) enabled = models.BooleanField(default=False) entity_id = models.CharField(max_length=255) + slug = models.SlugField(max_length=30, db_index=True, default='default') class Meta: app_label = 'fake_models' diff --git a/test_utils/test_settings_common.py b/test_utils/test_settings_common.py index ded3549..f0c54eb 100644 --- a/test_utils/test_settings_common.py +++ b/test_utils/test_settings_common.py @@ -111,8 +111,18 @@ def root(*args): FX_DEFAULT_COURSE_EFFORT = 20 -FX_NAFATH_ENTRY_ID = 'abc.com' -FX_NAFATH_AUTH_PROVIDER = 'dummy-provider' +FX_SSO_INFO = { + 'testing_entity_id1': { + 'external_id_field': 'test_uid', + 'external_id_extractor': lambda value: ( + value[0] if isinstance(value, list) and len(value) == 1 else '' if isinstance(value, list) else value + ) + }, + 'testing_entity_id2': { + 'external_id_field': 'test_uid2', + 'external_id_extractor': lambda value: value, + }, +} FX_DEFAULT_TENANT_SITE = 'default.example.com' FX_TENANTS_BASE_DOMAIN = 'local.overhang.io' diff --git a/tests/conftest.py b/tests/conftest.py index f391bdc..709036f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,6 +7,7 @@ from common.djangoapps.student.models import CourseAccessRole, CourseEnrollment, UserSignupSource from custom_reg_form.models import ExtraInfo from django.contrib.auth import get_user_model +from django.contrib.sites.models import Site from django.core.cache import cache from django.test import override_settings from django.utils import timezone @@ -265,6 +266,13 @@ def _create_certificates(): if GeneratedCertificate.objects.count() % 2 == 0: created_date -= datetime.timedelta(days=11) + def _create_sites(): + """Create Sites.""" + for _, tenant_config in _base_data['tenant_config'].items(): + site_domain = tenant_config['lms_configs'].get('LMS_BASE') + if site_domain: + Site.objects.get_or_create(domain=site_domain) + with django_db_blocker.unblock(): _create_users() _create_tenants() @@ -275,3 +283,4 @@ def _create_certificates(): _create_ignored_course_access_roles() _create_course_enrollments() _create_certificates() + _create_sites() diff --git a/tests/test_dashboard/test_serializers.py b/tests/test_dashboard/test_serializers.py index a276d5a..ff82a82 100644 --- a/tests/test_dashboard/test_serializers.py +++ b/tests/test_dashboard/test_serializers.py @@ -306,14 +306,16 @@ def test_learner_enrollments_serializer(mock_collect, base_data,): # pylint: di @pytest.mark.django_db -@patch('futurex_openedx_extensions.dashboard.serializers.get_nafath_sites') -def test_learner_enrollments_serializer_for_nafath_id(mocked_get_nafath_sites): - """Ensure LearnerEnrollmentSerializer correctly processes nafath_id based on social auth conditions.""" - site, _ = Site.objects.update_or_create( - domain='s2.sample.com', - defaults={'name': 's2.sample.com'} - ) - mocked_get_nafath_sites.return_value = [site.domain] +@patch('futurex_openedx_extensions.dashboard.serializers.get_sso_sites') +def test_learner_enrollments_serializer_for_sso_external_id(mocked_get_sso_sites): + """Ensure LearnerEnrollmentSerializer correctly processes sso_external_id based on social auth conditions.""" + site = Site.objects.get(domain='s1.sample.com') + mocked_get_sso_sites.return_value = { + site.domain: [{ + 'slug': 'site_slug', + 'entity_id': 'testing_entity_id1', + }] + } queryset = CourseEnrollment.objects.filter(user_id=10, course_id='course-v1:ORG3+1+1').annotate( certificate_available=Value(True), course_score=Value(0.67), @@ -321,26 +323,34 @@ def test_learner_enrollments_serializer_for_nafath_id(mocked_get_nafath_sites): ) context = { 'course_id': 'course-v1:ORG3+1+1', - 'requested_optional_field_tags': ['nafath_id'] + 'requested_optional_field_tags': ['sso_external_id'] } - def assert_nafath_id(expected, msg): - """Helper to serialize and assert nafath_id.""" + def assert_sso_external_id(expected, msg): + """Helper to serialize and assert sso_external_id.""" serializer = LearnerEnrollmentSerializer(queryset, context=context, many=True) - assert serializer.data[0].get('nafath_id') == expected, msg + assert serializer.data[0].get('sso_external_id') == expected, msg - assert_nafath_id('', 'nafath_id should be empty when no social auth exists') + assert_sso_external_id('', 'sso_external_id should be empty when no social auth exists') - queryset[0].user.social_auth.create(provider='other_provider', extra_data={'uid': ['12345']}) - assert_nafath_id('', 'nafath_id should be empty for an incorrect provider') + queryset[0].user.social_auth.create( + provider='other_provider', + uid='testing_entity_id1:whatever', + extra_data={'test_uid': ['12345']} + ) + assert_sso_external_id('', 'sso_external_id should be empty for an incorrect provider') - queryset[0].user.social_auth.create(provider=settings.FX_NAFATH_AUTH_PROVIDER, extra_data={'uid': ['12345']}) - assert_nafath_id('12345', 'nafath_id should be returned when the correct provider and single UID exist') + queryset[0].user.social_auth.create( + provider='tpa_saml', + uid='testing_entity_id1:whatever', + extra_data={'test_uid': ['12345']} + ) + assert_sso_external_id('12345', 'sso_external_id should be returned when the correct provider and single UID exist') - social_auth = queryset[0].user.social_auth.get(provider=settings.FX_NAFATH_AUTH_PROVIDER) - social_auth.extra_data = {'uid': ['12345', 'another-id']} + social_auth = queryset[0].user.social_auth.get(provider='tpa_saml') + social_auth.extra_data = {'test_uid': ['12345', 'another-id']} social_auth.save() - assert_nafath_id('', 'nafath_id should be empty when multiple UIDs are present') + assert_sso_external_id('', 'sso_external_id should be empty when multiple UIDs are present') @pytest.mark.django_db diff --git a/tests/test_helpers/test_apps.py b/tests/test_helpers/test_apps.py index 33cf865..26f2688 100644 --- a/tests/test_helpers/test_apps.py +++ b/tests/test_helpers/test_apps.py @@ -21,6 +21,12 @@ 'quarter': 4, 'year': 1, }), # Max Period Chunks + ('FX_SSO_INFO', { + 'dummy_entity_id': { + 'external_id_field': 'uid', + 'external_id_extractor': None, + }, + }), ] diff --git a/tests/test_helpers/test_tenants.py b/tests/test_helpers/test_tenants.py index c413ac9..1211877 100644 --- a/tests/test_helpers/test_tenants.py +++ b/tests/test_helpers/test_tenants.py @@ -24,25 +24,6 @@ def expected_exclusion(): } -@pytest.mark.django_db -def test_get_nafath_sites(): - """Test get_nafath_sites returns correct enabled sites for FX_NAFATH_ENTRY_ID.""" - site1 = Site.objects.create(domain='nafath-site1.com') - site2 = Site.objects.create(domain='nafath-site2.com') - site3 = Site.objects.create(domain='non-matching-site.com') - SAMLProviderConfig.objects.create(site=site1, entity_id=settings.FX_NAFATH_ENTRY_ID, enabled=True) - SAMLProviderConfig.objects.create(site=site2, entity_id=settings.FX_NAFATH_ENTRY_ID, enabled=True) - SAMLProviderConfig.objects.create(site=site3, entity_id='other-entry-id', enabled=True) - SAMLProviderConfig.objects.create(site=site1, entity_id=settings.FX_NAFATH_ENTRY_ID, enabled=False) - - nafath_sites = tenants.get_nafath_sites() - - assert len(nafath_sites) == 2, 'Only enabled sites with the correct entity_id should be returned' - assert site1.domain in nafath_sites, 'Expected site1 to be in the list' - assert site2.domain in nafath_sites, 'Expected site2 to be in the list' - assert site3.domain not in nafath_sites, 'Non-matching entity_id should be ignored' - - @pytest.mark.django_db def test_get_excluded_tenant_ids( base_data, expected_exclusion, @@ -263,6 +244,35 @@ def test_get_all_tenants_info_is_being_cached(cache_testing): # pylint: disable cache.set(cs.CACHE_NAME_ALL_TENANTS_INFO, None) +@pytest.mark.django_db +def test_get_sso_sites(base_data): + """Verify that get_sso_sites works as expected""" + assert not tenants.get_sso_sites(), 'bad test data' + assert isinstance(tenants.get_sso_sites(), dict), 'bad test data' + + test_data = [ + ('testing_entity_id1', 'slug1'), + ('testing_entity_id2', 'slug2'), + ('other-entry-id', 'slug3'), + ] + site = Site.objects.get(domain='s1.sample.com') + for entity_id, slug in test_data: + SAMLProviderConfig.objects.create(site=site, entity_id=entity_id, slug=slug, enabled=True) + assert tenants.get_sso_sites() == { + 's1.sample.com': [ + {'entity_id': 'testing_entity_id1', 'slug': 'slug1'}, + {'entity_id': 'testing_entity_id2', 'slug': 'slug2'}, + ] + } + + SAMLProviderConfig.objects.create(site=site, entity_id=test_data[0][0], slug=test_data[0][1], enabled=False) + assert tenants.get_sso_sites() == { + 's1.sample.com': [ + {'entity_id': 'testing_entity_id2', 'slug': 'slug2'}, + ] + } + + @pytest.mark.django_db @pytest.mark.parametrize('tenant_id, expected', [ (1, 's1.sample.com'),