diff --git a/fastapi_opa/auth/auth_saml.py b/fastapi_opa/auth/auth_saml.py index 61d3815..dc7a07d 100644 --- a/fastapi_opa/auth/auth_saml.py +++ b/fastapi_opa/auth/auth_saml.py @@ -44,7 +44,7 @@ async def authenticate( elif 'sso2' in request.query_params: logger.debug(datetime.utcnow(), '--sso2--') return_to = '%sattrs/' % request.base_url - return RedirectResponse(auth.login(return_to)) + return await self.single_sign_on(auth, return_to) elif "acs" in request.query_params: logger.debug(datetime.utcnow(), '--acs--') return await self.assertion_consumer_service(auth, request_args, request) @@ -93,8 +93,8 @@ async def single_log_out(auth: OneLogin_Saml2_Auth) -> RedirectResponse: spnq=name_id_spnq)) @staticmethod - async def single_sign_on(auth: OneLogin_Saml2_Auth) -> RedirectResponse: - redirect_url = auth.login() + async def single_sign_on(auth: OneLogin_Saml2_Auth, url: str = None) -> RedirectResponse: + redirect_url = auth.login(url) return RedirectResponse(redirect_url) @staticmethod diff --git a/tests/test_saml_auth.py b/tests/test_saml_auth.py index 3bae4e9..0ae92a7 100644 --- a/tests/test_saml_auth.py +++ b/tests/test_saml_auth.py @@ -15,12 +15,30 @@ async def test_single_sign_on(): saml_auth_mock = Mock() saml_auth_mock.login.return_value = "http://idp.com/cryptic-stuff" - response = await saml_auth.single_sign_on(saml_auth_mock) + url = r"http://idp.com/cryptic-stuff/attrs" + response = await saml_auth.single_sign_on(saml_auth_mock, url) assert isinstance(response, RedirectResponse) assert response.headers.get("location") == "http://idp.com/cryptic-stuff" +@pytest.mark.asyncio +async def test_single_sign_on_with_parameter(): + saml_conf = SAMLConfig(settings_directory="./tests/test_data/saml") + saml_auth = SAMLAuthentication(saml_conf) + + def side_effect(url): + return url + + saml_auth_mock = Mock() + saml_auth_mock.login = Mock(side_effect=side_effect) + attr_url = "http://idp.com/cryptic-stuff/attrs" + response = await saml_auth.single_sign_on(saml_auth_mock, attr_url) + + assert isinstance(response, RedirectResponse) + assert response.headers.get("location") == attr_url + + @pytest.mark.asyncio @patch("fastapi_opa.auth.auth_saml.OneLogin_Saml2_Utils") async def test_assertion_consumer_service(saml_util_mock):