diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml
new file mode 100644
index 0000000..b3cf1a0
--- /dev/null
+++ b/.github/workflows/unit-tests.yml
@@ -0,0 +1,42 @@
+name: Unit tests
+on:
+ pull_request:
+ types:
+ - opened
+ - edited
+ - reopened
+ - synchronize
+jobs:
+ python-tests:
+ name: Run python tests
+ runs-on: ubuntu-latest
+ steps:
+ - name: Checkout
+ uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7
+ with:
+ fetch-depth: 0
+ - name: Set up Python
+ uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # v5.1.1
+ with:
+ python-version: 3.11
+ - name: Install build dependencies
+ run: |
+ python -m pip install --upgrade pip
+ pip install tox
+ tox -e py
+ - name: Build coverage file
+ run: |
+ tox -e py
+ pytest --cache-clear --junitxml=coverage.xml --cov-report=term-missing:skip-covered --cov=mlflow_oidc_auth > pytest-coverage.txt
+ - name: Override Coverage Source Path for Sonar
+ run: sed -i "s@@@g" /home/runner/work/mlflow-oidc-auth/mlflow-oidc-auth/coverage.xml
+ - name: debug cov
+ run: |
+ pwd
+ ls -alh
+ head -n50 coverage.xml
+ - name: SonarCloud Scan
+ uses: SonarSource/sonarcloud-github-action@e44258b109568baa0df60ed515909fc6c72cba92 # v2.3.0
+ env:
+ SONAR_TOKEN: ${{ secrets.SONAR_TOKEN }}
+ # todo: remove GH App
diff --git a/.gitignore b/.gitignore
index d214a4d..96ff8bc 100644
--- a/.gitignore
+++ b/.gitignore
@@ -176,3 +176,4 @@ flask_session/
node_modules/
.angular/
mlflow_oidc_auth/ui
+pytest-coverage.txt
diff --git a/mlflow_oidc_auth/auth.py b/mlflow_oidc_auth/auth.py
index 1f1376f..2563226 100644
--- a/mlflow_oidc_auth/auth.py
+++ b/mlflow_oidc_auth/auth.py
@@ -1,4 +1,4 @@
-from typing import Union
+from typing import Union, Optional
import requests
from authlib.integrations.flask_client import OAuth
@@ -6,21 +6,33 @@
from flask import Response, request
from werkzeug.datastructures import Authorization
-from mlflow_oidc_auth.app import app
from mlflow_oidc_auth.config import config
from mlflow_oidc_auth.store import store
-oauth = OAuth(app)
-oauth.register(
- name="oidc",
- client_id=config.OIDC_CLIENT_ID,
- client_secret=config.OIDC_CLIENT_SECRET,
- server_metadata_url=config.OIDC_DISCOVERY_URL,
- client_kwargs={"scope": config.OIDC_SCOPE},
-)
+
+_oauth_instance: Optional[OAuth] = None
+
+
+def get_oauth_instance(app) -> OAuth:
+ # returns a singleton instance of OAuth
+ # to avoid circular imports
+ global _oauth_instance
+
+ if _oauth_instance is None:
+ _oauth_instance = OAuth(app)
+ _oauth_instance.register(
+ name="oidc",
+ client_id=config.OIDC_CLIENT_ID,
+ client_secret=config.OIDC_CLIENT_SECRET,
+ server_metadata_url=config.OIDC_DISCOVERY_URL,
+ client_kwargs={"scope": config.OIDC_SCOPE},
+ )
+ return _oauth_instance
+
def _get_oidc_jwks():
- from mlflow_oidc_auth.app import cache
+ from mlflow_oidc_auth.app import cache, app
+
jwks = cache.get("jwks")
if jwks:
app.logger.debug("JWKS cache hit")
@@ -41,6 +53,8 @@ def validate_token(token):
def authenticate_request_basic_auth() -> Union[Authorization, Response]:
+ from mlflow_oidc_auth.app import app
+
username = request.authorization.username
password = request.authorization.password
app.logger.debug("Authenticating user %s", username)
@@ -53,6 +67,8 @@ def authenticate_request_basic_auth() -> Union[Authorization, Response]:
def authenticate_request_bearer_token() -> Union[Authorization, Response]:
+ from mlflow_oidc_auth.app import app
+
token = request.authorization.token
try:
user = validate_token(token)
diff --git a/mlflow_oidc_auth/sqlalchemy_store.py b/mlflow_oidc_auth/sqlalchemy_store.py
index 5f26067..ec690db 100644
--- a/mlflow_oidc_auth/sqlalchemy_store.py
+++ b/mlflow_oidc_auth/sqlalchemy_store.py
@@ -113,6 +113,7 @@ def delete_user(self, username: str):
with self.ManagedSessionMaker() as session:
user = self._get_user(session, username)
session.delete(user)
+ session.flush()
def create_experiment_permission(self, experiment_id: str, username: str, permission: str) -> ExperimentPermission:
_validate_permission(permission)
@@ -341,6 +342,7 @@ def delete_registered_model_permission(self, name: str, username: str):
with self.ManagedSessionMaker() as session:
perm = self._get_registered_model_permission(session, name, username)
session.delete(perm)
+ session.flush()
def list_experiment_permissions_for_experiment(self, experiment_id: str):
with self.ManagedSessionMaker() as session:
diff --git a/mlflow_oidc_auth/tests/__init__.py b/mlflow_oidc_auth/tests/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/mlflow_oidc_auth/tests/test_auth.py b/mlflow_oidc_auth/tests/test_auth.py
new file mode 100644
index 0000000..9e534d9
--- /dev/null
+++ b/mlflow_oidc_auth/tests/test_auth.py
@@ -0,0 +1,88 @@
+from unittest.mock import patch, MagicMock
+from mlflow_oidc_auth.auth import (
+ get_oauth_instance,
+ _get_oidc_jwks,
+ validate_token,
+ authenticate_request_basic_auth,
+ authenticate_request_bearer_token,
+)
+
+
+class TestAuth:
+ @patch("mlflow_oidc_auth.auth.OAuth")
+ @patch("mlflow_oidc_auth.auth.config")
+ def test_get_oauth_instance(self, mock_config, mock_oauth):
+ mock_app = MagicMock()
+ mock_oauth_instance = MagicMock()
+ mock_oauth.return_value = mock_oauth_instance
+
+ mock_config.OIDC_CLIENT_ID = "mock_client_id"
+ mock_config.OIDC_CLIENT_SECRET = "mock_client_secret"
+ mock_config.OIDC_DISCOVERY_URL = "mock_discovery_url"
+ mock_config.OIDC_SCOPE = "mock_scope"
+
+ result = get_oauth_instance(mock_app)
+
+ mock_oauth.assert_called_once_with(mock_app)
+ mock_oauth_instance.register.assert_called_once_with(
+ name="oidc",
+ client_id="mock_client_id",
+ client_secret="mock_client_secret",
+ server_metadata_url="mock_discovery_url",
+ client_kwargs={"scope": "mock_scope"},
+ )
+ assert result == mock_oauth_instance
+
+ @patch("mlflow_oidc_auth.auth._get_oidc_jwks")
+ @patch("mlflow_oidc_auth.auth.jwt.decode")
+ def test_validate_token(self, mock_jwt_decode, mock_get_oidc_jwks):
+ mock_jwks = {"keys": "mock_keys"}
+ mock_get_oidc_jwks.return_value = mock_jwks
+ mock_payload = MagicMock()
+ mock_jwt_decode.return_value = mock_payload
+
+ token = "mock_token"
+ result = validate_token(token)
+
+ mock_get_oidc_jwks.assert_called_once()
+ mock_jwt_decode.assert_called_once_with(token, mock_jwks)
+ mock_payload.validate.assert_called_once()
+ assert result == mock_payload
+
+ @patch("mlflow_oidc_auth.auth.store")
+ def test_authenticate_request_basic_auth_uses_authenticate_user(self, mock_store):
+ mock_request = MagicMock()
+ mock_request.authorization.username = "mock_username"
+ mock_request.authorization.password = "mock_password"
+ mock_store.authenticate_user.return_value = True
+
+ with patch("mlflow_oidc_auth.auth.request", mock_request):
+ # for some reason decorator doesn't mock flask
+ result = authenticate_request_basic_auth()
+
+ mock_store.authenticate_user.assert_called_once_with("mock_username", "mock_password")
+ assert result == True
+
+ @patch("mlflow_oidc_auth.auth.validate_token")
+ def test_authenticate_request_bearer_token_uses_validate_token(self, mock_validate_token):
+ mock_request = MagicMock()
+ mock_request.authorization.token = "mock_token"
+ mock_validate_token.return_value = MagicMock()
+ with patch("mlflow_oidc_auth.auth.request", mock_request):
+ # for some reason decorator doesn't mock flask
+ result = authenticate_request_bearer_token()
+
+ mock_validate_token.assert_called_once_with("mock_token")
+ assert result == True
+
+ @patch("mlflow_oidc_auth.auth.validate_token")
+ def test_authenticate_request_bearer_token_exception_returns_false(self, mock_validate_token):
+ mock_request = MagicMock()
+ mock_request.authorization.token = "mock_token"
+ mock_validate_token.side_effect = Exception()
+ with patch("mlflow_oidc_auth.auth.request", mock_request):
+ # for some reason decorator doesn't mock flask
+ result = authenticate_request_bearer_token()
+
+ mock_validate_token.assert_called_once_with("mock_token")
+ assert result == False
diff --git a/mlflow_oidc_auth/tests/test_client.py b/mlflow_oidc_auth/tests/test_client.py
new file mode 100644
index 0000000..5da985a
--- /dev/null
+++ b/mlflow_oidc_auth/tests/test_client.py
@@ -0,0 +1,272 @@
+import pytest
+from unittest.mock import patch
+from mlflow_oidc_auth.client import AuthServiceClient
+from mlflow_oidc_auth.routes import (
+ CREATE_EXPERIMENT_PERMISSION,
+ CREATE_REGISTERED_MODEL_PERMISSION,
+ CREATE_USER,
+ DELETE_EXPERIMENT_PERMISSION,
+ DELETE_REGISTERED_MODEL_PERMISSION,
+ DELETE_USER,
+ GET_EXPERIMENT_PERMISSION,
+ GET_REGISTERED_MODEL_PERMISSION,
+ GET_USER,
+ UPDATE_EXPERIMENT_PERMISSION,
+ UPDATE_REGISTERED_MODEL_PERMISSION,
+ UPDATE_USER_ADMIN,
+ UPDATE_USER_PASSWORD,
+)
+
+
+@pytest.fixture
+def client():
+ return AuthServiceClient(tracking_uri="http://test")
+
+
+class TestClientExceptions:
+ def test_create_user_no_password(self, client):
+ with pytest.raises(ValueError):
+ client.create_user("test_user", None)
+
+ def test_create_user_with_no_username(self, client):
+ with pytest.raises(ValueError):
+ client.create_user(None, "password")
+
+ def test_get_user_with_no_username(self, client):
+ with pytest.raises(ValueError):
+ client.get_user(None)
+
+ def test_update_user_password_with_no_username(self, client):
+ with pytest.raises(ValueError):
+ client.update_user_password(None, "password")
+
+ def test_update_user_password_with_no_password(self, client):
+ with pytest.raises(ValueError):
+ client.update_user_password("test_user", None)
+
+ def test_update_user_admin_with_no_username(self, client):
+ with pytest.raises(ValueError):
+ client.update_user_admin(None, True)
+
+ def test_update_user_admin_with_no_boolean(self, client):
+ # TODO: probably, use mypy to check typing
+ with pytest.raises(ValueError):
+ client.update_user_admin("test_user", "this is not boolean value")
+
+ def test_delete_user_with_no_username(self, client):
+ with pytest.raises(ValueError):
+ client.delete_user(None)
+
+ def test_create_experiment_permission_with_no_experiment_id(self, client):
+ with pytest.raises(ValueError):
+ client.create_experiment_permission(None, "test_user", "READ")
+
+ def test_create_experiment_permission_with_no_username(self, client):
+ with pytest.raises(ValueError):
+ client.create_experiment_permission("1", None, "READ")
+
+ def test_create_experiment_permission_with_invalid_permission(self, client):
+ with pytest.raises(ValueError):
+ client.create_experiment_permission("1", "test_user", "THIS_PERMISSION_DOES_NOT_EXIST")
+
+ def test_get_experiment_permission_with_no_experiment_id(self, client):
+ with pytest.raises(ValueError):
+ client.get_experiment_permission(None, "test_user")
+
+ def test_get_experiment_permission_with_no_username(self, client):
+ with pytest.raises(ValueError):
+ client.get_experiment_permission("1", None)
+
+ def test_update_experiment_permission_with_no_experiment_id(self, client):
+ with pytest.raises(ValueError):
+ client.update_experiment_permission(None, "test_user", "READ")
+
+ def test_update_experiment_permission_with_no_username(self, client):
+ with pytest.raises(ValueError):
+ client.update_experiment_permission("1", None, "READ")
+
+ def test_update_experiment_permission_with_invalid_permission(self, client):
+ with pytest.raises(ValueError):
+ client.update_experiment_permission("1", "test_user", "THIS_PERMISSION_DOES_NOT_EXIST")
+
+ def test_delete_experiment_permission_with_no_experiment_id(self, client):
+ with pytest.raises(ValueError):
+ client.delete_experiment_permission(None, "test_user")
+
+ def test_delete_experiment_permission_with_no_username(self, client):
+ with pytest.raises(ValueError):
+ client.delete_experiment_permission("1", None)
+
+ def test_create_registered_model_permission_with_no_name(self, client):
+ with pytest.raises(ValueError):
+ client.create_registered_model_permission(None, "test_user", "READ")
+
+ def test_create_registered_model_permission_with_no_username(self, client):
+ with pytest.raises(ValueError):
+ client.create_registered_model_permission("model", None, "READ")
+
+ def test_create_registered_model_permission_with_invalid_permission(self, client):
+ with pytest.raises(ValueError):
+ client.create_registered_model_permission("model", "test_user", "THIS_PERMISSION_DOES_NOT_EXIST")
+
+ def test_get_registered_model_permission_with_no_name(self, client):
+ with pytest.raises(ValueError):
+ client.get_registered_model_permission(None, "test_user")
+
+ def test_get_registered_model_permission_with_no_username(self, client):
+ with pytest.raises(ValueError):
+ client.get_registered_model_permission("model", None)
+
+ def test_update_registered_model_permission_with_no_name(self, client):
+ with pytest.raises(ValueError):
+ client.update_registered_model_permission(None, "test_user", "READ")
+
+ def test_update_registered_model_permission_with_no_username(self, client):
+ with pytest.raises(ValueError):
+ client.update_registered_model_permission("model", None, "READ")
+
+ def test_update_registered_model_permission_with_invalid_permission(self, client):
+ with pytest.raises(ValueError):
+ client.update_registered_model_permission("model", "test_user", "THIS_PERMISSION_DOES_NOT_EXIST")
+
+ def test_delete_registered_model_permission_with_no_name(self, client):
+ with pytest.raises(ValueError):
+ client.delete_registered_model_permission(None, "test_user")
+
+ def test_delete_registered_model_permission_with_no_username(self, client):
+ with pytest.raises(ValueError):
+ client.delete_registered_model_permission("model", None)
+
+
+@pytest.fixture(
+ scope="function",
+ autouse=True,
+)
+def mock_request():
+ with patch("mlflow_oidc_auth.client.AuthServiceClient._request") as mock:
+ yield mock
+
+
+class TestClient:
+ def test_create_user(self, mock_request, client):
+ mock_request.return_value = {
+ "user": {
+ "id": 123,
+ "username": "test_user",
+ "display_name": "Test User",
+ "is_admin": False,
+ "experiment_permissions": [],
+ "registered_model_permissions": [],
+ "groups": [],
+ }
+ }
+ user = client.create_user("test_user", "password")
+ assert user.username == "test_user"
+ mock_request.assert_called_once_with(CREATE_USER, "POST", json={"username": "test_user", "password": "password"})
+
+ def test_get_user(self, mock_request, client):
+ mock_request.return_value = {
+ "user": {
+ "id": 123,
+ "username": "test_user",
+ "display_name": "Test User",
+ "is_admin": False,
+ "experiment_permissions": [],
+ "registered_model_permissions": [],
+ "groups": [],
+ }
+ }
+ user = client.get_user("test_user")
+ assert user.username == "test_user"
+ mock_request.assert_called_once_with(GET_USER, "GET", params={"username": "test_user"})
+
+ def test_update_user_password(self, mock_request, client):
+ client.update_user_password("test_user", "new_password")
+ mock_request.assert_called_once_with(
+ UPDATE_USER_PASSWORD, "PATCH", json={"username": "test_user", "password": "new_password"}
+ )
+
+ def test_update_user_admin(self, mock_request, client):
+ mock_request.return_value = {
+ "user": {
+ "id": 123,
+ "username": "test_user",
+ "display_name": "Test User",
+ "is_admin": True,
+ "experiment_permissions": [],
+ "registered_model_permissions": [],
+ "groups": [],
+ }
+ }
+ client.update_user_admin("test_user", True)
+ mock_request.assert_called_once_with(UPDATE_USER_ADMIN, "PATCH", json={"username": "test_user", "is_admin": True})
+
+ def test_delete_user(self, mock_request, client):
+ client.delete_user("test_user")
+ mock_request.assert_called_once_with(DELETE_USER, "DELETE", json={"username": "test_user"})
+
+ def test_create_experiment_permission(self, mock_request, client):
+ mock_request.return_value = {"experiment_permission": {"experiment_id": 1, "permission": "READ", "user_id": 123}}
+ permission = client.create_experiment_permission("1", "test_user", "READ")
+ assert permission.permission == "READ"
+ assert permission.experiment_id == 1
+ assert permission.user_id == 123
+ mock_request.assert_called_once_with(
+ CREATE_EXPERIMENT_PERMISSION, "POST", json={"experiment_id": "1", "username": "test_user", "permission": "READ"}
+ )
+
+ def test_get_experiment_permission(self, mock_request, client):
+ mock_request.return_value = {"experiment_permission": {"experiment_id": 1, "permission": "READ", "user_id": 123}}
+ permission = client.get_experiment_permission("1", "test_user")
+ assert permission.permission == "READ"
+ assert permission.experiment_id == 1
+ assert permission.user_id == 123
+ mock_request.assert_called_once_with(
+ GET_EXPERIMENT_PERMISSION, "GET", params={"experiment_id": "1", "username": "test_user"}
+ )
+
+ def test_update_experiment_permission(self, mock_request, client):
+ mock_request.return_value = {"experiment_permission": {"experiment_id": 1, "permission": "READ", "user_id": 123}}
+ client.update_experiment_permission("1", "test_user", "READ")
+ mock_request.assert_called_once_with(
+ UPDATE_EXPERIMENT_PERMISSION, "PATCH", json={"experiment_id": "1", "username": "test_user", "permission": "READ"}
+ )
+
+ def test_delete_experiment_permission(self, mock_request, client):
+ client.delete_experiment_permission("1", "test_user")
+ mock_request.assert_called_once_with(
+ DELETE_EXPERIMENT_PERMISSION, "DELETE", json={"experiment_id": "1", "username": "test_user"}
+ )
+
+ def test_create_registered_model_permission(self, mock_request, client):
+ mock_request.return_value = {"registered_model_permission": {"name": "model", "permission": "READ", "user_id": 123}}
+ permission = client.create_registered_model_permission("model", "test_user", "READ")
+ assert permission.permission == "READ"
+ assert permission.name == "model"
+ assert permission.user_id == 123
+ mock_request.assert_called_once_with(
+ CREATE_REGISTERED_MODEL_PERMISSION, "POST", json={"name": "model", "username": "test_user", "permission": "READ"}
+ )
+
+ def test_get_registered_model_permission(self, mock_request, client):
+ mock_request.return_value = {"registered_model_permission": {"name": "model", "permission": "READ", "user_id": 123}}
+ permission = client.get_registered_model_permission("model", "test_user")
+ assert permission.permission == "READ"
+ assert permission.name == "model"
+ assert permission.user_id == 123
+ mock_request.assert_called_once_with(
+ GET_REGISTERED_MODEL_PERMISSION, "GET", params={"name": "model", "username": "test_user"}
+ )
+
+ def test_update_registered_model_permission(self, mock_request, client):
+ mock_request.return_value = {"registered_model_permission": {"name": "model", "permission": "READ", "user_id": 123}}
+ client.update_registered_model_permission("model", "test_user", "READ")
+ mock_request.assert_called_once_with(
+ UPDATE_REGISTERED_MODEL_PERMISSION, "PATCH", json={"name": "model", "username": "test_user", "permission": "READ"}
+ )
+
+ def test_delete_registered_model_permission(self, mock_request, client):
+ client.delete_registered_model_permission("model", "test_user")
+ mock_request.assert_called_once_with(
+ DELETE_REGISTERED_MODEL_PERMISSION, "DELETE", json={"name": "model", "username": "test_user"}
+ )
diff --git a/mlflow_oidc_auth/tests/test_db_utils.py b/mlflow_oidc_auth/tests/test_db_utils.py
new file mode 100644
index 0000000..d0e84b7
--- /dev/null
+++ b/mlflow_oidc_auth/tests/test_db_utils.py
@@ -0,0 +1,15 @@
+from unittest.mock import patch
+from sqlalchemy import create_engine
+from sqlalchemy.orm import sessionmaker
+
+from mlflow_oidc_auth.db.utils import migrate
+
+
+class TestMigrate:
+ @patch("mlflow_oidc_auth.db.utils.upgrade")
+ def test_migrate(self, mock_upgrade):
+ engine = create_engine("sqlite:///:memory:")
+ with sessionmaker(bind=engine)():
+ migrate(engine, "head")
+
+ mock_upgrade.assert_called_once()
diff --git a/mlflow_oidc_auth/tests/test_entities.py b/mlflow_oidc_auth/tests/test_entities.py
new file mode 100644
index 0000000..55a4bd1
--- /dev/null
+++ b/mlflow_oidc_auth/tests/test_entities.py
@@ -0,0 +1,56 @@
+import unittest
+from mlflow_oidc_auth.entities import User, ExperimentPermission, RegisteredModelPermission, Group
+
+
+class TestUser(unittest.TestCase):
+ def test_user_to_json(self):
+ user = User(
+ id_="123",
+ username="test_user",
+ password_hash="password",
+ is_admin=True,
+ display_name="Test User",
+ experiment_permissions=[ExperimentPermission("exp1", "read")],
+ registered_model_permissions=[RegisteredModelPermission("model1", "write")],
+ groups=[Group("group1", "Group 1")],
+ )
+
+ expected_json = {
+ "id": "123",
+ "username": "test_user",
+ "is_admin": True,
+ "display_name": "Test User",
+ "experiment_permissions": [{"experiment_id": "exp1", "permission": "read", "user_id": None, "group_id": None}],
+ "registered_model_permissions": [{"name": "model1", "permission": "write", "user_id": None, "group_id": None}],
+ "groups": [{"id": "group1", "group_name": "Group 1"}],
+ }
+
+ self.assertEqual(user.to_json(), expected_json)
+
+ def test_user_from_json(self):
+ json_data = {
+ "id": "123",
+ "username": "test_user",
+ "is_admin": True,
+ "display_name": "Test User",
+ "experiment_permissions": [{"experiment_id": "exp1", "permission": "read", "user_id": None, "group_id": None}],
+ "registered_model_permissions": [{"name": "model1", "permission": "write", "user_id": None, "group_id": None}],
+ "groups": [{"id": "group1", "group_name": "Group 1"}],
+ }
+
+ user = User.from_json(json_data)
+
+ self.assertEqual(user.id, "123")
+ self.assertEqual(user.username, "test_user")
+ self.assertEqual(user.password_hash, "REDACTED")
+ self.assertTrue(user.is_admin)
+ self.assertEqual(user.display_name, "Test User")
+ self.assertEqual(len(user.experiment_permissions), 1)
+ self.assertEqual(user.experiment_permissions[0].experiment_id, "exp1")
+ self.assertEqual(user.experiment_permissions[0].permission, "read")
+ self.assertEqual(len(user.registered_model_permissions), 1)
+ self.assertEqual(user.registered_model_permissions[0].name, "model1")
+ self.assertEqual(user.registered_model_permissions[0].permission, "write")
+ self.assertEqual(len(user.groups), 1)
+ self.assertEqual(user.groups[0].id, "group1")
+ self.assertEqual(user.groups[0].group_name, "Group 1")
diff --git a/mlflow_oidc_auth/tests/test_plugins_group_detecntion_msft_entra_id.py b/mlflow_oidc_auth/tests/test_plugins_group_detecntion_msft_entra_id.py
new file mode 100644
index 0000000..fbb40ba
--- /dev/null
+++ b/mlflow_oidc_auth/tests/test_plugins_group_detecntion_msft_entra_id.py
@@ -0,0 +1,31 @@
+import unittest
+from unittest.mock import Mock, patch
+from mlflow_oidc_auth.plugins.group_detection_microsoft_entra_id import get_user_groups
+
+
+class TestGetUserGroups(unittest.TestCase):
+ @patch("mlflow_oidc_auth.plugins.group_detection_microsoft_entra_id.requests.get")
+ def test_get_user_groups(self, mock_get):
+ mock_response = Mock()
+ mock_response.json.return_value = {
+ "value": [
+ {"displayName": "Group 1"},
+ {"displayName": "Group 2"},
+ {"displayName": "Group 3"},
+ ]
+ }
+ mock_get.return_value = mock_response
+
+ access_token = "D34DB33F"
+ groups = get_user_groups(access_token)
+
+ mock_get.assert_called_once_with(
+ "https://graph.microsoft.com/v1.0/me/memberOf",
+ headers={
+ "Authorization": f"Bearer {access_token}",
+ "Content-Type": "application/json",
+ },
+ )
+
+ expected_groups = ["Group 1", "Group 2", "Group 3"]
+ self.assertEqual(groups, expected_groups)
diff --git a/mlflow_oidc_auth/tests/test_routes.py b/mlflow_oidc_auth/tests/test_routes.py
new file mode 100644
index 0000000..df8e45c
--- /dev/null
+++ b/mlflow_oidc_auth/tests/test_routes.py
@@ -0,0 +1,55 @@
+from unittest import mock
+from mlflow_oidc_auth import routes
+
+"""
+`routes` contains mutiple routes definitions.
+This test is to ensure that all expected routes are present.
+"""
+
+
+class TestRoutes:
+ def test_routes_presented(self):
+ assert all(
+ route is not None
+ for route in [
+ routes.HOME,
+ routes.LOGIN,
+ routes.LOGOUT,
+ routes.CALLBACK,
+ routes.STATIC,
+ routes.UI,
+ routes.UI_ROOT,
+ routes.GET_ACCESS_TOKEN,
+ routes.GET_CURRENT_USER,
+ routes.GET_EXPERIMENTS,
+ routes.GET_MODELS,
+ routes.GET_USERS,
+ routes.GET_USER_EXPERIMENTS,
+ routes.GET_USER_MODELS,
+ routes.GET_EXPERIMENT_USERS,
+ routes.GET_MODEL_USERS,
+ routes.CREATE_USER,
+ routes.GET_USER,
+ routes.UPDATE_USER_PASSWORD,
+ routes.UPDATE_USER_ADMIN,
+ routes.DELETE_USER,
+ routes.CREATE_EXPERIMENT_PERMISSION,
+ routes.GET_EXPERIMENT_PERMISSION,
+ routes.UPDATE_EXPERIMENT_PERMISSION,
+ routes.DELETE_EXPERIMENT_PERMISSION,
+ routes.CREATE_REGISTERED_MODEL_PERMISSION,
+ routes.GET_REGISTERED_MODEL_PERMISSION,
+ routes.UPDATE_REGISTERED_MODEL_PERMISSION,
+ routes.DELETE_REGISTERED_MODEL_PERMISSION,
+ routes.GET_GROUPS,
+ routes.GET_GROUP_USERS,
+ routes.GET_GROUP_EXPERIMENTS_PERMISSION,
+ routes.CREATE_GROUP_EXPERIMENT_PERMISSION,
+ routes.DELETE_GROUP_EXPERIMENT_PERMISSION,
+ routes.UPDATE_GROUP_EXPERIMENT_PERMISSION,
+ routes.GET_GROUP_MODELS_PERMISSION,
+ routes.CREATE_GROUP_MODEL_PERMISSION,
+ routes.DELETE_GROUP_MODEL_PERMISSION,
+ routes.UPDATE_GROUP_MODEL_PERMISSION,
+ ]
+ )
diff --git a/mlflow_oidc_auth/tests/test_sqlalchemy_store.py b/mlflow_oidc_auth/tests/test_sqlalchemy_store.py
new file mode 100644
index 0000000..4fff432
--- /dev/null
+++ b/mlflow_oidc_auth/tests/test_sqlalchemy_store.py
@@ -0,0 +1,195 @@
+import pytest
+from unittest.mock import patch, MagicMock, PropertyMock
+from mlflow.exceptions import MlflowException
+from mlflow_oidc_auth.sqlalchemy_store import SqlAlchemyStore
+from sqlalchemy.exc import IntegrityError, NoResultFound
+from mlflow_oidc_auth.db.models import SqlRegisteredModelPermission
+
+
+@pytest.fixture
+@patch("mlflow_oidc_auth.sqlalchemy_store.dbutils.migrate_if_needed")
+def store(_mock_migrate_if_needed):
+ store = SqlAlchemyStore()
+ store.init_db("sqlite:///:memory:")
+ return store
+
+
+class TestSqlAlchemyStore:
+ @patch("mlflow_oidc_auth.sqlalchemy_store.SqlAlchemyStore._get_user")
+ @patch("mlflow_oidc_auth.sqlalchemy_store.check_password_hash", return_value=True)
+ def test_authenticate_user(self, _mock_check_password_hash, mock_get_user, store):
+ mock_get_user.return_value = MagicMock(password_hash="hashed_password")
+ auth_result = store.authenticate_user("test_user", "password")
+ mock_get_user.assert_called_once()
+ assert mock_get_user.call_args[0][1] == "test_user"
+ assert auth_result is True
+
+ @patch("mlflow_oidc_auth.sqlalchemy_store.SqlAlchemyStore._get_user")
+ @patch("mlflow_oidc_auth.sqlalchemy_store.check_password_hash", return_value=False)
+ def test_authenticate_user_failure(self, _mock_check_password_hash, mock_get_user, store):
+ mock_get_user.return_value = MagicMock(password_hash="hashed_password")
+ auth_result = store.authenticate_user("test_user", "password")
+ mock_get_user.assert_called_once()
+ assert mock_get_user.call_args[0][1] == "test_user"
+ assert auth_result is False
+
+ @patch("mlflow_oidc_auth.sqlalchemy_store.generate_password_hash", return_value="hashed_password")
+ def test_create_user(self, _generate_password_hash, store):
+ store.ManagedSessionMaker = MagicMock()
+ mock_session = MagicMock()
+ mock_session.flush = MagicMock()
+ mock_session.add = MagicMock()
+ store.ManagedSessionMaker.return_value.__enter__.return_value = mock_session
+
+ user = store.create_user("test_user", "password", "Test User")
+
+ # display_name="Test User", is_admin=False)
+ mock_session.add.assert_called_once()
+ assert mock_session.add.call_args[0][0].username == "test_user"
+ assert mock_session.add.call_args[0][0].password_hash == "hashed_password"
+ assert mock_session.add.call_args[0][0].display_name == "Test User"
+ assert mock_session.add.call_args[0][0].is_admin is False
+
+ mock_session.flush.assert_called_once()
+ assert user.username == "test_user"
+ assert user.display_name == "Test User"
+ assert user.is_admin is False
+
+ @patch("mlflow_oidc_auth.sqlalchemy_store.generate_password_hash", return_value="hashed_password")
+ def test_create_admin_user(self, _generate_password_hash, store):
+ store.ManagedSessionMaker = MagicMock()
+ mock_session = MagicMock()
+ mock_session.flush = MagicMock()
+ mock_session.add = MagicMock()
+ store.ManagedSessionMaker.return_value.__enter__.return_value = mock_session
+
+ admin_user = store.create_user("admin_user", "password", "Admin User", is_admin=True)
+
+ assert mock_session.add.call_args[0][0].username == "admin_user"
+ assert mock_session.add.call_args[0][0].password_hash == "hashed_password"
+ assert mock_session.add.call_args[0][0].display_name == "Admin User"
+ assert mock_session.add.call_args[0][0].is_admin is True
+ assert admin_user.is_admin is True
+
+ def test_create_user_existing(self, store):
+ store.ManagedSessionMaker = MagicMock()
+ mock_session = MagicMock()
+ mock_session.flush = MagicMock()
+ mock_session.add = MagicMock(side_effect=IntegrityError("", {}, Exception))
+ store.ManagedSessionMaker.return_value.__enter__.return_value = mock_session
+
+ with pytest.raises(MlflowException):
+ store.create_user("test_user", "password", "Test User")
+
+ def test_get_user_not_found(self, store):
+ store.ManagedSessionMaker = MagicMock()
+ mock_session = MagicMock()
+ mock_session.query = MagicMock(side_effect=NoResultFound("", {}, Exception))
+ store.ManagedSessionMaker.return_value.__enter__.return_value = mock_session
+
+ with pytest.raises(MlflowException):
+ store.get_user("non_existent_user")
+
+ @patch("mlflow_oidc_auth.sqlalchemy_store.generate_password_hash", return_value="hashed_password")
+ def test_update_user(self, _generate_password_hash, store):
+ retrieved_user = MagicMock()
+ retrieved_user.is_admin = PropertyMock()
+ retrieved_user.password_hash = PropertyMock()
+ store._get_user = MagicMock(return_value=retrieved_user)
+ store.update_user("test_user", password="new_password", is_admin=True)
+ assert retrieved_user.is_admin == True
+ assert retrieved_user.password_hash == "hashed_password"
+
+ def test_delete_user(self, store):
+ store.ManagedSessionMaker = MagicMock()
+ mock_session = MagicMock()
+ mock_session.flush = MagicMock()
+ mock_session.delete = MagicMock()
+ store.ManagedSessionMaker.return_value.__enter__.return_value = mock_session
+ store._get_user = MagicMock(return_value=MagicMock())
+
+ store.delete_user("test_user")
+ mock_session.delete.assert_called_once()
+ mock_session.flush.assert_called_once()
+
+ def test_create_experiment_permission_validates_permission(self, store):
+ with pytest.raises(MlflowException):
+ store.create_experiment_permission("1", "test_user", "INVALID_PERMISSION")
+
+ @patch("mlflow_oidc_auth.sqlalchemy_store.SqlAlchemyStore._get_user")
+ def test_create_experiment_permission(self, mock_get_user, store):
+ store.ManagedSessionMaker = MagicMock()
+ mock_session = MagicMock()
+ mock_session.flush = MagicMock()
+ mock_session.add = MagicMock
+ store.ManagedSessionMaker.return_value.__enter__.return_value = mock_session
+
+ mock_user = MagicMock()
+ mock_user.id = 1
+ store._get_user.return_value = mock_user # = MagicMock(return_value=mock_user)
+ mock_get_user.return_value = mock_user
+
+ permission = store.create_experiment_permission("1", "test_user", "READ")
+ assert permission.experiment_id == "1"
+ assert permission.permission == "READ"
+ assert permission.user_id == 1
+
+ def test_create_registered_model_permission_validates_permission(self, store):
+ with pytest.raises(MlflowException):
+ store.create_registered_model_permission("model", "test_user", "INVALID_PERMISSION")
+
+ def test_create_registered_model_permission_fails_on_duplicate(self, store):
+ store.ManagedSessionMaker = MagicMock()
+ mock_session = MagicMock()
+ mock_session.flush = MagicMock()
+ mock_session.add = MagicMock(side_effect=IntegrityError("", {}, Exception))
+ store.ManagedSessionMaker.return_value.__enter__.return_value = mock_session
+ with pytest.raises(MlflowException):
+ store.create_registered_model_permission("model", "test_user", "READ")
+
+ def test_update_registered_model_permission_validates_permission(self, store):
+ with pytest.raises(MlflowException):
+ store.update_registered_model_permission("model", "test_user", "INVALID_PERMISSION")
+
+ def test_update_registered_model_permission(self, store):
+ store.ManagedSessionMaker = MagicMock()
+ mock_session = MagicMock()
+ store.ManagedSessionMaker.return_value.__enter__.return_value = mock_session
+
+ store._get_registered_model_permission = MagicMock(
+ return_value=SqlRegisteredModelPermission(name="model", user_id=1, permission=PropertyMock(return_value="READ"))
+ )
+
+ permission = store.update_registered_model_permission("model", "test_user", "EDIT")
+ assert permission.name == "model"
+ assert permission.permission == "EDIT"
+ assert permission.user_id == 1
+
+ def test_delete_registered_model_permission(self, store):
+ store.ManagedSessionMaker = MagicMock()
+ mock_session = MagicMock()
+ mock_session.flush = MagicMock()
+ mock_session.delete = MagicMock()
+ store.ManagedSessionMaker.return_value.__enter__.return_value = mock_session
+ store._get_registered_model_permission = MagicMock(
+ return_value=SqlRegisteredModelPermission(name="model", user_id=1, permission="READ")
+ )
+
+ store.delete_registered_model_permission("model", "test_user")
+ mock_session.delete.assert_called_once()
+ mock_session.flush.assert_called_once()
+
+ def test_populate_groups_is_idempotent(self, store):
+ store.ManagedSessionMaker = MagicMock()
+ mock_session = MagicMock()
+ mock_session.add = MagicMock()
+ mock_session.query.return_value.filter.return_value.first.return_value = None
+ store.ManagedSessionMaker.return_value.__enter__.return_value = mock_session
+
+ store.populate_groups(["Group 1"])
+ mock_session.add.assert_called()
+
+ mock_session.add.reset_mock()
+ mock_session.query.return_value.filter.return_value.first.return_value = "Group 1"
+ store.populate_groups(["Group 1"])
+ assert mock_session.add.call_count == 0
diff --git a/mlflow_oidc_auth/views/authentication.py b/mlflow_oidc_auth/views/authentication.py
index 08ab3c0..70563fa 100644
--- a/mlflow_oidc_auth/views/authentication.py
+++ b/mlflow_oidc_auth/views/authentication.py
@@ -3,7 +3,7 @@
from flask import redirect, session, url_for
import mlflow_oidc_auth.utils as utils
-from mlflow_oidc_auth.auth import oauth
+from mlflow_oidc_auth.auth import get_oauth_instance
from mlflow_oidc_auth.app import app
from mlflow_oidc_auth.config import config
from mlflow_oidc_auth.user import create_user, populate_groups, update_user
@@ -12,7 +12,7 @@
def login():
state = secrets.token_urlsafe(16)
session["oauth_state"] = state
- return oauth.oidc.authorize_redirect(config.OIDC_REDIRECT_URI, state=state)
+ return get_oauth_instance(app).oidc.authorize_redirect(config.OIDC_REDIRECT_URI, state=state)
def logout():
@@ -26,7 +26,7 @@ def callback():
if "oauth_state" not in session or utils.get_request_param("state") != session["oauth_state"]:
return "Invalid state parameter", 401
- token = oauth.oidc.authorize_access_token()
+ token = get_oauth_instance(app).oidc.authorize_access_token()
app.logger.debug(f"Token: {token}")
session["user"] = token["userinfo"]
@@ -40,9 +40,7 @@ def callback():
if config.OIDC_GROUP_DETECTION_PLUGIN:
import importlib
- user_groups = importlib.import_module(config.OIDC_GROUP_DETECTION_PLUGIN).get_user_groups(
- token["access_token"]
- )
+ user_groups = importlib.import_module(config.OIDC_GROUP_DETECTION_PLUGIN).get_user_groups(token["access_token"])
else:
user_groups = token["userinfo"][config.OIDC_GROUPS_ATTRIBUTE]
diff --git a/pyproject.toml b/pyproject.toml
index 76e21df..82916dc 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -23,26 +23,20 @@ classifiers = [
requires-python = ">=3.8"
dependencies = [
"cachelib<1",
- "mlflow-skinny<3,>=2.11.1",
+ "mlflow<3,>=2.11.1",
+ "oauthlib<4",
"python-dotenv<2",
"requests<3,>=2.31.0",
"sqlalchemy<3,>=1.4.0",
"Flask<4",
"Flask-Session>=0.7.0",
- "gunicorn<24; platform_system != 'Windows'",
- "alembic<2,!=1.10.0",
- "Authlib<2",
- "Flask-Caching<3"
+ "authlib>=1.3.2",
+ "flask-caching>=2.3.0"
]
-[project.optional-dependencies]
-full = ["mlflow<3,>=2.11.1"]
-caching-redis = ["redis[hiredis]<6"]
-dev = ["black<25", "pre-commit<5"]
-
[[project.maintainers]]
-name = "Alexander Kharkevich"
-email = "alexander_kharkevich@outlook.com"
+name = "Data Platform folks"
+email = "noreply@example.com"
[project.license]
file = "LICENSE"
@@ -74,7 +68,18 @@ include = ["mlflow_oidc_auth", "mlflow_oidc_auth.*"]
exclude = ["tests", "tests.*"]
[tool.setuptools.dynamic]
-version = { attr = "mlflow_oidc_auth.version" }
+version = {attr = "mlflow_oidc_auth.version"}
[tool.black]
line-length = 128
+
+[project.optional-dependencies]
+dev = [
+ "black==24.8.0",
+ "pytest==8.3.2",
+ "pre-commit==3.5.0",
+]
+test = [
+ "pytest==8.3.2",
+ "pytest-cov==5.0.0",
+]
diff --git a/scripts/run-dev-server.sh b/scripts/run-dev-server.sh
index 143028b..58671f5 100755
--- a/scripts/run-dev-server.sh
+++ b/scripts/run-dev-server.sh
@@ -17,6 +17,29 @@ python_preconfigure() {
fi
}
+check_yarn_and_node_version() {
+ if ! command -v node &> /dev/null; then
+ echo "node is not installed. Please install node to continue."
+ exit 1
+ fi
+
+ if ! command -v yarn &> /dev/null; then
+ echo "yarn is not installed. Please install yarn to continue."
+ exit 1
+ fi
+
+ node_version=$(node --version)
+
+ major=$(echo $node_version | cut -d. -f1 | tr -d 'v')
+ minor=$(echo $node_version | cut -d. -f2)
+ patch=$(echo $node_version | cut -d. -f3)
+
+ if ! { [ "$major" -eq 14 ] && [ "$minor" -eq 15 ] && [ "$patch" -eq 0 ]; } && ! { [ "$major" -ge 16 ] && { [ "$minor" -ge 10 ] || [ "$major" -gt 16 ]; }; }; then
+ echo "Node version $node_version is not supported. Please install node version ^14.15.0 || >=16.10.0 to continue."
+ exit 1
+ fi
+}
+
ui_preconfigure() {
if [ ! -d "web-ui/node_modules" ]; then
pushd web-ui
@@ -38,6 +61,7 @@ wait_server_ready() {
return 1
}
+check_yarn_and_node_version
python_preconfigure
source venv/bin/activate
mlflow server --dev --app-name oidc-auth --host 0.0.0.0 --port 8080 &
diff --git a/sonar-project.properties b/sonar-project.properties
new file mode 100644
index 0000000..81b479f
--- /dev/null
+++ b/sonar-project.properties
@@ -0,0 +1,7 @@
+sonar.projectKey=data-platform-hq_mlflow-oidc-auth
+sonar.organization=data-platform-hq
+
+sonar.python.version=3.11
+sonar.python.coverage.reportPaths=coverage.xml
+sonar.test.inclusions=**/test_*.py
+sonar.coverage.exclusions==**/test_*.py,**/app.py,**/db/migrations/versions/**/*.*,**/sqlalchemy_store.py,**/views.py # TODO: review sqlalchemy_store.py and views.py
diff --git a/tox.ini b/tox.ini
new file mode 100644
index 0000000..785448f
--- /dev/null
+++ b/tox.ini
@@ -0,0 +1,17 @@
+[tox]
+envlist = py311
+skipsdist = True
+
+[testenv]
+deps =
+ pytest
+ coverage<=7.6
+commands =
+ pip install -e '.[test]'
+ coverage run -m pytest -s mlflow_oidc_auth/tests
+ coverage xml
+
+[coverage:run]
+# relative_files = True
+source = mlflow_oidc_auth/
+branch = True