diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml new file mode 100644 index 0000000..b3cf1a0 --- /dev/null +++ b/.github/workflows/unit-tests.yml @@ -0,0 +1,42 @@ +name: Unit tests +on: + pull_request: + types: + - opened + - edited + - reopened + - synchronize +jobs: + python-tests: + name: Run python tests + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 + with: + fetch-depth: 0 + - name: Set up Python + uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # v5.1.1 + with: + python-version: 3.11 + - name: Install build dependencies + run: | + python -m pip install --upgrade pip + pip install tox + tox -e py + - name: Build coverage file + run: | + tox -e py + pytest --cache-clear --junitxml=coverage.xml --cov-report=term-missing:skip-covered --cov=mlflow_oidc_auth > pytest-coverage.txt + - name: Override Coverage Source Path for Sonar + run: sed -i "s@/home/runner/work/mlflow-oidc-auth/mlflow-oidc-auth@/github/workspace@g" /home/runner/work/mlflow-oidc-auth/mlflow-oidc-auth/coverage.xml + - name: debug cov + run: | + pwd + ls -alh + head -n50 coverage.xml + - name: SonarCloud Scan + uses: SonarSource/sonarcloud-github-action@e44258b109568baa0df60ed515909fc6c72cba92 # v2.3.0 + env: + SONAR_TOKEN: ${{ secrets.SONAR_TOKEN }} + # todo: remove GH App diff --git a/.gitignore b/.gitignore index d214a4d..96ff8bc 100644 --- a/.gitignore +++ b/.gitignore @@ -176,3 +176,4 @@ flask_session/ node_modules/ .angular/ mlflow_oidc_auth/ui +pytest-coverage.txt diff --git a/mlflow_oidc_auth/auth.py b/mlflow_oidc_auth/auth.py index 1f1376f..2563226 100644 --- a/mlflow_oidc_auth/auth.py +++ b/mlflow_oidc_auth/auth.py @@ -1,4 +1,4 @@ -from typing import Union +from typing import Union, Optional import requests from authlib.integrations.flask_client import OAuth @@ -6,21 +6,33 @@ from flask import Response, request from werkzeug.datastructures import Authorization -from mlflow_oidc_auth.app import app from mlflow_oidc_auth.config import config from mlflow_oidc_auth.store import store -oauth = OAuth(app) -oauth.register( - name="oidc", - client_id=config.OIDC_CLIENT_ID, - client_secret=config.OIDC_CLIENT_SECRET, - server_metadata_url=config.OIDC_DISCOVERY_URL, - client_kwargs={"scope": config.OIDC_SCOPE}, -) + +_oauth_instance: Optional[OAuth] = None + + +def get_oauth_instance(app) -> OAuth: + # returns a singleton instance of OAuth + # to avoid circular imports + global _oauth_instance + + if _oauth_instance is None: + _oauth_instance = OAuth(app) + _oauth_instance.register( + name="oidc", + client_id=config.OIDC_CLIENT_ID, + client_secret=config.OIDC_CLIENT_SECRET, + server_metadata_url=config.OIDC_DISCOVERY_URL, + client_kwargs={"scope": config.OIDC_SCOPE}, + ) + return _oauth_instance + def _get_oidc_jwks(): - from mlflow_oidc_auth.app import cache + from mlflow_oidc_auth.app import cache, app + jwks = cache.get("jwks") if jwks: app.logger.debug("JWKS cache hit") @@ -41,6 +53,8 @@ def validate_token(token): def authenticate_request_basic_auth() -> Union[Authorization, Response]: + from mlflow_oidc_auth.app import app + username = request.authorization.username password = request.authorization.password app.logger.debug("Authenticating user %s", username) @@ -53,6 +67,8 @@ def authenticate_request_basic_auth() -> Union[Authorization, Response]: def authenticate_request_bearer_token() -> Union[Authorization, Response]: + from mlflow_oidc_auth.app import app + token = request.authorization.token try: user = validate_token(token) diff --git a/mlflow_oidc_auth/sqlalchemy_store.py b/mlflow_oidc_auth/sqlalchemy_store.py index 5f26067..ec690db 100644 --- a/mlflow_oidc_auth/sqlalchemy_store.py +++ b/mlflow_oidc_auth/sqlalchemy_store.py @@ -113,6 +113,7 @@ def delete_user(self, username: str): with self.ManagedSessionMaker() as session: user = self._get_user(session, username) session.delete(user) + session.flush() def create_experiment_permission(self, experiment_id: str, username: str, permission: str) -> ExperimentPermission: _validate_permission(permission) @@ -341,6 +342,7 @@ def delete_registered_model_permission(self, name: str, username: str): with self.ManagedSessionMaker() as session: perm = self._get_registered_model_permission(session, name, username) session.delete(perm) + session.flush() def list_experiment_permissions_for_experiment(self, experiment_id: str): with self.ManagedSessionMaker() as session: diff --git a/mlflow_oidc_auth/tests/__init__.py b/mlflow_oidc_auth/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/mlflow_oidc_auth/tests/test_auth.py b/mlflow_oidc_auth/tests/test_auth.py new file mode 100644 index 0000000..9e534d9 --- /dev/null +++ b/mlflow_oidc_auth/tests/test_auth.py @@ -0,0 +1,88 @@ +from unittest.mock import patch, MagicMock +from mlflow_oidc_auth.auth import ( + get_oauth_instance, + _get_oidc_jwks, + validate_token, + authenticate_request_basic_auth, + authenticate_request_bearer_token, +) + + +class TestAuth: + @patch("mlflow_oidc_auth.auth.OAuth") + @patch("mlflow_oidc_auth.auth.config") + def test_get_oauth_instance(self, mock_config, mock_oauth): + mock_app = MagicMock() + mock_oauth_instance = MagicMock() + mock_oauth.return_value = mock_oauth_instance + + mock_config.OIDC_CLIENT_ID = "mock_client_id" + mock_config.OIDC_CLIENT_SECRET = "mock_client_secret" + mock_config.OIDC_DISCOVERY_URL = "mock_discovery_url" + mock_config.OIDC_SCOPE = "mock_scope" + + result = get_oauth_instance(mock_app) + + mock_oauth.assert_called_once_with(mock_app) + mock_oauth_instance.register.assert_called_once_with( + name="oidc", + client_id="mock_client_id", + client_secret="mock_client_secret", + server_metadata_url="mock_discovery_url", + client_kwargs={"scope": "mock_scope"}, + ) + assert result == mock_oauth_instance + + @patch("mlflow_oidc_auth.auth._get_oidc_jwks") + @patch("mlflow_oidc_auth.auth.jwt.decode") + def test_validate_token(self, mock_jwt_decode, mock_get_oidc_jwks): + mock_jwks = {"keys": "mock_keys"} + mock_get_oidc_jwks.return_value = mock_jwks + mock_payload = MagicMock() + mock_jwt_decode.return_value = mock_payload + + token = "mock_token" + result = validate_token(token) + + mock_get_oidc_jwks.assert_called_once() + mock_jwt_decode.assert_called_once_with(token, mock_jwks) + mock_payload.validate.assert_called_once() + assert result == mock_payload + + @patch("mlflow_oidc_auth.auth.store") + def test_authenticate_request_basic_auth_uses_authenticate_user(self, mock_store): + mock_request = MagicMock() + mock_request.authorization.username = "mock_username" + mock_request.authorization.password = "mock_password" + mock_store.authenticate_user.return_value = True + + with patch("mlflow_oidc_auth.auth.request", mock_request): + # for some reason decorator doesn't mock flask + result = authenticate_request_basic_auth() + + mock_store.authenticate_user.assert_called_once_with("mock_username", "mock_password") + assert result == True + + @patch("mlflow_oidc_auth.auth.validate_token") + def test_authenticate_request_bearer_token_uses_validate_token(self, mock_validate_token): + mock_request = MagicMock() + mock_request.authorization.token = "mock_token" + mock_validate_token.return_value = MagicMock() + with patch("mlflow_oidc_auth.auth.request", mock_request): + # for some reason decorator doesn't mock flask + result = authenticate_request_bearer_token() + + mock_validate_token.assert_called_once_with("mock_token") + assert result == True + + @patch("mlflow_oidc_auth.auth.validate_token") + def test_authenticate_request_bearer_token_exception_returns_false(self, mock_validate_token): + mock_request = MagicMock() + mock_request.authorization.token = "mock_token" + mock_validate_token.side_effect = Exception() + with patch("mlflow_oidc_auth.auth.request", mock_request): + # for some reason decorator doesn't mock flask + result = authenticate_request_bearer_token() + + mock_validate_token.assert_called_once_with("mock_token") + assert result == False diff --git a/mlflow_oidc_auth/tests/test_client.py b/mlflow_oidc_auth/tests/test_client.py new file mode 100644 index 0000000..5da985a --- /dev/null +++ b/mlflow_oidc_auth/tests/test_client.py @@ -0,0 +1,272 @@ +import pytest +from unittest.mock import patch +from mlflow_oidc_auth.client import AuthServiceClient +from mlflow_oidc_auth.routes import ( + CREATE_EXPERIMENT_PERMISSION, + CREATE_REGISTERED_MODEL_PERMISSION, + CREATE_USER, + DELETE_EXPERIMENT_PERMISSION, + DELETE_REGISTERED_MODEL_PERMISSION, + DELETE_USER, + GET_EXPERIMENT_PERMISSION, + GET_REGISTERED_MODEL_PERMISSION, + GET_USER, + UPDATE_EXPERIMENT_PERMISSION, + UPDATE_REGISTERED_MODEL_PERMISSION, + UPDATE_USER_ADMIN, + UPDATE_USER_PASSWORD, +) + + +@pytest.fixture +def client(): + return AuthServiceClient(tracking_uri="http://test") + + +class TestClientExceptions: + def test_create_user_no_password(self, client): + with pytest.raises(ValueError): + client.create_user("test_user", None) + + def test_create_user_with_no_username(self, client): + with pytest.raises(ValueError): + client.create_user(None, "password") + + def test_get_user_with_no_username(self, client): + with pytest.raises(ValueError): + client.get_user(None) + + def test_update_user_password_with_no_username(self, client): + with pytest.raises(ValueError): + client.update_user_password(None, "password") + + def test_update_user_password_with_no_password(self, client): + with pytest.raises(ValueError): + client.update_user_password("test_user", None) + + def test_update_user_admin_with_no_username(self, client): + with pytest.raises(ValueError): + client.update_user_admin(None, True) + + def test_update_user_admin_with_no_boolean(self, client): + # TODO: probably, use mypy to check typing + with pytest.raises(ValueError): + client.update_user_admin("test_user", "this is not boolean value") + + def test_delete_user_with_no_username(self, client): + with pytest.raises(ValueError): + client.delete_user(None) + + def test_create_experiment_permission_with_no_experiment_id(self, client): + with pytest.raises(ValueError): + client.create_experiment_permission(None, "test_user", "READ") + + def test_create_experiment_permission_with_no_username(self, client): + with pytest.raises(ValueError): + client.create_experiment_permission("1", None, "READ") + + def test_create_experiment_permission_with_invalid_permission(self, client): + with pytest.raises(ValueError): + client.create_experiment_permission("1", "test_user", "THIS_PERMISSION_DOES_NOT_EXIST") + + def test_get_experiment_permission_with_no_experiment_id(self, client): + with pytest.raises(ValueError): + client.get_experiment_permission(None, "test_user") + + def test_get_experiment_permission_with_no_username(self, client): + with pytest.raises(ValueError): + client.get_experiment_permission("1", None) + + def test_update_experiment_permission_with_no_experiment_id(self, client): + with pytest.raises(ValueError): + client.update_experiment_permission(None, "test_user", "READ") + + def test_update_experiment_permission_with_no_username(self, client): + with pytest.raises(ValueError): + client.update_experiment_permission("1", None, "READ") + + def test_update_experiment_permission_with_invalid_permission(self, client): + with pytest.raises(ValueError): + client.update_experiment_permission("1", "test_user", "THIS_PERMISSION_DOES_NOT_EXIST") + + def test_delete_experiment_permission_with_no_experiment_id(self, client): + with pytest.raises(ValueError): + client.delete_experiment_permission(None, "test_user") + + def test_delete_experiment_permission_with_no_username(self, client): + with pytest.raises(ValueError): + client.delete_experiment_permission("1", None) + + def test_create_registered_model_permission_with_no_name(self, client): + with pytest.raises(ValueError): + client.create_registered_model_permission(None, "test_user", "READ") + + def test_create_registered_model_permission_with_no_username(self, client): + with pytest.raises(ValueError): + client.create_registered_model_permission("model", None, "READ") + + def test_create_registered_model_permission_with_invalid_permission(self, client): + with pytest.raises(ValueError): + client.create_registered_model_permission("model", "test_user", "THIS_PERMISSION_DOES_NOT_EXIST") + + def test_get_registered_model_permission_with_no_name(self, client): + with pytest.raises(ValueError): + client.get_registered_model_permission(None, "test_user") + + def test_get_registered_model_permission_with_no_username(self, client): + with pytest.raises(ValueError): + client.get_registered_model_permission("model", None) + + def test_update_registered_model_permission_with_no_name(self, client): + with pytest.raises(ValueError): + client.update_registered_model_permission(None, "test_user", "READ") + + def test_update_registered_model_permission_with_no_username(self, client): + with pytest.raises(ValueError): + client.update_registered_model_permission("model", None, "READ") + + def test_update_registered_model_permission_with_invalid_permission(self, client): + with pytest.raises(ValueError): + client.update_registered_model_permission("model", "test_user", "THIS_PERMISSION_DOES_NOT_EXIST") + + def test_delete_registered_model_permission_with_no_name(self, client): + with pytest.raises(ValueError): + client.delete_registered_model_permission(None, "test_user") + + def test_delete_registered_model_permission_with_no_username(self, client): + with pytest.raises(ValueError): + client.delete_registered_model_permission("model", None) + + +@pytest.fixture( + scope="function", + autouse=True, +) +def mock_request(): + with patch("mlflow_oidc_auth.client.AuthServiceClient._request") as mock: + yield mock + + +class TestClient: + def test_create_user(self, mock_request, client): + mock_request.return_value = { + "user": { + "id": 123, + "username": "test_user", + "display_name": "Test User", + "is_admin": False, + "experiment_permissions": [], + "registered_model_permissions": [], + "groups": [], + } + } + user = client.create_user("test_user", "password") + assert user.username == "test_user" + mock_request.assert_called_once_with(CREATE_USER, "POST", json={"username": "test_user", "password": "password"}) + + def test_get_user(self, mock_request, client): + mock_request.return_value = { + "user": { + "id": 123, + "username": "test_user", + "display_name": "Test User", + "is_admin": False, + "experiment_permissions": [], + "registered_model_permissions": [], + "groups": [], + } + } + user = client.get_user("test_user") + assert user.username == "test_user" + mock_request.assert_called_once_with(GET_USER, "GET", params={"username": "test_user"}) + + def test_update_user_password(self, mock_request, client): + client.update_user_password("test_user", "new_password") + mock_request.assert_called_once_with( + UPDATE_USER_PASSWORD, "PATCH", json={"username": "test_user", "password": "new_password"} + ) + + def test_update_user_admin(self, mock_request, client): + mock_request.return_value = { + "user": { + "id": 123, + "username": "test_user", + "display_name": "Test User", + "is_admin": True, + "experiment_permissions": [], + "registered_model_permissions": [], + "groups": [], + } + } + client.update_user_admin("test_user", True) + mock_request.assert_called_once_with(UPDATE_USER_ADMIN, "PATCH", json={"username": "test_user", "is_admin": True}) + + def test_delete_user(self, mock_request, client): + client.delete_user("test_user") + mock_request.assert_called_once_with(DELETE_USER, "DELETE", json={"username": "test_user"}) + + def test_create_experiment_permission(self, mock_request, client): + mock_request.return_value = {"experiment_permission": {"experiment_id": 1, "permission": "READ", "user_id": 123}} + permission = client.create_experiment_permission("1", "test_user", "READ") + assert permission.permission == "READ" + assert permission.experiment_id == 1 + assert permission.user_id == 123 + mock_request.assert_called_once_with( + CREATE_EXPERIMENT_PERMISSION, "POST", json={"experiment_id": "1", "username": "test_user", "permission": "READ"} + ) + + def test_get_experiment_permission(self, mock_request, client): + mock_request.return_value = {"experiment_permission": {"experiment_id": 1, "permission": "READ", "user_id": 123}} + permission = client.get_experiment_permission("1", "test_user") + assert permission.permission == "READ" + assert permission.experiment_id == 1 + assert permission.user_id == 123 + mock_request.assert_called_once_with( + GET_EXPERIMENT_PERMISSION, "GET", params={"experiment_id": "1", "username": "test_user"} + ) + + def test_update_experiment_permission(self, mock_request, client): + mock_request.return_value = {"experiment_permission": {"experiment_id": 1, "permission": "READ", "user_id": 123}} + client.update_experiment_permission("1", "test_user", "READ") + mock_request.assert_called_once_with( + UPDATE_EXPERIMENT_PERMISSION, "PATCH", json={"experiment_id": "1", "username": "test_user", "permission": "READ"} + ) + + def test_delete_experiment_permission(self, mock_request, client): + client.delete_experiment_permission("1", "test_user") + mock_request.assert_called_once_with( + DELETE_EXPERIMENT_PERMISSION, "DELETE", json={"experiment_id": "1", "username": "test_user"} + ) + + def test_create_registered_model_permission(self, mock_request, client): + mock_request.return_value = {"registered_model_permission": {"name": "model", "permission": "READ", "user_id": 123}} + permission = client.create_registered_model_permission("model", "test_user", "READ") + assert permission.permission == "READ" + assert permission.name == "model" + assert permission.user_id == 123 + mock_request.assert_called_once_with( + CREATE_REGISTERED_MODEL_PERMISSION, "POST", json={"name": "model", "username": "test_user", "permission": "READ"} + ) + + def test_get_registered_model_permission(self, mock_request, client): + mock_request.return_value = {"registered_model_permission": {"name": "model", "permission": "READ", "user_id": 123}} + permission = client.get_registered_model_permission("model", "test_user") + assert permission.permission == "READ" + assert permission.name == "model" + assert permission.user_id == 123 + mock_request.assert_called_once_with( + GET_REGISTERED_MODEL_PERMISSION, "GET", params={"name": "model", "username": "test_user"} + ) + + def test_update_registered_model_permission(self, mock_request, client): + mock_request.return_value = {"registered_model_permission": {"name": "model", "permission": "READ", "user_id": 123}} + client.update_registered_model_permission("model", "test_user", "READ") + mock_request.assert_called_once_with( + UPDATE_REGISTERED_MODEL_PERMISSION, "PATCH", json={"name": "model", "username": "test_user", "permission": "READ"} + ) + + def test_delete_registered_model_permission(self, mock_request, client): + client.delete_registered_model_permission("model", "test_user") + mock_request.assert_called_once_with( + DELETE_REGISTERED_MODEL_PERMISSION, "DELETE", json={"name": "model", "username": "test_user"} + ) diff --git a/mlflow_oidc_auth/tests/test_db_utils.py b/mlflow_oidc_auth/tests/test_db_utils.py new file mode 100644 index 0000000..d0e84b7 --- /dev/null +++ b/mlflow_oidc_auth/tests/test_db_utils.py @@ -0,0 +1,15 @@ +from unittest.mock import patch +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker + +from mlflow_oidc_auth.db.utils import migrate + + +class TestMigrate: + @patch("mlflow_oidc_auth.db.utils.upgrade") + def test_migrate(self, mock_upgrade): + engine = create_engine("sqlite:///:memory:") + with sessionmaker(bind=engine)(): + migrate(engine, "head") + + mock_upgrade.assert_called_once() diff --git a/mlflow_oidc_auth/tests/test_entities.py b/mlflow_oidc_auth/tests/test_entities.py new file mode 100644 index 0000000..55a4bd1 --- /dev/null +++ b/mlflow_oidc_auth/tests/test_entities.py @@ -0,0 +1,56 @@ +import unittest +from mlflow_oidc_auth.entities import User, ExperimentPermission, RegisteredModelPermission, Group + + +class TestUser(unittest.TestCase): + def test_user_to_json(self): + user = User( + id_="123", + username="test_user", + password_hash="password", + is_admin=True, + display_name="Test User", + experiment_permissions=[ExperimentPermission("exp1", "read")], + registered_model_permissions=[RegisteredModelPermission("model1", "write")], + groups=[Group("group1", "Group 1")], + ) + + expected_json = { + "id": "123", + "username": "test_user", + "is_admin": True, + "display_name": "Test User", + "experiment_permissions": [{"experiment_id": "exp1", "permission": "read", "user_id": None, "group_id": None}], + "registered_model_permissions": [{"name": "model1", "permission": "write", "user_id": None, "group_id": None}], + "groups": [{"id": "group1", "group_name": "Group 1"}], + } + + self.assertEqual(user.to_json(), expected_json) + + def test_user_from_json(self): + json_data = { + "id": "123", + "username": "test_user", + "is_admin": True, + "display_name": "Test User", + "experiment_permissions": [{"experiment_id": "exp1", "permission": "read", "user_id": None, "group_id": None}], + "registered_model_permissions": [{"name": "model1", "permission": "write", "user_id": None, "group_id": None}], + "groups": [{"id": "group1", "group_name": "Group 1"}], + } + + user = User.from_json(json_data) + + self.assertEqual(user.id, "123") + self.assertEqual(user.username, "test_user") + self.assertEqual(user.password_hash, "REDACTED") + self.assertTrue(user.is_admin) + self.assertEqual(user.display_name, "Test User") + self.assertEqual(len(user.experiment_permissions), 1) + self.assertEqual(user.experiment_permissions[0].experiment_id, "exp1") + self.assertEqual(user.experiment_permissions[0].permission, "read") + self.assertEqual(len(user.registered_model_permissions), 1) + self.assertEqual(user.registered_model_permissions[0].name, "model1") + self.assertEqual(user.registered_model_permissions[0].permission, "write") + self.assertEqual(len(user.groups), 1) + self.assertEqual(user.groups[0].id, "group1") + self.assertEqual(user.groups[0].group_name, "Group 1") diff --git a/mlflow_oidc_auth/tests/test_plugins_group_detecntion_msft_entra_id.py b/mlflow_oidc_auth/tests/test_plugins_group_detecntion_msft_entra_id.py new file mode 100644 index 0000000..fbb40ba --- /dev/null +++ b/mlflow_oidc_auth/tests/test_plugins_group_detecntion_msft_entra_id.py @@ -0,0 +1,31 @@ +import unittest +from unittest.mock import Mock, patch +from mlflow_oidc_auth.plugins.group_detection_microsoft_entra_id import get_user_groups + + +class TestGetUserGroups(unittest.TestCase): + @patch("mlflow_oidc_auth.plugins.group_detection_microsoft_entra_id.requests.get") + def test_get_user_groups(self, mock_get): + mock_response = Mock() + mock_response.json.return_value = { + "value": [ + {"displayName": "Group 1"}, + {"displayName": "Group 2"}, + {"displayName": "Group 3"}, + ] + } + mock_get.return_value = mock_response + + access_token = "D34DB33F" + groups = get_user_groups(access_token) + + mock_get.assert_called_once_with( + "https://graph.microsoft.com/v1.0/me/memberOf", + headers={ + "Authorization": f"Bearer {access_token}", + "Content-Type": "application/json", + }, + ) + + expected_groups = ["Group 1", "Group 2", "Group 3"] + self.assertEqual(groups, expected_groups) diff --git a/mlflow_oidc_auth/tests/test_routes.py b/mlflow_oidc_auth/tests/test_routes.py new file mode 100644 index 0000000..df8e45c --- /dev/null +++ b/mlflow_oidc_auth/tests/test_routes.py @@ -0,0 +1,55 @@ +from unittest import mock +from mlflow_oidc_auth import routes + +""" +`routes` contains mutiple routes definitions. +This test is to ensure that all expected routes are present. +""" + + +class TestRoutes: + def test_routes_presented(self): + assert all( + route is not None + for route in [ + routes.HOME, + routes.LOGIN, + routes.LOGOUT, + routes.CALLBACK, + routes.STATIC, + routes.UI, + routes.UI_ROOT, + routes.GET_ACCESS_TOKEN, + routes.GET_CURRENT_USER, + routes.GET_EXPERIMENTS, + routes.GET_MODELS, + routes.GET_USERS, + routes.GET_USER_EXPERIMENTS, + routes.GET_USER_MODELS, + routes.GET_EXPERIMENT_USERS, + routes.GET_MODEL_USERS, + routes.CREATE_USER, + routes.GET_USER, + routes.UPDATE_USER_PASSWORD, + routes.UPDATE_USER_ADMIN, + routes.DELETE_USER, + routes.CREATE_EXPERIMENT_PERMISSION, + routes.GET_EXPERIMENT_PERMISSION, + routes.UPDATE_EXPERIMENT_PERMISSION, + routes.DELETE_EXPERIMENT_PERMISSION, + routes.CREATE_REGISTERED_MODEL_PERMISSION, + routes.GET_REGISTERED_MODEL_PERMISSION, + routes.UPDATE_REGISTERED_MODEL_PERMISSION, + routes.DELETE_REGISTERED_MODEL_PERMISSION, + routes.GET_GROUPS, + routes.GET_GROUP_USERS, + routes.GET_GROUP_EXPERIMENTS_PERMISSION, + routes.CREATE_GROUP_EXPERIMENT_PERMISSION, + routes.DELETE_GROUP_EXPERIMENT_PERMISSION, + routes.UPDATE_GROUP_EXPERIMENT_PERMISSION, + routes.GET_GROUP_MODELS_PERMISSION, + routes.CREATE_GROUP_MODEL_PERMISSION, + routes.DELETE_GROUP_MODEL_PERMISSION, + routes.UPDATE_GROUP_MODEL_PERMISSION, + ] + ) diff --git a/mlflow_oidc_auth/tests/test_sqlalchemy_store.py b/mlflow_oidc_auth/tests/test_sqlalchemy_store.py new file mode 100644 index 0000000..4fff432 --- /dev/null +++ b/mlflow_oidc_auth/tests/test_sqlalchemy_store.py @@ -0,0 +1,195 @@ +import pytest +from unittest.mock import patch, MagicMock, PropertyMock +from mlflow.exceptions import MlflowException +from mlflow_oidc_auth.sqlalchemy_store import SqlAlchemyStore +from sqlalchemy.exc import IntegrityError, NoResultFound +from mlflow_oidc_auth.db.models import SqlRegisteredModelPermission + + +@pytest.fixture +@patch("mlflow_oidc_auth.sqlalchemy_store.dbutils.migrate_if_needed") +def store(_mock_migrate_if_needed): + store = SqlAlchemyStore() + store.init_db("sqlite:///:memory:") + return store + + +class TestSqlAlchemyStore: + @patch("mlflow_oidc_auth.sqlalchemy_store.SqlAlchemyStore._get_user") + @patch("mlflow_oidc_auth.sqlalchemy_store.check_password_hash", return_value=True) + def test_authenticate_user(self, _mock_check_password_hash, mock_get_user, store): + mock_get_user.return_value = MagicMock(password_hash="hashed_password") + auth_result = store.authenticate_user("test_user", "password") + mock_get_user.assert_called_once() + assert mock_get_user.call_args[0][1] == "test_user" + assert auth_result is True + + @patch("mlflow_oidc_auth.sqlalchemy_store.SqlAlchemyStore._get_user") + @patch("mlflow_oidc_auth.sqlalchemy_store.check_password_hash", return_value=False) + def test_authenticate_user_failure(self, _mock_check_password_hash, mock_get_user, store): + mock_get_user.return_value = MagicMock(password_hash="hashed_password") + auth_result = store.authenticate_user("test_user", "password") + mock_get_user.assert_called_once() + assert mock_get_user.call_args[0][1] == "test_user" + assert auth_result is False + + @patch("mlflow_oidc_auth.sqlalchemy_store.generate_password_hash", return_value="hashed_password") + def test_create_user(self, _generate_password_hash, store): + store.ManagedSessionMaker = MagicMock() + mock_session = MagicMock() + mock_session.flush = MagicMock() + mock_session.add = MagicMock() + store.ManagedSessionMaker.return_value.__enter__.return_value = mock_session + + user = store.create_user("test_user", "password", "Test User") + + # display_name="Test User", is_admin=False) + mock_session.add.assert_called_once() + assert mock_session.add.call_args[0][0].username == "test_user" + assert mock_session.add.call_args[0][0].password_hash == "hashed_password" + assert mock_session.add.call_args[0][0].display_name == "Test User" + assert mock_session.add.call_args[0][0].is_admin is False + + mock_session.flush.assert_called_once() + assert user.username == "test_user" + assert user.display_name == "Test User" + assert user.is_admin is False + + @patch("mlflow_oidc_auth.sqlalchemy_store.generate_password_hash", return_value="hashed_password") + def test_create_admin_user(self, _generate_password_hash, store): + store.ManagedSessionMaker = MagicMock() + mock_session = MagicMock() + mock_session.flush = MagicMock() + mock_session.add = MagicMock() + store.ManagedSessionMaker.return_value.__enter__.return_value = mock_session + + admin_user = store.create_user("admin_user", "password", "Admin User", is_admin=True) + + assert mock_session.add.call_args[0][0].username == "admin_user" + assert mock_session.add.call_args[0][0].password_hash == "hashed_password" + assert mock_session.add.call_args[0][0].display_name == "Admin User" + assert mock_session.add.call_args[0][0].is_admin is True + assert admin_user.is_admin is True + + def test_create_user_existing(self, store): + store.ManagedSessionMaker = MagicMock() + mock_session = MagicMock() + mock_session.flush = MagicMock() + mock_session.add = MagicMock(side_effect=IntegrityError("", {}, Exception)) + store.ManagedSessionMaker.return_value.__enter__.return_value = mock_session + + with pytest.raises(MlflowException): + store.create_user("test_user", "password", "Test User") + + def test_get_user_not_found(self, store): + store.ManagedSessionMaker = MagicMock() + mock_session = MagicMock() + mock_session.query = MagicMock(side_effect=NoResultFound("", {}, Exception)) + store.ManagedSessionMaker.return_value.__enter__.return_value = mock_session + + with pytest.raises(MlflowException): + store.get_user("non_existent_user") + + @patch("mlflow_oidc_auth.sqlalchemy_store.generate_password_hash", return_value="hashed_password") + def test_update_user(self, _generate_password_hash, store): + retrieved_user = MagicMock() + retrieved_user.is_admin = PropertyMock() + retrieved_user.password_hash = PropertyMock() + store._get_user = MagicMock(return_value=retrieved_user) + store.update_user("test_user", password="new_password", is_admin=True) + assert retrieved_user.is_admin == True + assert retrieved_user.password_hash == "hashed_password" + + def test_delete_user(self, store): + store.ManagedSessionMaker = MagicMock() + mock_session = MagicMock() + mock_session.flush = MagicMock() + mock_session.delete = MagicMock() + store.ManagedSessionMaker.return_value.__enter__.return_value = mock_session + store._get_user = MagicMock(return_value=MagicMock()) + + store.delete_user("test_user") + mock_session.delete.assert_called_once() + mock_session.flush.assert_called_once() + + def test_create_experiment_permission_validates_permission(self, store): + with pytest.raises(MlflowException): + store.create_experiment_permission("1", "test_user", "INVALID_PERMISSION") + + @patch("mlflow_oidc_auth.sqlalchemy_store.SqlAlchemyStore._get_user") + def test_create_experiment_permission(self, mock_get_user, store): + store.ManagedSessionMaker = MagicMock() + mock_session = MagicMock() + mock_session.flush = MagicMock() + mock_session.add = MagicMock + store.ManagedSessionMaker.return_value.__enter__.return_value = mock_session + + mock_user = MagicMock() + mock_user.id = 1 + store._get_user.return_value = mock_user # = MagicMock(return_value=mock_user) + mock_get_user.return_value = mock_user + + permission = store.create_experiment_permission("1", "test_user", "READ") + assert permission.experiment_id == "1" + assert permission.permission == "READ" + assert permission.user_id == 1 + + def test_create_registered_model_permission_validates_permission(self, store): + with pytest.raises(MlflowException): + store.create_registered_model_permission("model", "test_user", "INVALID_PERMISSION") + + def test_create_registered_model_permission_fails_on_duplicate(self, store): + store.ManagedSessionMaker = MagicMock() + mock_session = MagicMock() + mock_session.flush = MagicMock() + mock_session.add = MagicMock(side_effect=IntegrityError("", {}, Exception)) + store.ManagedSessionMaker.return_value.__enter__.return_value = mock_session + with pytest.raises(MlflowException): + store.create_registered_model_permission("model", "test_user", "READ") + + def test_update_registered_model_permission_validates_permission(self, store): + with pytest.raises(MlflowException): + store.update_registered_model_permission("model", "test_user", "INVALID_PERMISSION") + + def test_update_registered_model_permission(self, store): + store.ManagedSessionMaker = MagicMock() + mock_session = MagicMock() + store.ManagedSessionMaker.return_value.__enter__.return_value = mock_session + + store._get_registered_model_permission = MagicMock( + return_value=SqlRegisteredModelPermission(name="model", user_id=1, permission=PropertyMock(return_value="READ")) + ) + + permission = store.update_registered_model_permission("model", "test_user", "EDIT") + assert permission.name == "model" + assert permission.permission == "EDIT" + assert permission.user_id == 1 + + def test_delete_registered_model_permission(self, store): + store.ManagedSessionMaker = MagicMock() + mock_session = MagicMock() + mock_session.flush = MagicMock() + mock_session.delete = MagicMock() + store.ManagedSessionMaker.return_value.__enter__.return_value = mock_session + store._get_registered_model_permission = MagicMock( + return_value=SqlRegisteredModelPermission(name="model", user_id=1, permission="READ") + ) + + store.delete_registered_model_permission("model", "test_user") + mock_session.delete.assert_called_once() + mock_session.flush.assert_called_once() + + def test_populate_groups_is_idempotent(self, store): + store.ManagedSessionMaker = MagicMock() + mock_session = MagicMock() + mock_session.add = MagicMock() + mock_session.query.return_value.filter.return_value.first.return_value = None + store.ManagedSessionMaker.return_value.__enter__.return_value = mock_session + + store.populate_groups(["Group 1"]) + mock_session.add.assert_called() + + mock_session.add.reset_mock() + mock_session.query.return_value.filter.return_value.first.return_value = "Group 1" + store.populate_groups(["Group 1"]) + assert mock_session.add.call_count == 0 diff --git a/mlflow_oidc_auth/views/authentication.py b/mlflow_oidc_auth/views/authentication.py index 08ab3c0..70563fa 100644 --- a/mlflow_oidc_auth/views/authentication.py +++ b/mlflow_oidc_auth/views/authentication.py @@ -3,7 +3,7 @@ from flask import redirect, session, url_for import mlflow_oidc_auth.utils as utils -from mlflow_oidc_auth.auth import oauth +from mlflow_oidc_auth.auth import get_oauth_instance from mlflow_oidc_auth.app import app from mlflow_oidc_auth.config import config from mlflow_oidc_auth.user import create_user, populate_groups, update_user @@ -12,7 +12,7 @@ def login(): state = secrets.token_urlsafe(16) session["oauth_state"] = state - return oauth.oidc.authorize_redirect(config.OIDC_REDIRECT_URI, state=state) + return get_oauth_instance(app).oidc.authorize_redirect(config.OIDC_REDIRECT_URI, state=state) def logout(): @@ -26,7 +26,7 @@ def callback(): if "oauth_state" not in session or utils.get_request_param("state") != session["oauth_state"]: return "Invalid state parameter", 401 - token = oauth.oidc.authorize_access_token() + token = get_oauth_instance(app).oidc.authorize_access_token() app.logger.debug(f"Token: {token}") session["user"] = token["userinfo"] @@ -40,9 +40,7 @@ def callback(): if config.OIDC_GROUP_DETECTION_PLUGIN: import importlib - user_groups = importlib.import_module(config.OIDC_GROUP_DETECTION_PLUGIN).get_user_groups( - token["access_token"] - ) + user_groups = importlib.import_module(config.OIDC_GROUP_DETECTION_PLUGIN).get_user_groups(token["access_token"]) else: user_groups = token["userinfo"][config.OIDC_GROUPS_ATTRIBUTE] diff --git a/pyproject.toml b/pyproject.toml index 76e21df..82916dc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,26 +23,20 @@ classifiers = [ requires-python = ">=3.8" dependencies = [ "cachelib<1", - "mlflow-skinny<3,>=2.11.1", + "mlflow<3,>=2.11.1", + "oauthlib<4", "python-dotenv<2", "requests<3,>=2.31.0", "sqlalchemy<3,>=1.4.0", "Flask<4", "Flask-Session>=0.7.0", - "gunicorn<24; platform_system != 'Windows'", - "alembic<2,!=1.10.0", - "Authlib<2", - "Flask-Caching<3" + "authlib>=1.3.2", + "flask-caching>=2.3.0" ] -[project.optional-dependencies] -full = ["mlflow<3,>=2.11.1"] -caching-redis = ["redis[hiredis]<6"] -dev = ["black<25", "pre-commit<5"] - [[project.maintainers]] -name = "Alexander Kharkevich" -email = "alexander_kharkevich@outlook.com" +name = "Data Platform folks" +email = "noreply@example.com" [project.license] file = "LICENSE" @@ -74,7 +68,18 @@ include = ["mlflow_oidc_auth", "mlflow_oidc_auth.*"] exclude = ["tests", "tests.*"] [tool.setuptools.dynamic] -version = { attr = "mlflow_oidc_auth.version" } +version = {attr = "mlflow_oidc_auth.version"} [tool.black] line-length = 128 + +[project.optional-dependencies] +dev = [ + "black==24.8.0", + "pytest==8.3.2", + "pre-commit==3.5.0", +] +test = [ + "pytest==8.3.2", + "pytest-cov==5.0.0", +] diff --git a/scripts/run-dev-server.sh b/scripts/run-dev-server.sh index 143028b..58671f5 100755 --- a/scripts/run-dev-server.sh +++ b/scripts/run-dev-server.sh @@ -17,6 +17,29 @@ python_preconfigure() { fi } +check_yarn_and_node_version() { + if ! command -v node &> /dev/null; then + echo "node is not installed. Please install node to continue." + exit 1 + fi + + if ! command -v yarn &> /dev/null; then + echo "yarn is not installed. Please install yarn to continue." + exit 1 + fi + + node_version=$(node --version) + + major=$(echo $node_version | cut -d. -f1 | tr -d 'v') + minor=$(echo $node_version | cut -d. -f2) + patch=$(echo $node_version | cut -d. -f3) + + if ! { [ "$major" -eq 14 ] && [ "$minor" -eq 15 ] && [ "$patch" -eq 0 ]; } && ! { [ "$major" -ge 16 ] && { [ "$minor" -ge 10 ] || [ "$major" -gt 16 ]; }; }; then + echo "Node version $node_version is not supported. Please install node version ^14.15.0 || >=16.10.0 to continue." + exit 1 + fi +} + ui_preconfigure() { if [ ! -d "web-ui/node_modules" ]; then pushd web-ui @@ -38,6 +61,7 @@ wait_server_ready() { return 1 } +check_yarn_and_node_version python_preconfigure source venv/bin/activate mlflow server --dev --app-name oidc-auth --host 0.0.0.0 --port 8080 & diff --git a/sonar-project.properties b/sonar-project.properties new file mode 100644 index 0000000..81b479f --- /dev/null +++ b/sonar-project.properties @@ -0,0 +1,7 @@ +sonar.projectKey=data-platform-hq_mlflow-oidc-auth +sonar.organization=data-platform-hq + +sonar.python.version=3.11 +sonar.python.coverage.reportPaths=coverage.xml +sonar.test.inclusions=**/test_*.py +sonar.coverage.exclusions==**/test_*.py,**/app.py,**/db/migrations/versions/**/*.*,**/sqlalchemy_store.py,**/views.py # TODO: review sqlalchemy_store.py and views.py diff --git a/tox.ini b/tox.ini new file mode 100644 index 0000000..785448f --- /dev/null +++ b/tox.ini @@ -0,0 +1,17 @@ +[tox] +envlist = py311 +skipsdist = True + +[testenv] +deps = + pytest + coverage<=7.6 +commands = + pip install -e '.[test]' + coverage run -m pytest -s mlflow_oidc_auth/tests + coverage xml + +[coverage:run] +# relative_files = True +source = mlflow_oidc_auth/ +branch = True