diff --git a/futurex_openedx_extensions/dashboard/serializers.py b/futurex_openedx_extensions/dashboard/serializers.py index b6ec60c..ff9a2f6 100644 --- a/futurex_openedx_extensions/dashboard/serializers.py +++ b/futurex_openedx_extensions/dashboard/serializers.py @@ -2,7 +2,7 @@ from __future__ import annotations import re -from typing import Any, Dict, Tuple +from typing import Any, Dict, List, Tuple from common.djangoapps.student.models import CourseEnrollment from django.conf import settings @@ -44,8 +44,8 @@ ) from futurex_openedx_extensions.helpers.tenants import ( get_all_tenants_info, - get_nafath_sites, get_org_to_tenant_map, + get_sso_sites, get_tenants_by_org, ) @@ -421,14 +421,14 @@ class LearnerEnrollmentSerializer( ): # pylint: disable=too-many-ancestors """Serializer for learner enrollments""" course_id = serializers.SerializerMethodField() - nafath_id = SerializerOptionalMethodField(field_tags=['nafath_id', 'csv_export']) + sso_external_id = SerializerOptionalMethodField(field_tags=['sso_external_id', 'csv_export']) class Meta: model = CourseEnrollment fields = ( LearnerBasicDetailsSerializer.Meta.fields + CourseScoreAndCertificateSerializer.Meta.fields + - ['course_id', 'nafath_id'] + ['course_id', 'sso_external_id'] ) def _get_course_id(self, obj: Any = None) -> CourseLocator | None: @@ -446,20 +446,51 @@ def get_course_id(self, obj: Any) -> str: """Get course id""" return str(self._get_course_id(obj)) - def get_nafath_id(self, obj: Any) -> str: # pylint: disable=no-self-use - """Get Nafath ID from social auth extra_data.""" - tenant_sites = get_all_tenants_info()['sites'] + @staticmethod + def get_sso_site_info(obj: Any) -> List[Dict[str, Any]]: + """Get SSO information of the tenant's site related to the course""" course_tenants = get_org_to_tenant_map().get(obj.course_id.org.lower(), []) + for tenant_id in course_tenants: + sso_site_info = get_sso_sites().get(get_all_tenants_info()['sites'][tenant_id]) + if sso_site_info: + return sso_site_info + + return [] + + def get_sso_external_id(self, obj: Any) -> str: + """Get the SSO external ID from social auth extra_data.""" + result = '' + + sso_site_info = self.get_sso_site_info(obj) + if not sso_site_info: + return result + + social_auth_records = obj.user.social_auth.filter(provider='tpa-saml') + user_auth_by_slug = {} + for record in social_auth_records: + if record.uid.count(':') == 1: + sso_slug, _ = record.uid.split(':') + user_auth_by_slug[sso_slug] = record + + if not user_auth_by_slug: + return result + + for entity_id, sso_info in settings.FX_SSO_INFO.items: + if not sso_info['external_id_field'] or not sso_info['external_id_extractor']: + continue + + for sso_links in sso_site_info: + if entity_id == sso_links['entity_id']: + user_auth_record = user_auth_by_slug.get(sso_links['slug']) + if not user_auth_record: + continue + + external_id_value = user_auth_record.extra_data.get(sso_info['external_id_field']) + if external_id_value: + result = str(sso_info['external_id_extractor'](external_id_value) or '') + break - if not any(tenant_sites[tenant] in get_nafath_sites() for tenant in course_tenants): - return '' - - drupal_social_auth = obj.user.social_auth.filter( - provider=settings.FX_NAFATH_AUTH_PROVIDER, - ).first() - - uid = drupal_social_auth.extra_data.get('uid') if drupal_social_auth else None - return uid[0] if isinstance(uid, list) and len(uid) == 1 else '' + return result class LearnerDetailsExtendedSerializer(LearnerDetailsSerializer): # pylint: disable=too-many-ancestors diff --git a/futurex_openedx_extensions/helpers/settings/common_production.py b/futurex_openedx_extensions/helpers/settings/common_production.py index 8817a9e..29a9250 100644 --- a/futurex_openedx_extensions/helpers/settings/common_production.py +++ b/futurex_openedx_extensions/helpers/settings/common_production.py @@ -59,18 +59,16 @@ def plugin_settings(settings: Any) -> None: }, ) - # Nafath Entry Id - settings.FX_NAFATH_ENTRY_ID = getattr( + # FX SSO Information + settings.FX_SSO_INFO = getattr( settings, - 'FX_NAFATH_ENTRY_ID', - '', - ) - - # Nafath Social Auth Provider - settings.FX_NAFATH_AUTH_PROVIDER = getattr( - settings, - 'FX_NAFATH_AUTH_PROVIDER', - 'tpa-saml', + 'FX_SSO_INFO', + { + 'dummy_entity_id': { + 'external_id_field': 'uid', + 'external_id_extractor': None, # should be a valid function or lambda + }, + }, ) # Default Tenant site diff --git a/futurex_openedx_extensions/helpers/tenants.py b/futurex_openedx_extensions/helpers/tenants.py index f9fd8ee..9c8bfd2 100644 --- a/futurex_openedx_extensions/helpers/tenants.py +++ b/futurex_openedx_extensions/helpers/tenants.py @@ -102,6 +102,18 @@ def get_all_tenants_info() -> Dict[str, str | dict | List[int]]: """ tenant_ids = list(get_all_tenants().values_list('id', flat=True)) info = TenantConfig.objects.filter(id__in=tenant_ids).values('id', 'route__domain', 'lms_configs') + sso_sites: Dict[str, List[Dict[str, str]]] = {} + for sso_site in SAMLProviderConfig.objects.current_set().filter( + entity_id__in=settings.FX_SSO_INFO, enabled=True, + ).values('site__domain', 'slug', 'entity_id'): + site_domain = sso_site['site__domain'] + if site_domain not in sso_sites: + sso_sites[site_domain] = [] + sso_sites[site_domain].append({ + 'slug': sso_site['slug'], + 'entity_id': sso_site['entity_id'], + }) + return { 'tenant_ids': tenant_ids, 'sites': { @@ -124,13 +136,7 @@ def get_all_tenants_info() -> Dict[str, str | dict | List[int]]: 'tenant_by_site': { tenant['route__domain']: tenant['id'] for tenant in info }, - 'special_info': { - 'nafath_sites': list( - SAMLProviderConfig.objects.filter( - entity_id=settings.FX_NAFATH_ENTRY_ID, enabled=True, - ).values_list('site__domain', flat=True) - ), - }, + 'sso_sites': sso_sites, } @@ -313,9 +319,9 @@ def get_tenants_sites(tenant_ids: List[int]) -> List[str]: return tenant_sites -def get_nafath_sites() -> List: - """Get all nafath sites""" - return get_all_tenants_info()['special_info']['nafath_sites'] +def get_sso_sites() -> Dict[str, List[Dict[str, int]]]: + """Get all SSO sites""" + return get_all_tenants_info()['sso_sites'] def generate_tenant_config(sub_domain: str, platform_name: str) -> dict: diff --git a/requirements/test-constraints-redwood.txt b/requirements/test-constraints-redwood.txt index eb8436c..1230545 100644 --- a/requirements/test-constraints-redwood.txt +++ b/requirements/test-constraints-redwood.txt @@ -9,6 +9,7 @@ eox-tenant