From 21595f4e7824ba86529a91156d3da80e33aeac2b Mon Sep 17 00:00:00 2001 From: Alexander Doroshevich Date: Fri, 5 Apr 2024 19:20:10 -0700 Subject: [PATCH 01/13] feat: add validators and code refactoring --- mlflow_oidc_auth/app.py | 8 +- mlflow_oidc_auth/config.py | 1 + mlflow_oidc_auth/views.py | 351 +++++++++++++++++++++++++++---------- 3 files changed, 261 insertions(+), 99 deletions(-) diff --git a/mlflow_oidc_auth/app.py b/mlflow_oidc_auth/app.py index c4c7ab4..df409ee 100644 --- a/mlflow_oidc_auth/app.py +++ b/mlflow_oidc_auth/app.py @@ -49,10 +49,10 @@ app.add_url_rule(rule=routes.GET_EXPERIMENT_PERMISSION, methods=["GET"], view_func=views.get_experiment_permission) app.add_url_rule(rule=routes.UPDATE_EXPERIMENT_PERMISSION, methods=["POST"], view_func=views.update_experiment_permission) app.add_url_rule(rule=routes.DELETE_EXPERIMENT_PERMISSION, methods=["POST"], view_func=views.delete_experiment_permission) -app.add_url_rule(rule=routes.CREATE_REGISTERED_MODEL_PERMISSION, methods=["POST"], view_func=views.create_model_permission) -app.add_url_rule(rule=routes.GET_REGISTERED_MODEL_PERMISSION, methods=["GET"], view_func=views.get_model_permission) -app.add_url_rule(rule=routes.UPDATE_REGISTERED_MODEL_PERMISSION, methods=["POST"], view_func=views.update_model_permission) -app.add_url_rule(rule=routes.DELETE_REGISTERED_MODEL_PERMISSION, methods=["POST"], view_func=views.delete_model_permission) +app.add_url_rule(rule=routes.CREATE_REGISTERED_MODEL_PERMISSION, methods=["POST"], view_func=views.create_registered_model_permission) +app.add_url_rule(rule=routes.GET_REGISTERED_MODEL_PERMISSION, methods=["GET"], view_func=views.get_registered_model_permission) +app.add_url_rule(rule=routes.UPDATE_REGISTERED_MODEL_PERMISSION, methods=["POST"], view_func=views.update_registered_model_permission) +app.add_url_rule(rule=routes.DELETE_REGISTERED_MODEL_PERMISSION, methods=["POST"], view_func=views.delete_registered_model_permission) # Add new hooks app.before_request(views.before_request_hook) diff --git a/mlflow_oidc_auth/config.py b/mlflow_oidc_auth/config.py index de8657d..3c4502e 100644 --- a/mlflow_oidc_auth/config.py +++ b/mlflow_oidc_auth/config.py @@ -10,6 +10,7 @@ class AppConfig: + DEFAULT_MLFLOW_PERMISSION = os.environ.get("DEFAULT_MLFLOW_PERMISSION", "MANAGE") SECRET_KEY = os.environ.get("SECRET_KEY", secrets.token_hex(16)) SESSION_TYPE = "cachelib" LEVEL = logging.DEBUG if os.environ.get("DEBUG") else logging.INFO diff --git a/mlflow_oidc_auth/views.py b/mlflow_oidc_auth/views.py index 9e78d1e..bbe5969 100644 --- a/mlflow_oidc_auth/views.py +++ b/mlflow_oidc_auth/views.py @@ -5,7 +5,7 @@ import string from typing import Callable, Union -from sqlalchemy import text +from mlflow.utils.proto_json_utils import parse_dict from werkzeug.datastructures import Authorization from flask import ( @@ -27,10 +27,46 @@ INVALID_PARAMETER_VALUE, RESOURCE_DOES_NOT_EXIST, ) + +from mlflow.protos.model_registry_pb2 import ( + CreateModelVersion, + CreateRegisteredModel, + DeleteModelVersion, + DeleteModelVersionTag, + DeleteRegisteredModel, + DeleteRegisteredModelAlias, + DeleteRegisteredModelTag, + GetLatestVersions, + GetModelVersion, + GetModelVersionByAlias, + GetModelVersionDownloadUri, + GetRegisteredModel, + RenameRegisteredModel, + SearchRegisteredModels, + SetModelVersionTag, + SetRegisteredModelAlias, + SetRegisteredModelTag, + TransitionModelVersionStage, + UpdateModelVersion, + UpdateRegisteredModel, +) + from mlflow.protos.service_pb2 import ( CreateExperiment, + CreateRun, + DeleteExperiment, + DeleteRun, + DeleteTag, + GetExperiment, + GetExperimentByName, + RestoreExperiment, + SearchExperiments, + SetExperimentTag, + SetTag, + UpdateExperiment, + UpdateRun, ) -from mlflow_oidc_auth.permissions import Permission, get_permission +from mlflow_oidc_auth.permissions import Permission, get_permission, MANAGE from mlflow.server.handlers import ( catch_mlflow_exception, get_endpoints, @@ -139,12 +175,97 @@ def _get_is_admin(): def _get_permission_from_experiment_id() -> Permission: experiment_id = _get_request_param("experiment_id") username = _get_username() - return _get_permission_from_store_or_default(lambda: store.get_experiment_permission(experiment_id, username).permission) + return _get_permission_from_store_or_default( + lambda: store.get_experiment_permission(experiment_id, username).permission) + + +def _get_permission_from_experiment_name() -> Permission: + experiment_name = _get_request_param("experiment_name") + store_exp = mlflow_client.get_experiment_by_name(experiment_name) + if store_exp is None: + raise MlflowException( + f"Could not find experiment with name {experiment_name}", + error_code=RESOURCE_DOES_NOT_EXIST, + ) + username = _get_username() + return _get_permission_from_store_or_default( + lambda: store.get_experiment_permission(store_exp.experiment_id, username).permission + ) + + +def _get_permission_from_registered_model_name() -> Permission: + name = _get_request_param("name") + username = _get_username() + return _get_permission_from_store_or_default( + lambda: store.get_registered_model_permission(name, username).permission + ) + + +def _set_can_manage_experiment_permission(resp: Response): + response_message = CreateExperiment.Response() + parse_dict(resp.json, response_message) + experiment_id = response_message.experiment_id + username = _get_username() + store.create_experiment_permission(experiment_id, username, MANAGE.name) + + +def _set_can_manage_registered_model_permission(resp: Response): + response_message = CreateRegisteredModel.Response() + parse_dict(resp.json, response_message) + name = response_message.registered_model.name + username = _get_username() + store.create_registered_model_permission(name, username, MANAGE.name) + + +def delete_can_manage_registered_model_permission(resp: Response): + """ + Delete registered model permission when the model is deleted. + + We need to do this because the primary key of the registered model is the name, + unlike the experiment where the primary key is experiment_id (UUID). Therefore, + we have to delete the permission record when the model is deleted otherwise it + conflicts with the new model registered with the same name. + """ + # Get model name from request context because it's not available in the response + name = request.get_json(force=True, silent=True)["name"] + username = _get_username() + store.delete_registered_model_permission(name, username) def _validate_can_manage_experiment(): - # return _get_permission_from_experiment_id().can_manage - return True + return _get_permission_from_experiment_id().can_manage + + +def _validate_can_manage_registered_model(): + return _get_permission_from_registered_model_name().can_manage + + +def _validate_can_read_experiment(): + return _get_permission_from_experiment_id().can_read + + +def _validate_can_read_experiment_by_name(): + return _get_permission_from_experiment_name().can_read + + +def _validate_can_update_experiment(): + return _get_permission_from_experiment_id().can_update + + +def _validate_can_delete_experiment(): + return _get_permission_from_experiment_id().can_delete + + +def _validate_can_read_registered_model(): + return _get_permission_from_registered_model_name().can_read + + +def _validate_can_update_registered_model(): + return _get_permission_from_registered_model_name().can_update + + +def _validate_can_delete_registered_model(): + return _get_permission_from_registered_model_name().can_delete def _get_before_request_handler(request_class): @@ -154,6 +275,31 @@ def _get_before_request_handler(request_class): BEFORE_REQUEST_HANDLERS = { # Routes for experiments CreateExperiment: _validate_can_manage_experiment, + GetExperiment: _validate_can_read_experiment, + GetExperimentByName: _validate_can_read_experiment_by_name, + DeleteExperiment: _validate_can_delete_experiment, + RestoreExperiment: _validate_can_delete_experiment, + UpdateExperiment: _validate_can_update_experiment, + SetExperimentTag: _validate_can_update_experiment, + # Routes for model registry + GetRegisteredModel: _validate_can_read_registered_model, + DeleteRegisteredModel: _validate_can_delete_registered_model, + UpdateRegisteredModel: _validate_can_update_registered_model, + RenameRegisteredModel: _validate_can_update_registered_model, + GetLatestVersions: _validate_can_read_registered_model, + CreateModelVersion: _validate_can_update_registered_model, + GetModelVersion: _validate_can_read_registered_model, + DeleteModelVersion: _validate_can_delete_registered_model, + UpdateModelVersion: _validate_can_update_registered_model, + TransitionModelVersionStage: _validate_can_update_registered_model, + GetModelVersionDownloadUri: _validate_can_read_registered_model, + SetRegisteredModelTag: _validate_can_update_registered_model, + DeleteRegisteredModelTag: _validate_can_update_registered_model, + SetModelVersionTag: _validate_can_update_registered_model, + DeleteModelVersionTag: _validate_can_delete_registered_model, + SetRegisteredModelAlias: _validate_can_update_registered_model, + DeleteRegisteredModelAlias: _validate_can_delete_registered_model, + GetModelVersionByAlias: _validate_can_read_registered_model, } BEFORE_REQUEST_VALIDATORS = { @@ -165,9 +311,34 @@ def _get_before_request_handler(request_class): BEFORE_REQUEST_VALIDATORS.update( { (routes.CREATE_EXPERIMENT_PERMISSION, "GET"): _validate_can_manage_experiment, + (routes.GET_EXPERIMENT_PERMISSION, "GET"): _validate_can_manage_experiment, + (routes.CREATE_EXPERIMENT_PERMISSION, "POST"): _validate_can_manage_experiment, + (routes.UPDATE_EXPERIMENT_PERMISSION, "PATCH"): _validate_can_manage_experiment, + (routes.DELETE_EXPERIMENT_PERMISSION, "DELETE"): _validate_can_manage_experiment, + (routes.GET_REGISTERED_MODEL_PERMISSION, "GET"): _validate_can_manage_registered_model, + (routes.CREATE_REGISTERED_MODEL_PERMISSION, "POST"): _validate_can_manage_registered_model, + (routes.UPDATE_REGISTERED_MODEL_PERMISSION, "PATCH"): _validate_can_manage_registered_model, + (routes.DELETE_REGISTERED_MODEL_PERMISSION, "DELETE"): _validate_can_manage_registered_model, } ) +AFTER_REQUEST_PATH_HANDLERS = { + CreateExperiment: _set_can_manage_experiment_permission, + CreateRegisteredModel: _set_can_manage_registered_model_permission, + DeleteRegisteredModel: delete_can_manage_registered_model_permission, +} + + +def get_after_request_handler(request_class): + return AFTER_REQUEST_PATH_HANDLERS.get(request_class) + + +AFTER_REQUEST_HANDLERS = { + (http_path, method): handler + for http_path, handler, methods in get_endpoints(get_after_request_handler) + for method in methods +} + def before_request_hook(): """Called before each request. If it did not return a response, @@ -206,18 +377,13 @@ def make_basic_auth_response() -> Response: return res +@catch_mlflow_exception def create_experiment_permission(): - request_data = request.get_json() - # Get the experiment - experiment = mlflow_client.get_experiment_by_name(request_data.get("experiment_name")) - - # # Update the experiment - store.create_experiment_permission( - experiment.experiment_id, - request_data.get("user_name"), - request_data.get("new_permission"), - ) - return "Experiment permission has been created." + experiment_id = _get_request_param("experiment_id") + username = _get_username() + permission = _get_request_param("permission") + ep = store.create_experiment_permission(experiment_id, username, permission) + return jsonify({"experiment_permission": ep.to_json()}) # Experiment views @@ -462,39 +628,36 @@ def get_user_models(username): def get_experiment_users(experiment_id): - # experiment_permissions is table name for experiments - # users is a table for users - with store.ManagedSessionMaker() as session: - query = text( - """ - SELECT users.username, experiment_permissions.permission - FROM users - JOIN experiment_permissions ON users.id = experiment_permissions.user_id - WHERE experiment_permissions.experiment_id = :experiment_id - """ - ) - results = session.execute(query, {"experiment_id": experiment_id}) - users_permissions = [{"username": row[0], "permission": row[1]} for row in results] + # Convert experiment_id to string for comparison + experiment_id = str(experiment_id) + + # Get the list of all users + list_users = store.list_users() - return jsonify(users_permissions) + # Filter users who are associated with the given experiment + usernames = [] + for user in list_users: + # Check if the user is associated with the experiment + user_experiments = [str(exp.experiment_id) for exp in user.experiment_permissions] + if experiment_id in user_experiments: + usernames.append(user.username) + + return jsonify({"usernames": usernames}) def get_model_users(model_name): - # registered_model_permissions is table name for models - # users is a table for users - with store.ManagedSessionMaker() as session: - query = text( - """ - SELECT users.username, registered_model_permissions.permission - FROM users - JOIN registered_model_permissions ON users.id = registered_model_permissions.user_id - WHERE registered_model_permissions.name = :model_name - """ - ) - results = session.execute(query, {"model_name": model_name}) - models_permissions = [{"username": row[0], "permission": row[1]} for row in results] + # Get the list of all users + list_users = store.list_users() + + # Filter users who are associated with the given model + usernames = [] + for user in list_users: + # Check if the user is associated with the model + user_models = [model.name for model in user.registered_model_permissions] + if model_name in user_models: + usernames.append(user.username) - return jsonify(models_permissions) + return jsonify({"usernames": usernames}) def _password_generation(): @@ -503,70 +666,68 @@ def _password_generation(): return new_password +@catch_mlflow_exception def update_experiment_permission(): - request_data = request.get_json() - # Get the experiment - experiment = mlflow_client.get_experiment_by_name(request_data.get("experiment_name")) - - # # Update the experiment - store.update_experiment_permission( - experiment.experiment_id, - request_data.get("user_name"), - request_data.get("new_permission"), - ) - return "Experiment permission has been changed." + experiment_id = _get_request_param("experiment_id") + username = _get_username() + permission = _get_request_param("permission") + store.update_experiment_permission(experiment_id, username, permission) + return make_response("Experiment permission has been updated") +@catch_mlflow_exception def delete_experiment_permission(): - request_data = request.get_json() - # Get the experiment - experiment = mlflow_client.get_experiment_by_name(request_data.get("experiment_name")) - - # # Update the experiment - store.delete_experiment_permission( - experiment.experiment_id, - request_data.get("user_name"), - ) - return "Experiment permission has been deleted." + experiment_id = _get_request_param("experiment_id") + username = _get_username() + store.delete_experiment_permission(experiment_id, username) + return make_response("Experiment permission has been deleted") -def create_model_permission(): - request_data = request.get_json() +@catch_mlflow_exception +def create_registered_model_permission(): + name = _get_request_param("name") + username = _get_username() + permission = _get_request_param("permission") + rmp = store.create_registered_model_permission(name, username, permission) + return make_response({"registered_model_permission": rmp.to_json()}) - store.create_registered_model_permission( - request_data.get("model_name"), - request_data.get("user_name"), - request_data.get("new_permission"), - ) - return "Model permission has been created." +@catch_mlflow_exception +def get_registered_model_permission(): + name = _get_request_param("name") + username = _get_username() + rmp = store.get_registered_model_permission(name, username) + return make_response({"registered_model_permission": rmp.to_json()}) -def get_model_permission(): - request_data = request.get_json() - permission = store.get_registered_model_permission( - request_data.get("model_name"), - request_data.get("user_name"), - ) - return jsonify({"model_permission": permission.to_json()}) +@catch_mlflow_exception +def update_registered_model_permission(): + name = _get_request_param("name") + username = _get_username() + permission = _get_request_param("permission") + store.update_registered_model_permission(name, username, permission) + return make_response("Model permission has been changed") -def update_model_permission(): - request_data = request.get_json() +@catch_mlflow_exception +def delete_registered_model_permission(): + name = _get_request_param("name") + username = _get_username() + store.delete_registered_model_permission(name, username) + return make_response("Model permission has been deleted") - store.update_registered_model_permission( - request_data.get("model_name"), - request_data.get("user_name"), - request_data.get("new_permission"), - ) - return "Model permission has been changed." +def set_can_manage_experiment_permission(resp: Response): + response_message = CreateExperiment.Response() + parse_dict(resp.json, response_message) + experiment_id = response_message.experiment_id + username = _get_username() + store.create_experiment_permission(experiment_id, username, MANAGE.name) -def delete_model_permission(): - request_data = request.get_json() - store.delete_registered_model_permission( - request_data.get("model_name"), - request_data.get("user_name"), - ) - return "Model permission has been deleted." +def set_can_manage_registered_model_permission(resp: Response): + response_message = CreateRegisteredModel.Response() + parse_dict(resp.json, response_message) + name = response_message.registered_model.name + username = _get_username() + store.create_registered_model_permission(name, username, MANAGE.name) From eceb5398620baa80ecd3a8aeaf11bac6df5f93df Mon Sep 17 00:00:00 2001 From: Alexander Doroshevich Date: Fri, 5 Apr 2024 22:01:37 -0700 Subject: [PATCH 02/13] fix: updated user receiving --- mlflow_oidc_auth/views.py | 64 +++++++++++++++++++-------------------- 1 file changed, 32 insertions(+), 32 deletions(-) diff --git a/mlflow_oidc_auth/views.py b/mlflow_oidc_auth/views.py index bbe5969..15ef85f 100644 --- a/mlflow_oidc_auth/views.py +++ b/mlflow_oidc_auth/views.py @@ -174,9 +174,9 @@ def _get_is_admin(): def _get_permission_from_experiment_id() -> Permission: experiment_id = _get_request_param("experiment_id") - username = _get_username() + user = store.get_user(_get_username()) return _get_permission_from_store_or_default( - lambda: store.get_experiment_permission(experiment_id, username).permission) + lambda: store.get_experiment_permission(experiment_id, user.username).permission) def _get_permission_from_experiment_name() -> Permission: @@ -187,17 +187,17 @@ def _get_permission_from_experiment_name() -> Permission: f"Could not find experiment with name {experiment_name}", error_code=RESOURCE_DOES_NOT_EXIST, ) - username = _get_username() + user = store.get_user(_get_username()) return _get_permission_from_store_or_default( - lambda: store.get_experiment_permission(store_exp.experiment_id, username).permission + lambda: store.get_experiment_permission(store_exp.experiment_id, user.username).permission ) def _get_permission_from_registered_model_name() -> Permission: name = _get_request_param("name") - username = _get_username() + user = store.get_user(_get_username()) return _get_permission_from_store_or_default( - lambda: store.get_registered_model_permission(name, username).permission + lambda: store.get_registered_model_permission(name, user.username).permission ) @@ -205,16 +205,16 @@ def _set_can_manage_experiment_permission(resp: Response): response_message = CreateExperiment.Response() parse_dict(resp.json, response_message) experiment_id = response_message.experiment_id - username = _get_username() - store.create_experiment_permission(experiment_id, username, MANAGE.name) + user = store.get_user(_get_username()) + store.create_experiment_permission(experiment_id, user.username, MANAGE.name) def _set_can_manage_registered_model_permission(resp: Response): response_message = CreateRegisteredModel.Response() parse_dict(resp.json, response_message) name = response_message.registered_model.name - username = _get_username() - store.create_registered_model_permission(name, username, MANAGE.name) + user = store.get_user(_get_username()) + store.create_registered_model_permission(name, user.username, MANAGE.name) def delete_can_manage_registered_model_permission(resp: Response): @@ -228,8 +228,8 @@ def delete_can_manage_registered_model_permission(resp: Response): """ # Get model name from request context because it's not available in the response name = request.get_json(force=True, silent=True)["name"] - username = _get_username() - store.delete_registered_model_permission(name, username) + user = store.get_user(_get_username()) + store.delete_registered_model_permission(name, user.username) def _validate_can_manage_experiment(): @@ -380,9 +380,9 @@ def make_basic_auth_response() -> Response: @catch_mlflow_exception def create_experiment_permission(): experiment_id = _get_request_param("experiment_id") - username = _get_username() + user = store.get_user(_get_username()) permission = _get_request_param("permission") - ep = store.create_experiment_permission(experiment_id, username, permission) + ep = store.create_experiment_permission(experiment_id, user.username, permission) return jsonify({"experiment_permission": ep.to_json()}) @@ -390,8 +390,8 @@ def create_experiment_permission(): @catch_mlflow_exception def get_experiment_permission(): experiment_id = _get_request_param("experiment_id") - username = _get_username() - ep = store.get_experiment_permission(experiment_id, username) + user = store.get_user(_get_username()) + ep = store.get_experiment_permission(experiment_id, user.username) return make_response({"experiment_permission": ep.to_json()}) @@ -669,51 +669,51 @@ def _password_generation(): @catch_mlflow_exception def update_experiment_permission(): experiment_id = _get_request_param("experiment_id") - username = _get_username() + user = store.get_user(_get_username()) permission = _get_request_param("permission") - store.update_experiment_permission(experiment_id, username, permission) + store.update_experiment_permission(experiment_id, user.username, permission) return make_response("Experiment permission has been updated") @catch_mlflow_exception def delete_experiment_permission(): experiment_id = _get_request_param("experiment_id") - username = _get_username() - store.delete_experiment_permission(experiment_id, username) + user = store.get_user(_get_username()) + store.delete_experiment_permission(experiment_id, user.username) return make_response("Experiment permission has been deleted") @catch_mlflow_exception def create_registered_model_permission(): name = _get_request_param("name") - username = _get_username() + user = store.get_user(_get_username()) permission = _get_request_param("permission") - rmp = store.create_registered_model_permission(name, username, permission) + rmp = store.create_registered_model_permission(name, user.username, permission) return make_response({"registered_model_permission": rmp.to_json()}) @catch_mlflow_exception def get_registered_model_permission(): name = _get_request_param("name") - username = _get_username() - rmp = store.get_registered_model_permission(name, username) + user = store.get_user(_get_username()) + rmp = store.get_registered_model_permission(name, user.username) return make_response({"registered_model_permission": rmp.to_json()}) @catch_mlflow_exception def update_registered_model_permission(): name = _get_request_param("name") - username = _get_username() + user = store.get_user(_get_username()) permission = _get_request_param("permission") - store.update_registered_model_permission(name, username, permission) + store.update_registered_model_permission(name, user.username, permission) return make_response("Model permission has been changed") @catch_mlflow_exception def delete_registered_model_permission(): name = _get_request_param("name") - username = _get_username() - store.delete_registered_model_permission(name, username) + user = store.get_user(_get_username()) + store.delete_registered_model_permission(name, user.username) return make_response("Model permission has been deleted") @@ -721,13 +721,13 @@ def set_can_manage_experiment_permission(resp: Response): response_message = CreateExperiment.Response() parse_dict(resp.json, response_message) experiment_id = response_message.experiment_id - username = _get_username() - store.create_experiment_permission(experiment_id, username, MANAGE.name) + user = store.get_user(_get_username()) + store.create_experiment_permission(experiment_id, user.username, MANAGE.name) def set_can_manage_registered_model_permission(resp: Response): response_message = CreateRegisteredModel.Response() parse_dict(resp.json, response_message) name = response_message.registered_model.name - username = _get_username() - store.create_registered_model_permission(name, username, MANAGE.name) + user = store.get_user(_get_username()) + store.create_registered_model_permission(name, user.username, MANAGE.name) From bd82c726e65fa44c05e8e62e72c4dcaaab2ace39 Mon Sep 17 00:00:00 2001 From: Alexander Doroshevich Date: Fri, 5 Apr 2024 19:20:10 -0700 Subject: [PATCH 03/13] feat: rebase rc into validators --- mlflow_oidc_auth/app.py | 8 +- mlflow_oidc_auth/config.py | 1 + mlflow_oidc_auth/views.py | 265 +++++++++++++++++++++++++++++++------ 3 files changed, 226 insertions(+), 48 deletions(-) diff --git a/mlflow_oidc_auth/app.py b/mlflow_oidc_auth/app.py index c4c7ab4..df409ee 100644 --- a/mlflow_oidc_auth/app.py +++ b/mlflow_oidc_auth/app.py @@ -49,10 +49,10 @@ app.add_url_rule(rule=routes.GET_EXPERIMENT_PERMISSION, methods=["GET"], view_func=views.get_experiment_permission) app.add_url_rule(rule=routes.UPDATE_EXPERIMENT_PERMISSION, methods=["POST"], view_func=views.update_experiment_permission) app.add_url_rule(rule=routes.DELETE_EXPERIMENT_PERMISSION, methods=["POST"], view_func=views.delete_experiment_permission) -app.add_url_rule(rule=routes.CREATE_REGISTERED_MODEL_PERMISSION, methods=["POST"], view_func=views.create_model_permission) -app.add_url_rule(rule=routes.GET_REGISTERED_MODEL_PERMISSION, methods=["GET"], view_func=views.get_model_permission) -app.add_url_rule(rule=routes.UPDATE_REGISTERED_MODEL_PERMISSION, methods=["POST"], view_func=views.update_model_permission) -app.add_url_rule(rule=routes.DELETE_REGISTERED_MODEL_PERMISSION, methods=["POST"], view_func=views.delete_model_permission) +app.add_url_rule(rule=routes.CREATE_REGISTERED_MODEL_PERMISSION, methods=["POST"], view_func=views.create_registered_model_permission) +app.add_url_rule(rule=routes.GET_REGISTERED_MODEL_PERMISSION, methods=["GET"], view_func=views.get_registered_model_permission) +app.add_url_rule(rule=routes.UPDATE_REGISTERED_MODEL_PERMISSION, methods=["POST"], view_func=views.update_registered_model_permission) +app.add_url_rule(rule=routes.DELETE_REGISTERED_MODEL_PERMISSION, methods=["POST"], view_func=views.delete_registered_model_permission) # Add new hooks app.before_request(views.before_request_hook) diff --git a/mlflow_oidc_auth/config.py b/mlflow_oidc_auth/config.py index de8657d..3c4502e 100644 --- a/mlflow_oidc_auth/config.py +++ b/mlflow_oidc_auth/config.py @@ -10,6 +10,7 @@ class AppConfig: + DEFAULT_MLFLOW_PERMISSION = os.environ.get("DEFAULT_MLFLOW_PERMISSION", "MANAGE") SECRET_KEY = os.environ.get("SECRET_KEY", secrets.token_hex(16)) SESSION_TYPE = "cachelib" LEVEL = logging.DEBUG if os.environ.get("DEBUG") else logging.INFO diff --git a/mlflow_oidc_auth/views.py b/mlflow_oidc_auth/views.py index 9418b8f..e1e103b 100644 --- a/mlflow_oidc_auth/views.py +++ b/mlflow_oidc_auth/views.py @@ -5,7 +5,7 @@ import string from typing import Callable, Union -from sqlalchemy import text +from mlflow.utils.proto_json_utils import parse_dict from werkzeug.datastructures import Authorization from flask import ( @@ -27,10 +27,46 @@ INVALID_PARAMETER_VALUE, RESOURCE_DOES_NOT_EXIST, ) + +from mlflow.protos.model_registry_pb2 import ( + CreateModelVersion, + CreateRegisteredModel, + DeleteModelVersion, + DeleteModelVersionTag, + DeleteRegisteredModel, + DeleteRegisteredModelAlias, + DeleteRegisteredModelTag, + GetLatestVersions, + GetModelVersion, + GetModelVersionByAlias, + GetModelVersionDownloadUri, + GetRegisteredModel, + RenameRegisteredModel, + SearchRegisteredModels, + SetModelVersionTag, + SetRegisteredModelAlias, + SetRegisteredModelTag, + TransitionModelVersionStage, + UpdateModelVersion, + UpdateRegisteredModel, +) + from mlflow.protos.service_pb2 import ( CreateExperiment, + CreateRun, + DeleteExperiment, + DeleteRun, + DeleteTag, + GetExperiment, + GetExperimentByName, + RestoreExperiment, + SearchExperiments, + SetExperimentTag, + SetTag, + UpdateExperiment, + UpdateRun, ) -from mlflow_oidc_auth.permissions import Permission, get_permission +from mlflow_oidc_auth.permissions import Permission, get_permission, MANAGE from mlflow.server.handlers import ( catch_mlflow_exception, get_endpoints, @@ -145,12 +181,97 @@ def _get_is_admin(): def _get_permission_from_experiment_id() -> Permission: experiment_id = _get_request_param("experiment_id") username = _get_username() - return _get_permission_from_store_or_default(lambda: store.get_experiment_permission(experiment_id, username).permission) + return _get_permission_from_store_or_default( + lambda: store.get_experiment_permission(experiment_id, username).permission) + + +def _get_permission_from_experiment_name() -> Permission: + experiment_name = _get_request_param("experiment_name") + store_exp = mlflow_client.get_experiment_by_name(experiment_name) + if store_exp is None: + raise MlflowException( + f"Could not find experiment with name {experiment_name}", + error_code=RESOURCE_DOES_NOT_EXIST, + ) + username = _get_username() + return _get_permission_from_store_or_default( + lambda: store.get_experiment_permission(store_exp.experiment_id, username).permission + ) + + +def _get_permission_from_registered_model_name() -> Permission: + name = _get_request_param("name") + username = _get_username() + return _get_permission_from_store_or_default( + lambda: store.get_registered_model_permission(name, username).permission + ) + + +def _set_can_manage_experiment_permission(resp: Response): + response_message = CreateExperiment.Response() + parse_dict(resp.json, response_message) + experiment_id = response_message.experiment_id + username = _get_username() + store.create_experiment_permission(experiment_id, username, MANAGE.name) + + +def _set_can_manage_registered_model_permission(resp: Response): + response_message = CreateRegisteredModel.Response() + parse_dict(resp.json, response_message) + name = response_message.registered_model.name + username = _get_username() + store.create_registered_model_permission(name, username, MANAGE.name) + + +def delete_can_manage_registered_model_permission(resp: Response): + """ + Delete registered model permission when the model is deleted. + + We need to do this because the primary key of the registered model is the name, + unlike the experiment where the primary key is experiment_id (UUID). Therefore, + we have to delete the permission record when the model is deleted otherwise it + conflicts with the new model registered with the same name. + """ + # Get model name from request context because it's not available in the response + name = request.get_json(force=True, silent=True)["name"] + username = _get_username() + store.delete_registered_model_permission(name, username) def _validate_can_manage_experiment(): - # return _get_permission_from_experiment_id().can_manage - return True + return _get_permission_from_experiment_id().can_manage + + +def _validate_can_manage_registered_model(): + return _get_permission_from_registered_model_name().can_manage + + +def _validate_can_read_experiment(): + return _get_permission_from_experiment_id().can_read + + +def _validate_can_read_experiment_by_name(): + return _get_permission_from_experiment_name().can_read + + +def _validate_can_update_experiment(): + return _get_permission_from_experiment_id().can_update + + +def _validate_can_delete_experiment(): + return _get_permission_from_experiment_id().can_delete + + +def _validate_can_read_registered_model(): + return _get_permission_from_registered_model_name().can_read + + +def _validate_can_update_registered_model(): + return _get_permission_from_registered_model_name().can_update + + +def _validate_can_delete_registered_model(): + return _get_permission_from_registered_model_name().can_delete def _get_before_request_handler(request_class): @@ -160,6 +281,31 @@ def _get_before_request_handler(request_class): BEFORE_REQUEST_HANDLERS = { # Routes for experiments CreateExperiment: _validate_can_manage_experiment, + GetExperiment: _validate_can_read_experiment, + GetExperimentByName: _validate_can_read_experiment_by_name, + DeleteExperiment: _validate_can_delete_experiment, + RestoreExperiment: _validate_can_delete_experiment, + UpdateExperiment: _validate_can_update_experiment, + SetExperimentTag: _validate_can_update_experiment, + # Routes for model registry + GetRegisteredModel: _validate_can_read_registered_model, + DeleteRegisteredModel: _validate_can_delete_registered_model, + UpdateRegisteredModel: _validate_can_update_registered_model, + RenameRegisteredModel: _validate_can_update_registered_model, + GetLatestVersions: _validate_can_read_registered_model, + CreateModelVersion: _validate_can_update_registered_model, + GetModelVersion: _validate_can_read_registered_model, + DeleteModelVersion: _validate_can_delete_registered_model, + UpdateModelVersion: _validate_can_update_registered_model, + TransitionModelVersionStage: _validate_can_update_registered_model, + GetModelVersionDownloadUri: _validate_can_read_registered_model, + SetRegisteredModelTag: _validate_can_update_registered_model, + DeleteRegisteredModelTag: _validate_can_update_registered_model, + SetModelVersionTag: _validate_can_update_registered_model, + DeleteModelVersionTag: _validate_can_delete_registered_model, + SetRegisteredModelAlias: _validate_can_update_registered_model, + DeleteRegisteredModelAlias: _validate_can_delete_registered_model, + GetModelVersionByAlias: _validate_can_read_registered_model, } BEFORE_REQUEST_VALIDATORS = { @@ -171,9 +317,34 @@ def _get_before_request_handler(request_class): BEFORE_REQUEST_VALIDATORS.update( { (routes.CREATE_EXPERIMENT_PERMISSION, "GET"): _validate_can_manage_experiment, + (routes.GET_EXPERIMENT_PERMISSION, "GET"): _validate_can_manage_experiment, + (routes.CREATE_EXPERIMENT_PERMISSION, "POST"): _validate_can_manage_experiment, + (routes.UPDATE_EXPERIMENT_PERMISSION, "PATCH"): _validate_can_manage_experiment, + (routes.DELETE_EXPERIMENT_PERMISSION, "DELETE"): _validate_can_manage_experiment, + (routes.GET_REGISTERED_MODEL_PERMISSION, "GET"): _validate_can_manage_registered_model, + (routes.CREATE_REGISTERED_MODEL_PERMISSION, "POST"): _validate_can_manage_registered_model, + (routes.UPDATE_REGISTERED_MODEL_PERMISSION, "PATCH"): _validate_can_manage_registered_model, + (routes.DELETE_REGISTERED_MODEL_PERMISSION, "DELETE"): _validate_can_manage_registered_model, } ) +AFTER_REQUEST_PATH_HANDLERS = { + CreateExperiment: _set_can_manage_experiment_permission, + CreateRegisteredModel: _set_can_manage_registered_model_permission, + DeleteRegisteredModel: delete_can_manage_registered_model_permission, +} + + +def get_after_request_handler(request_class): + return AFTER_REQUEST_PATH_HANDLERS.get(request_class) + + +AFTER_REQUEST_HANDLERS = { + (http_path, method): handler + for http_path, handler, methods in get_endpoints(get_after_request_handler) + for method in methods +} + def before_request_hook(): """Called before each request. If it did not return a response, @@ -476,39 +647,36 @@ def get_user_models(username): def get_experiment_users(experiment_id): - # experiment_permissions is table name for experiments - # users is a table for users - with store.ManagedSessionMaker() as session: - query = text( - """ - SELECT users.username, experiment_permissions.permission - FROM users - JOIN experiment_permissions ON users.id = experiment_permissions.user_id - WHERE experiment_permissions.experiment_id = :experiment_id - """ - ) - results = session.execute(query, {"experiment_id": experiment_id}) - users_permissions = [{"username": row[0], "permission": row[1]} for row in results] + # Convert experiment_id to string for comparison + experiment_id = str(experiment_id) - return jsonify(users_permissions) + # Get the list of all users + list_users = store.list_users() + + # Filter users who are associated with the given experiment + usernames = [] + for user in list_users: + # Check if the user is associated with the experiment + user_experiments = [str(exp.experiment_id) for exp in user.experiment_permissions] + if experiment_id in user_experiments: + usernames.append(user.username) + + return jsonify({"usernames": usernames}) def get_model_users(model_name): - # registered_model_permissions is table name for models - # users is a table for users - with store.ManagedSessionMaker() as session: - query = text( - """ - SELECT users.username, registered_model_permissions.permission - FROM users - JOIN registered_model_permissions ON users.id = registered_model_permissions.user_id - WHERE registered_model_permissions.name = :model_name - """ - ) - results = session.execute(query, {"model_name": model_name}) - models_permissions = [{"username": row[0], "permission": row[1]} for row in results] + # Get the list of all users + list_users = store.list_users() + + # Filter users who are associated with the given model + usernames = [] + for user in list_users: + # Check if the user is associated with the model + user_models = [model.name for model in user.registered_model_permissions] + if model_name in user_models: + usernames.append(user.username) - return jsonify(models_permissions) + return jsonify({"usernames": usernames}) def _password_generation(): @@ -535,9 +703,14 @@ def delete_experiment_permission(): ) return jsonify({"message": "Experiment permission has been deleted."}) +@catch_mlflow_exception +def create_registered_model_permission(): + name = _get_request_param("name") + username = _get_username() + permission = _get_request_param("permission") + rmp = store.create_registered_model_permission(name, username, permission) + return make_response({"registered_model_permission": rmp.to_json()}) -def create_model_permission(): - request_data = request.get_json() store.create_registered_model_permission( request_data.get("model_name"), @@ -547,18 +720,22 @@ def create_model_permission(): return jsonify({"message": "Model permission has been created."}) -def get_model_permission(): - request_data = request.get_json() +@catch_mlflow_exception +def update_registered_model_permission(): + name = _get_request_param("name") + username = _get_username() + permission = _get_request_param("permission") + store.update_registered_model_permission(name, username, permission) + return make_response("Model permission has been changed") - permission = store.get_registered_model_permission( - request_data.get("model_name"), - request_data.get("user_name"), - ) - return jsonify({"model_permission": permission.to_json()}) +@catch_mlflow_exception +def delete_registered_model_permission(): + name = _get_request_param("name") + username = _get_username() + store.delete_registered_model_permission(name, username) + return make_response("Model permission has been deleted") -def update_model_permission(): - request_data = request.get_json() store.update_registered_model_permission( request_data.get("model_name"), From ca4d6854bd0e0a40b41bc76adf83da1775a1a36a Mon Sep 17 00:00:00 2001 From: Alexander Doroshevich Date: Fri, 5 Apr 2024 22:01:37 -0700 Subject: [PATCH 04/13] fix: updated user receiving --- mlflow_oidc_auth/views.py | 38 ++++++++++++++++++-------------------- 1 file changed, 18 insertions(+), 20 deletions(-) diff --git a/mlflow_oidc_auth/views.py b/mlflow_oidc_auth/views.py index e1e103b..45d3602 100644 --- a/mlflow_oidc_auth/views.py +++ b/mlflow_oidc_auth/views.py @@ -712,12 +712,12 @@ def create_registered_model_permission(): return make_response({"registered_model_permission": rmp.to_json()}) - store.create_registered_model_permission( - request_data.get("model_name"), - request_data.get("user_name"), - request_data.get("new_permission"), - ) - return jsonify({"message": "Model permission has been created."}) +@catch_mlflow_exception +def get_registered_model_permission(): + name = _get_request_param("name") + username = _get_username() + rmp = store.get_registered_model_permission(name, username) + return make_response({"registered_model_permission": rmp.to_json()}) @catch_mlflow_exception @@ -737,19 +737,17 @@ def delete_registered_model_permission(): return make_response("Model permission has been deleted") - store.update_registered_model_permission( - request_data.get("model_name"), - request_data.get("user_name"), - request_data.get("new_permission"), - ) - return jsonify({"message": "Model permission has been changed."}) - +def set_can_manage_experiment_permission(resp: Response): + response_message = CreateExperiment.Response() + parse_dict(resp.json, response_message) + experiment_id = response_message.experiment_id + username = _get_username() + store.create_experiment_permission(experiment_id, username, MANAGE.name) -def delete_model_permission(): - request_data = request.get_json() - store.delete_registered_model_permission( - request_data.get("model_name"), - request_data.get("user_name"), - ) - return jsonify({"message": "Model permission has been deleted."}) +def set_can_manage_registered_model_permission(resp: Response): + response_message = CreateRegisteredModel.Response() + parse_dict(resp.json, response_message) + name = response_message.registered_model.name + username = _get_username() + store.create_registered_model_permission(name, username, MANAGE.name) From b5125e2d2bd47d7c46db8587b1f6a79ecdf903ec Mon Sep 17 00:00:00 2001 From: Alexander Doroshevich Date: Mon, 8 Apr 2024 16:19:49 -0700 Subject: [PATCH 05/13] fix: fixed user receiving --- mlflow_oidc_auth/views.py | 85 +++++++++++++++++++++++---------------- 1 file changed, 50 insertions(+), 35 deletions(-) diff --git a/mlflow_oidc_auth/views.py b/mlflow_oidc_auth/views.py index 45d3602..ed7bc1e 100644 --- a/mlflow_oidc_auth/views.py +++ b/mlflow_oidc_auth/views.py @@ -89,12 +89,13 @@ _logger = logging.getLogger(__name__) -def _get_experiment_id(request_data: dict) -> str: - experiment_id = request_data.get("experiment_id") - if "experiment_id" not in request_data: - experiment_id = mlflow_client.get_experiment_by_name(request_data.get("experiment_name")).experiment_id +def _get_experiment_id() -> str: + experiment_id = _get_request_param("experiment_id") + if not experiment_id: + experiment_id = mlflow_client.get_experiment_by_name(_get_request_param("experiment_name")).experiment_id return experiment_id + def _get_request_param(param: str) -> str: if request.method == "GET": args = request.args @@ -180,7 +181,7 @@ def _get_is_admin(): def _get_permission_from_experiment_id() -> Permission: experiment_id = _get_request_param("experiment_id") - username = _get_username() + username = _get_request_param("user_name") return _get_permission_from_store_or_default( lambda: store.get_experiment_permission(experiment_id, username).permission) @@ -193,17 +194,17 @@ def _get_permission_from_experiment_name() -> Permission: f"Could not find experiment with name {experiment_name}", error_code=RESOURCE_DOES_NOT_EXIST, ) - username = _get_username() + username = _get_request_param("user_name") return _get_permission_from_store_or_default( lambda: store.get_experiment_permission(store_exp.experiment_id, username).permission ) def _get_permission_from_registered_model_name() -> Permission: - name = _get_request_param("name") - username = _get_username() + model_name = _get_request_param("model_name") + username = _get_request_param("user_name") return _get_permission_from_store_or_default( - lambda: store.get_registered_model_permission(name, username).permission + lambda: store.get_registered_model_permission(model_name, username).permission ) @@ -211,7 +212,7 @@ def _set_can_manage_experiment_permission(resp: Response): response_message = CreateExperiment.Response() parse_dict(resp.json, response_message) experiment_id = response_message.experiment_id - username = _get_username() + username = _get_request_param("user_name") store.create_experiment_permission(experiment_id, username, MANAGE.name) @@ -219,7 +220,7 @@ def _set_can_manage_registered_model_permission(resp: Response): response_message = CreateRegisteredModel.Response() parse_dict(resp.json, response_message) name = response_message.registered_model.name - username = _get_username() + username = _get_request_param("user_name") store.create_registered_model_permission(name, username, MANAGE.name) @@ -233,9 +234,9 @@ def delete_can_manage_registered_model_permission(resp: Response): conflicts with the new model registered with the same name. """ # Get model name from request context because it's not available in the response - name = request.get_json(force=True, silent=True)["name"] - username = _get_username() - store.delete_registered_model_permission(name, username) + model_name = _get_request_param("model_name") + username = _get_request_param("user_name") + store.delete_registered_model_permission(model_name, username) def _validate_can_manage_experiment(): @@ -383,12 +384,12 @@ def make_basic_auth_response() -> Response: return res +@catch_mlflow_exception def create_experiment_permission(): - request_data = request.get_json() store.create_experiment_permission( - _get_experiment_id(request_data), - request_data.get("user_name"), - request_data.get("new_permission"), + _get_experiment_id(), + _get_request_param("user_name"), + _get_request_param("permission"), ) return jsonify({"message": "Experiment permission has been created."}) @@ -397,7 +398,7 @@ def create_experiment_permission(): @catch_mlflow_exception def get_experiment_permission(): experiment_id = _get_request_param("experiment_id") - username = _get_username() + username = _get_request_param("user_name") ep = store.get_experiment_permission(experiment_id, username) return make_response({"experiment_permission": ep.to_json()}) @@ -470,7 +471,8 @@ def callback(): if not any(group["displayName"] == AppConfig.get_property("OIDC_GROUP_NAME") for group in group_data["value"]): return "User not in group", 401 # set is_admin if user is in admin group - if any(group["displayName"] == AppConfig.get_property("OIDC_ADMIN_GROUP_NAME") for group in group_data["value"]): + if any(group["displayName"] == AppConfig.get_property("OIDC_ADMIN_GROUP_NAME") for group in + group_data["value"]): _set_is_admin(True) else: _set_is_admin(False) @@ -532,12 +534,14 @@ def create_user(): ) +@catch_mlflow_exception def create_access_token(): new_token = _password_generation() store.update_user(_get_username(), new_token) return jsonify({"token": new_token}) +@catch_mlflow_exception def get_current_user(): user = store.get_user(_get_username()) user_json = user.to_json() @@ -552,26 +556,29 @@ def get_current_user(): return jsonify(user_json) +@catch_mlflow_exception def update_username_password(): new_password = _password_generation() store.update_user(_get_username(), new_password) return jsonify({"token": new_password}) +@catch_mlflow_exception def update_user_admin(): is_admin = _get_request_param("is_admin") store.update_user(_get_username(), is_admin) return jsonify({"is_admin": is_admin}) +@catch_mlflow_exception def delete_user(): - store.delete_user(_get_username()) + store.delete_user(_get_request_param("user_name")) return jsonify({"message": f"User {_get_username()} has been deleted"}) @catch_mlflow_exception def get_user(): - username = _get_request_param("username") + username = _get_request_param("user_name") user = store.get_user(username) return jsonify({"user": user.to_json()}) @@ -584,6 +591,7 @@ def permissions(): return redirect(url_for("list_users")) +@catch_mlflow_exception def get_users(): # check is admin # if not _get_is_admin(): @@ -592,6 +600,7 @@ def get_users(): return jsonify({"users": users}) +@catch_mlflow_exception def get_experiments(): list_experiments = mlflow_client.search_experiments() experiments = [ @@ -605,6 +614,7 @@ def get_experiments(): return jsonify(experiments) +@catch_mlflow_exception def get_models(): registered_models = mlflow_client.search_registered_models() models = [ @@ -620,6 +630,7 @@ def get_models(): return jsonify(models) +@catch_mlflow_exception def get_user_experiments(username): # get list of experiments for the user list_experiments = store.list_experiment_permissions(username) @@ -636,6 +647,7 @@ def get_user_experiments(username): return jsonify({"experiments": experiments_list}) +@catch_mlflow_exception def get_user_models(username): # get list of models for current user registered_models = store.list_registered_model_permissions(username) @@ -646,6 +658,7 @@ def get_user_models(username): return jsonify({"models": models}) +@catch_mlflow_exception def get_experiment_users(experiment_id): # Convert experiment_id to string for comparison experiment_id = str(experiment_id) @@ -664,6 +677,7 @@ def get_experiment_users(experiment_id): return jsonify({"usernames": usernames}) +@catch_mlflow_exception def get_model_users(model_name): # Get the list of all users list_users = store.list_users() @@ -685,28 +699,29 @@ def _password_generation(): return new_password +@catch_mlflow_exception def update_experiment_permission(): - request_data = request.get_json() store.update_experiment_permission( - _get_experiment_id(request_data), - request_data.get("user_name"), - request_data.get("new_permission"), + _get_experiment_id(), + _get_request_param("user_name"), + _get_request_param("new_permission"), ) return jsonify({"message": "Experiment permission has been changed."}) +@catch_mlflow_exception def delete_experiment_permission(): - request_data = request.get_json() store.delete_experiment_permission( - _get_experiment_id(request_data), - request_data.get("user_name"), + _get_experiment_id(), + _get_request_param("user_name"), ) return jsonify({"message": "Experiment permission has been deleted."}) + @catch_mlflow_exception def create_registered_model_permission(): name = _get_request_param("name") - username = _get_username() + username = _get_request_param("user_name") permission = _get_request_param("permission") rmp = store.create_registered_model_permission(name, username, permission) return make_response({"registered_model_permission": rmp.to_json()}) @@ -715,7 +730,7 @@ def create_registered_model_permission(): @catch_mlflow_exception def get_registered_model_permission(): name = _get_request_param("name") - username = _get_username() + username = _get_request_param("user_name") rmp = store.get_registered_model_permission(name, username) return make_response({"registered_model_permission": rmp.to_json()}) @@ -723,7 +738,7 @@ def get_registered_model_permission(): @catch_mlflow_exception def update_registered_model_permission(): name = _get_request_param("name") - username = _get_username() + username = _get_request_param("user_name") permission = _get_request_param("permission") store.update_registered_model_permission(name, username, permission) return make_response("Model permission has been changed") @@ -732,7 +747,7 @@ def update_registered_model_permission(): @catch_mlflow_exception def delete_registered_model_permission(): name = _get_request_param("name") - username = _get_username() + username = _get_request_param("user_name") store.delete_registered_model_permission(name, username) return make_response("Model permission has been deleted") @@ -741,7 +756,7 @@ def set_can_manage_experiment_permission(resp: Response): response_message = CreateExperiment.Response() parse_dict(resp.json, response_message) experiment_id = response_message.experiment_id - username = _get_username() + username = _get_request_param("user_name") store.create_experiment_permission(experiment_id, username, MANAGE.name) @@ -749,5 +764,5 @@ def set_can_manage_registered_model_permission(resp: Response): response_message = CreateRegisteredModel.Response() parse_dict(resp.json, response_message) name = response_message.registered_model.name - username = _get_username() + username = _get_request_param("user_name") store.create_registered_model_permission(name, username, MANAGE.name) From 49f77e8992746e6042b4f1d7ac405dc699170905 Mon Sep 17 00:00:00 2001 From: Alexander Kharkevich Date: Fri, 12 Apr 2024 19:47:31 -0400 Subject: [PATCH 06/13] fix: remove unused code --- mlflow_oidc_auth/views.py | 18 ++---------------- 1 file changed, 2 insertions(+), 16 deletions(-) diff --git a/mlflow_oidc_auth/views.py b/mlflow_oidc_auth/views.py index 2976190..bcbf1ed 100644 --- a/mlflow_oidc_auth/views.py +++ b/mlflow_oidc_auth/views.py @@ -197,8 +197,7 @@ def _get_is_admin(): def _get_permission_from_experiment_id() -> Permission: experiment_id = _get_request_param("experiment_id") username = _get_request_param("user_name") - return _get_permission_from_store_or_default( - lambda: store.get_experiment_permission(experiment_id, username).permission) + return _get_permission_from_store_or_default(lambda: store.get_experiment_permission(experiment_id, username).permission) def _get_permission_from_experiment_name() -> Permission: @@ -218,9 +217,7 @@ def _get_permission_from_experiment_name() -> Permission: def _get_permission_from_registered_model_name() -> Permission: model_name = _get_request_param("model_name") username = _get_request_param("user_name") - return _get_permission_from_store_or_default( - lambda: store.get_registered_model_permission(model_name, username).permission - ) + return _get_permission_from_store_or_default(lambda: store.get_registered_model_permission(model_name, username).permission) def _set_can_manage_experiment_permission(resp: Response): @@ -418,12 +415,6 @@ def get_experiment_permission(): return make_response({"experiment_permission": ep.to_json()}) -# TODO -# @catch_mlflow_exception -# def search_experiment(): -# return render_template("home.html", username=_get_username()) - - def login(): state = secrets.token_urlsafe(16) session["oauth_state"] = state @@ -525,11 +516,6 @@ def oidc_ui(filename=None): return send_from_directory(ui_directory, filename) -# # TODO -# def search_model(): -# return render_template("home.html", username=_get_username()) - - def create_user(): try: user = store.get_user(_get_username()) From e1152a755a1bc1b109344f96c25259b427c56650 Mon Sep 17 00:00:00 2001 From: Alexander Kharkevich Date: Fri, 12 Apr 2024 19:49:43 -0400 Subject: [PATCH 07/13] ci: update validator --- .github/workflows/commit-message-check.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/commit-message-check.yml b/.github/workflows/commit-message-check.yml index bb1aa7e..a95f16c 100644 --- a/.github/workflows/commit-message-check.yml +++ b/.github/workflows/commit-message-check.yml @@ -15,7 +15,7 @@ jobs: - name: Check Commit Type uses: gsactions/commit-message-checker@v2 with: - pattern: '^(|feat|fix|chore|docs|style|ci|refactor|perf|test)(\([\w-]+\))?:\s.+$|^(Merge\sbranch)' + pattern: '^(|feat|fix|chore|docs|style|ci|refactor|perf|test)(\([\w-]+\))?:\s.+$|^(Merge\sbranch)|^(Merge\sremote)' error: 'Your first line has to contain a commit type like "feat|fix|chore|docs|style|ci|refactor|perf|test".' excludeDescription: "true" excludeTitle: "true" From 4943484ed5e504ccce89d7a5958ddf137f179bbf Mon Sep 17 00:00:00 2001 From: Alexander Kharkevich Date: Sun, 14 Apr 2024 00:17:17 -0400 Subject: [PATCH 08/13] feat: update user permission validation --- README.md | 3 - mlflow_oidc_auth/config.py | 4 +- mlflow_oidc_auth/views.py | 419 ++++++++++++------ ...experiment-permission-details.component.ts | 2 +- .../model-permission-details.component.ts | 6 +- .../user-permission-details.component.ts | 8 +- .../interfaces/permission-data.interface.ts | 6 +- .../services/data/permission-data.service.ts | 6 +- .../services/permission-modal.service.ts | 6 +- 9 files changed, 312 insertions(+), 148 deletions(-) diff --git a/README.md b/README.md index a49297f..25fccc7 100644 --- a/README.md +++ b/README.md @@ -24,9 +24,6 @@ The plugin required the following environment variables but also supported `.env | OAUTHLIB_INSECURE_TRANSPORT | Development only. Allow to use insecure endpoints for OIDC | | LOG_LEVEL | Application log level | | OIDC_USERS_DB_URI | Database connection string | -| MLFLOW_TRACKING_USERNAME | Credentials for internal communications via API | -| MLFLOW_TRACKING_PASSWORD | Credentials for internal communications via API | -| MLFLOW_TRACKING_URI | URI for internal communications via API | # Configuration examples diff --git a/mlflow_oidc_auth/config.py b/mlflow_oidc_auth/config.py index 2e89a51..c765dd6 100644 --- a/mlflow_oidc_auth/config.py +++ b/mlflow_oidc_auth/config.py @@ -36,9 +36,7 @@ class AppConfig: OIDC_REDIRECT_URI = os.environ.get("OIDC_REDIRECT_URI", None) OIDC_CLIENT_ID = os.environ.get("OIDC_CLIENT_ID", None) OIDC_CLIENT_SECRET = os.environ.get("OIDC_CLIENT_SECRET", None) - MLFLOW_TRACKING_URI = os.environ.get("MLFLOW_TRACKING_URI", "http://localhost:8080") - MLFLOW_TRACKING_USERNAME = os.environ.get("MLFLOW_TRACKING_USERNAME", secrets.token_urlsafe(32)) - MLFLOW_TRACKING_PASSWORD = os.environ.get("MLFLOW_TRACKING_PASSWORD", secrets.token_urlsafe(72)) + @staticmethod def get_property(property_name): diff --git a/mlflow_oidc_auth/views.py b/mlflow_oidc_auth/views.py index bcbf1ed..c33cba7 100644 --- a/mlflow_oidc_auth/views.py +++ b/mlflow_oidc_auth/views.py @@ -1,13 +1,17 @@ import logging import os +import re import requests import secrets import string from typing import Callable, Union -from mlflow.utils.proto_json_utils import parse_dict +from mlflow.store.entities import PagedList +from mlflow.entities.model_registry import RegisteredModel +from mlflow.entities import Experiment +from mlflow.utils.proto_json_utils import message_to_json, parse_dict from werkzeug.datastructures import Authorization - +from mlflow.utils.search_utils import SearchUtils from flask import ( make_response, request, @@ -59,7 +63,15 @@ DeleteTag, GetExperiment, GetExperimentByName, + GetMetricHistory, + GetRun, + ListArtifacts, + LogBatch, + LogMetric, + LogModel, + LogParam, RestoreExperiment, + RestoreRun, SearchExperiments, SetExperimentTag, SetTag, @@ -68,6 +80,9 @@ ) from mlflow_oidc_auth.permissions import Permission, get_permission, MANAGE from mlflow.server.handlers import ( + _get_model_registry_store, + _get_request_message, + _get_tracking_store, catch_mlflow_exception, get_endpoints, ) @@ -83,17 +98,20 @@ # Create the OAuth2 client auth_client = WebApplicationClient(AppConfig.get_property("OIDC_CLIENT_ID")) -mlflow_client = MlflowClient(tracking_uri=AppConfig.get_property("MLFLOW_TRACKING_URI")) store = SqlAlchemyStore() store.init_db((AppConfig.get_property("OIDC_USERS_DB_URI"))) _logger = logging.getLogger(__name__) -def _get_experiment_id() -> str: - experiment_id = _get_request_param("experiment_id") - if not experiment_id: - experiment_id = mlflow_client.get_experiment_by_name(_get_request_param("experiment_name")).experiment_id - return experiment_id +def _get_experiment_id(request_data: dict) -> str: + if "experiment_id" in request_data: + return request_data.get("experiment_id") + elif "experiment_name" in request_data: + return _get_tracking_store().get_experiment_by_name(request_data.get("experiment_name")).experiment_id + raise MlflowException( + "Either 'experiment_id' or 'experiment_name' must be provided in the request data.", + INVALID_PARAMETER_VALUE, + ) def _get_request_param(param: str) -> str: @@ -145,25 +163,34 @@ def _get_permission_from_store_or_default(store_permission_func: Callable[[], st return get_permission(perm) +# def authenticate_request() -> Union[Authorization, Response]: +# """Use configured authorization function to get request authorization.""" +# auth_func = get_auth_func(auth_config.authorization_function) +# return auth_func() + +# @functools.lru_cache(maxsize=None) +# def get_auth_func(authorization_function: str) -> Callable[[], Union[Authorization, Response]]: +# """ +# Import and return the specified authorization function. + +# Args: +# authorization_function: A string of the form "module.submodule:auth_func" +# """ +# mod_name, fn_name = authorization_function.split(":", 1) +# module = importlib.import_module(mod_name) +# return getattr(module, fn_name) + + def authenticate_request_basic_auth() -> Union[Authorization, Response]: username = request.authorization.username password = request.authorization.password _logger.debug("Authenticating user %s", username) - # check for internal call, if credentials are correct, return True - if username == AppConfig.get_property("MLFLOW_TRACKING_USERNAME") and password == AppConfig.get_property( - "MLFLOW_TRACKING_PASSWORD" - ): - _set_username(username) - _set_is_admin(True) - _logger.debug("User %s authenticated", username) - return True if store.authenticate_user(username.lower(), password): _set_username(username) _logger.debug("User %s authenticated", username) return True else: _logger.debug("User %s not authenticated", username) - # let user attempt login again return False @@ -195,47 +222,163 @@ def _get_is_admin(): def _get_permission_from_experiment_id() -> Permission: - experiment_id = _get_request_param("experiment_id") - username = _get_request_param("user_name") + experiment_id = _get_experiment_id(request.get_json()) + username = _get_username() return _get_permission_from_store_or_default(lambda: store.get_experiment_permission(experiment_id, username).permission) +_EXPERIMENT_ID_PATTERN = re.compile(r"^(\d+)/") + + +def _get_experiment_id_from_view_args(): + if artifact_path := request.view_args.get("artifact_path"): + if m := _EXPERIMENT_ID_PATTERN.match(artifact_path): + return m.group(1) + return None + + +def _get_permission_from_experiment_id_artifact_proxy() -> Permission: + if experiment_id := _get_experiment_id_from_view_args(): + username = _get_username() + return _get_permission_from_store_or_default( + lambda: store.get_experiment_permission(experiment_id, username).permission + ) + return get_permission(AppConfig.get_property("DEFAULT_MLFLOW_PERMISSION")) + + def _get_permission_from_experiment_name() -> Permission: experiment_name = _get_request_param("experiment_name") - store_exp = mlflow_client.get_experiment_by_name(experiment_name) + store_exp = _get_tracking_store().get_experiment_by_name(experiment_name) if store_exp is None: raise MlflowException( f"Could not find experiment with name {experiment_name}", error_code=RESOURCE_DOES_NOT_EXIST, ) - username = _get_request_param("user_name") + username = _get_username() return _get_permission_from_store_or_default( lambda: store.get_experiment_permission(store_exp.experiment_id, username).permission ) +def _get_permission_from_run_id() -> Permission: + # run permissions inherit from parent resource (experiment) + # so we just get the experiment permission + run_id = _get_request_param("run_id") + run = _get_tracking_store().get_run(run_id) + experiment_id = run.info.experiment_id + username = _get_username() + return _get_permission_from_store_or_default(lambda: store.get_experiment_permission(experiment_id, username).permission) + + def _get_permission_from_registered_model_name() -> Permission: - model_name = _get_request_param("model_name") - username = _get_request_param("user_name") + model_name = _get_request_param("name") + username = _get_username() return _get_permission_from_store_or_default(lambda: store.get_registered_model_permission(model_name, username).permission) -def _set_can_manage_experiment_permission(resp: Response): +def set_can_manage_experiment_permission(resp: Response): response_message = CreateExperiment.Response() parse_dict(resp.json, response_message) experiment_id = response_message.experiment_id - username = _get_request_param("user_name") + username = _get_username() store.create_experiment_permission(experiment_id, username, MANAGE.name) -def _set_can_manage_registered_model_permission(resp: Response): +def set_can_manage_registered_model_permission(resp: Response): response_message = CreateRegisteredModel.Response() parse_dict(resp.json, response_message) name = response_message.registered_model.name - username = _get_request_param("user_name") + username = _get_username() store.create_registered_model_permission(name, username, MANAGE.name) +def filter_search_experiments(resp: Response): + if _get_is_admin(): + return + + response_message = SearchExperiments.Response() + parse_dict(resp.json, response_message) + + # fetch permissions + username = _get_username() + perms = store.list_experiment_permissions(username) + can_read = {p.experiment_id: get_permission(p.permission).can_read for p in perms} + default_can_read = get_permission(AppConfig.get_property("DEFAULT_MLFLOW_PERMISSION")).can_read + + # filter out unreadable + for e in list(response_message.experiments): + if not can_read.get(e.experiment_id, default_can_read): + response_message.experiments.remove(e) + + # re-fetch to fill max results + request_message = _get_request_message(SearchExperiments()) + while len(response_message.experiments) < request_message.max_results and response_message.next_page_token != "": + refetched: PagedList[Experiment] = _get_tracking_store().search_experiments( + view_type=request_message.view_type, + max_results=request_message.max_results, + order_by=request_message.order_by, + filter_string=request_message.filter, + page_token=response_message.next_page_token, + ) + refetched = refetched[: request_message.max_results - len(response_message.experiments)] + if len(refetched) == 0: + response_message.next_page_token = "" + break + + refetched_readable_proto = [e.to_proto() for e in refetched if can_read.get(e.experiment_id, default_can_read)] + response_message.experiments.extend(refetched_readable_proto) + + # recalculate next page token + start_offset = SearchUtils.parse_start_offset_from_page_token(response_message.next_page_token) + final_offset = start_offset + len(refetched) + response_message.next_page_token = SearchUtils.create_page_token(final_offset) + + resp.data = message_to_json(response_message) + + +def filter_search_registered_models(resp: Response): + if _get_is_admin(): + return + + response_message = SearchRegisteredModels.Response() + parse_dict(resp.json, response_message) + + # fetch permissions + username = _get_username() + perms = store.list_registered_model_permissions(username) + can_read = {p.name: get_permission(p.permission).can_read for p in perms} + default_can_read = get_permission(AppConfig.get_property("DEFAULT_MLFLOW_PERMISSION")).can_read + + # filter out unreadable + for rm in list(response_message.registered_models): + if not can_read.get(rm.name, default_can_read): + response_message.registered_models.remove(rm) + + # re-fetch to fill max results + request_message = _get_request_message(SearchRegisteredModels()) + while len(response_message.registered_models) < request_message.max_results and response_message.next_page_token != "": + refetched: PagedList[RegisteredModel] = _get_model_registry_store().search_registered_models( + filter_string=request_message.filter, + max_results=request_message.max_results, + order_by=request_message.order_by, + page_token=response_message.next_page_token, + ) + refetched = refetched[: request_message.max_results - len(response_message.registered_models)] + if len(refetched) == 0: + response_message.next_page_token = "" + break + + refetched_readable_proto = [rm.to_proto() for rm in refetched if can_read.get(rm.name, default_can_read)] + response_message.registered_models.extend(refetched_readable_proto) + + # recalculate next page token + start_offset = SearchUtils.parse_start_offset_from_page_token(response_message.next_page_token) + final_offset = start_offset + len(refetched) + response_message.next_page_token = SearchUtils.create_page_token(final_offset) + + resp.data = message_to_json(response_message) + + def delete_can_manage_registered_model_permission(resp: Response): """ Delete registered model permission when the model is deleted. @@ -246,105 +389,155 @@ def delete_can_manage_registered_model_permission(resp: Response): conflicts with the new model registered with the same name. """ # Get model name from request context because it's not available in the response - model_name = _get_request_param("model_name") + model_name = _get_request_param("name") username = _get_request_param("user_name") store.delete_registered_model_permission(model_name, username) -def _validate_can_manage_experiment(): +def validate_can_read_experiment(): + return _get_permission_from_experiment_id().can_read + + +def validate_can_read_experiment_by_name(): + return _get_permission_from_experiment_name().can_read + + +def validate_can_update_experiment(): + return _get_permission_from_experiment_id().can_update + + +def validate_can_delete_experiment(): + return _get_permission_from_experiment_id().can_delete + + +def validate_can_manage_experiment(): return _get_permission_from_experiment_id().can_manage -def _validate_can_manage_registered_model(): - return _get_permission_from_registered_model_name().can_manage +def validate_can_read_experiment_artifact_proxy(): + return _get_permission_from_experiment_id_artifact_proxy().can_read -def _validate_can_read_experiment(): - return _get_permission_from_experiment_id().can_read +def validate_can_update_experiment_artifact_proxy(): + return _get_permission_from_experiment_id_artifact_proxy().can_update -def _validate_can_read_experiment_by_name(): - return _get_permission_from_experiment_name().can_read +def validate_can_delete_experiment_artifact_proxy(): + return _get_permission_from_experiment_id_artifact_proxy().can_manage -def _validate_can_update_experiment(): - return _get_permission_from_experiment_id().can_update +def validate_can_read_run(): + return _get_permission_from_run_id().can_read -def _validate_can_delete_experiment(): - return _get_permission_from_experiment_id().can_delete +def validate_can_update_run(): + return _get_permission_from_run_id().can_update -def _validate_can_read_registered_model(): +def validate_can_delete_run(): + return _get_permission_from_run_id().can_delete + + +def validate_can_manage_run(): + return _get_permission_from_run_id().can_manage + + +def validate_can_read_registered_model(): return _get_permission_from_registered_model_name().can_read -def _validate_can_update_registered_model(): +def validate_can_update_registered_model(): return _get_permission_from_registered_model_name().can_update -def _validate_can_delete_registered_model(): +def validate_can_delete_registered_model(): return _get_permission_from_registered_model_name().can_delete -def _get_before_request_handler(request_class): +def validate_can_manage_registered_model(): + return _get_permission_from_registered_model_name().can_manage + + +def get_before_request_handler(request_class): return BEFORE_REQUEST_HANDLERS.get(request_class) BEFORE_REQUEST_HANDLERS = { # Routes for experiments - CreateExperiment: _validate_can_manage_experiment, - GetExperiment: _validate_can_read_experiment, - GetExperimentByName: _validate_can_read_experiment_by_name, - DeleteExperiment: _validate_can_delete_experiment, - RestoreExperiment: _validate_can_delete_experiment, - UpdateExperiment: _validate_can_update_experiment, - SetExperimentTag: _validate_can_update_experiment, - # Routes for model registry - GetRegisteredModel: _validate_can_read_registered_model, - DeleteRegisteredModel: _validate_can_delete_registered_model, - UpdateRegisteredModel: _validate_can_update_registered_model, - RenameRegisteredModel: _validate_can_update_registered_model, - GetLatestVersions: _validate_can_read_registered_model, - CreateModelVersion: _validate_can_update_registered_model, - GetModelVersion: _validate_can_read_registered_model, - DeleteModelVersion: _validate_can_delete_registered_model, - UpdateModelVersion: _validate_can_update_registered_model, - TransitionModelVersionStage: _validate_can_update_registered_model, - GetModelVersionDownloadUri: _validate_can_read_registered_model, - SetRegisteredModelTag: _validate_can_update_registered_model, - DeleteRegisteredModelTag: _validate_can_update_registered_model, - SetModelVersionTag: _validate_can_update_registered_model, - DeleteModelVersionTag: _validate_can_delete_registered_model, - SetRegisteredModelAlias: _validate_can_update_registered_model, - DeleteRegisteredModelAlias: _validate_can_delete_registered_model, - GetModelVersionByAlias: _validate_can_read_registered_model, + ## CreateExperiment: _validate_can_manage_experiment, + GetExperiment: validate_can_read_experiment, + GetExperimentByName: validate_can_read_experiment_by_name, + DeleteExperiment: validate_can_delete_experiment, + RestoreExperiment: validate_can_delete_experiment, + UpdateExperiment: validate_can_update_experiment, + SetExperimentTag: validate_can_update_experiment, + # # Routes for runs + CreateRun: validate_can_update_experiment, + GetRun: validate_can_read_run, + DeleteRun: validate_can_delete_run, + RestoreRun: validate_can_delete_run, + UpdateRun: validate_can_update_run, + LogMetric: validate_can_update_run, + LogBatch: validate_can_update_run, + LogModel: validate_can_update_run, + SetTag: validate_can_update_run, + DeleteTag: validate_can_update_run, + LogParam: validate_can_update_run, + GetMetricHistory: validate_can_read_run, + ListArtifacts: validate_can_read_run, + # # Routes for model registry + GetRegisteredModel: validate_can_read_registered_model, + DeleteRegisteredModel: validate_can_delete_registered_model, + UpdateRegisteredModel: validate_can_update_registered_model, + RenameRegisteredModel: validate_can_update_registered_model, + GetLatestVersions: validate_can_read_registered_model, + CreateModelVersion: validate_can_update_registered_model, + GetModelVersion: validate_can_read_registered_model, + DeleteModelVersion: validate_can_delete_registered_model, + UpdateModelVersion: validate_can_update_registered_model, + TransitionModelVersionStage: validate_can_update_registered_model, + GetModelVersionDownloadUri: validate_can_read_registered_model, + SetRegisteredModelTag: validate_can_update_registered_model, + DeleteRegisteredModelTag: validate_can_update_registered_model, + SetModelVersionTag: validate_can_update_registered_model, + DeleteModelVersionTag: validate_can_delete_registered_model, + SetRegisteredModelAlias: validate_can_update_registered_model, + DeleteRegisteredModelAlias: validate_can_delete_registered_model, + GetModelVersionByAlias: validate_can_read_registered_model, } BEFORE_REQUEST_VALIDATORS = { (http_path, method): handler - for http_path, handler, methods in get_endpoints(_get_before_request_handler) + for http_path, handler, methods in get_endpoints(get_before_request_handler) for method in methods } BEFORE_REQUEST_VALIDATORS.update( { - (routes.CREATE_EXPERIMENT_PERMISSION, "GET"): _validate_can_manage_experiment, - (routes.GET_EXPERIMENT_PERMISSION, "GET"): _validate_can_manage_experiment, - (routes.CREATE_EXPERIMENT_PERMISSION, "POST"): _validate_can_manage_experiment, - (routes.UPDATE_EXPERIMENT_PERMISSION, "PATCH"): _validate_can_manage_experiment, - (routes.DELETE_EXPERIMENT_PERMISSION, "DELETE"): _validate_can_manage_experiment, - (routes.GET_REGISTERED_MODEL_PERMISSION, "GET"): _validate_can_manage_registered_model, - (routes.CREATE_REGISTERED_MODEL_PERMISSION, "POST"): _validate_can_manage_registered_model, - (routes.UPDATE_REGISTERED_MODEL_PERMISSION, "PATCH"): _validate_can_manage_registered_model, - (routes.DELETE_REGISTERED_MODEL_PERMISSION, "DELETE"): _validate_can_manage_registered_model, + (routes.GET_EXPERIMENT_PERMISSION, "GET"): validate_can_manage_experiment, + (routes.CREATE_EXPERIMENT_PERMISSION, "GET"): validate_can_manage_experiment, + (routes.CREATE_EXPERIMENT_PERMISSION, "POST"): validate_can_manage_experiment, + (routes.UPDATE_EXPERIMENT_PERMISSION, "PATCH"): validate_can_manage_experiment, + (routes.DELETE_EXPERIMENT_PERMISSION, "DELETE"): validate_can_manage_experiment, + (routes.GET_REGISTERED_MODEL_PERMISSION, "GET"): validate_can_manage_registered_model, + (routes.CREATE_REGISTERED_MODEL_PERMISSION, "POST"): validate_can_manage_registered_model, + (routes.UPDATE_REGISTERED_MODEL_PERMISSION, "PATCH"): validate_can_manage_registered_model, + (routes.DELETE_REGISTERED_MODEL_PERMISSION, "DELETE"): validate_can_manage_registered_model, + # (SIGNUP, "GET"): validate_can_create_user, + # (GET_USER, "GET"): validate_can_read_user, + # (CREATE_USER, "POST"): validate_can_create_user, + # (UPDATE_USER_PASSWORD, "PATCH"): validate_can_update_user_password, + # (UPDATE_USER_ADMIN, "PATCH"): validate_can_update_user_admin, + # (DELETE_USER, "DELETE"): validate_can_delete_user, } ) AFTER_REQUEST_PATH_HANDLERS = { - CreateExperiment: _set_can_manage_experiment_permission, - CreateRegisteredModel: _set_can_manage_registered_model_permission, - DeleteRegisteredModel: delete_can_manage_registered_model_permission, + CreateExperiment: set_can_manage_experiment_permission, + CreateRegisteredModel: set_can_manage_registered_model_permission, + # ???? DeleteRegisteredModel: delete_can_manage_registered_model_permission, + SearchExperiments: filter_search_experiments, + SearchRegisteredModels: filter_search_registered_models, } @@ -399,7 +592,7 @@ def make_basic_auth_response() -> Response: @catch_mlflow_exception def create_experiment_permission(): store.create_experiment_permission( - _get_experiment_id(), + _get_experiment_id(request.get_json()), _get_request_param("user_name"), _get_request_param("permission"), ) @@ -473,7 +666,6 @@ def callback(): }, ) group_data = group_response.json() - print(AppConfig.get_property("OIDC_GROUP_NAME")) if not any(group["displayName"] == AppConfig.get_property("OIDC_GROUP_NAME") for group in group_data["value"]): return "User not in group", 401 # set is_admin if user is in admin group @@ -482,7 +674,6 @@ def callback(): else: _set_is_admin(False) elif AppConfig.get_property("OIDC_PROVIDER_TYPE") == "oidc": - print(user_data.get("groups", [])) if AppConfig.get_property("OIDC_GROUP_NAME") not in user_data.get("groups", []): return "User not in group", 401 # set is_admin if user is in admin group @@ -508,7 +699,6 @@ def oidc_static(filename): def oidc_ui(filename=None): # Specify the directory where your static files are located ui_directory = os.path.join(os.path.dirname(__file__), "ui") - print(filename) if not filename: filename = "index.html" elif not os.path.exists(os.path.join(ui_directory, filename)): @@ -547,7 +737,7 @@ def get_current_user(): user_json = user.to_json() user_json["experiment_permissions"] = [ { - "name": mlflow_client.get_experiment(permission.experiment_id).name, + "name": _get_tracking_store().get_experiment(permission.experiment_id).name, "id": permission.experiment_id, "permission": permission.permission, } @@ -598,7 +788,7 @@ def get_users(): @catch_mlflow_exception def get_experiments(): - list_experiments = mlflow_client.search_experiments() + list_experiments = _get_tracking_store().search_experiments() experiments = [ { "name": experiment.name, @@ -612,7 +802,8 @@ def get_experiments(): @catch_mlflow_exception def get_models(): - registered_models = mlflow_client.search_registered_models() + # TODO: Implement pagination + registered_models = _get_model_registry_store().search_registered_models(max_results=1000) models = [ { "name": model.name, @@ -631,7 +822,7 @@ def get_user_experiments(username): list_experiments = store.list_experiment_permissions(username) experiments_list = [] for experiments in list_experiments: - experiment = mlflow_client.get_experiment(experiments.experiment_id) + experiment = _get_tracking_store().get_experiment(experiments.experiment_id) experiments_list.append( { "name": experiment.name, @@ -654,38 +845,32 @@ def get_user_models(username): @catch_mlflow_exception -def get_experiment_users(experiment_id): - # Convert experiment_id to string for comparison +def get_experiment_users(experiment_id: str): experiment_id = str(experiment_id) - # Get the list of all users list_users = store.list_users() - # Filter users who are associated with the given experiment - usernames = [] + users = [] for user in list_users: # Check if the user is associated with the experiment - user_experiments = [str(exp.experiment_id) for exp in user.experiment_permissions] - if experiment_id in user_experiments: - usernames.append(user.username) - - return jsonify({"usernames": usernames}) + user_experiments_details = {str(exp.experiment_id): exp.permission for exp in user.experiment_permissions} + if experiment_id in user_experiments_details: + users.append({"username": user.username, "permission": user_experiments_details[experiment_id]}) + return jsonify(users) @catch_mlflow_exception def get_model_users(model_name): # Get the list of all users list_users = store.list_users() - # Filter users who are associated with the given model - usernames = [] + users = [] for user in list_users: # Check if the user is associated with the model - user_models = [model.name for model in user.registered_model_permissions] + user_models = {model.name: model.permission for model in user.registered_model_permissions} if model_name in user_models: - usernames.append(user.username) - - return jsonify({"usernames": usernames}) + users.append({"username": user.username, "permission": user_models[model_name]}) + return jsonify(users) def _password_generation(): @@ -697,9 +882,9 @@ def _password_generation(): @catch_mlflow_exception def update_experiment_permission(): store.update_experiment_permission( - _get_experiment_id(), + _get_experiment_id(request.get_json()), _get_request_param("user_name"), - _get_request_param("new_permission"), + _get_request_param("permission"), ) return jsonify({"message": "Experiment permission has been changed."}) @@ -707,7 +892,7 @@ def update_experiment_permission(): @catch_mlflow_exception def delete_experiment_permission(): store.delete_experiment_permission( - _get_experiment_id(), + _get_experiment_id(request.get_json()), _get_request_param("user_name"), ) return jsonify({"message": "Experiment permission has been deleted."}) @@ -745,19 +930,3 @@ def delete_registered_model_permission(): username = _get_request_param("user_name") store.delete_registered_model_permission(name, username) return make_response("Model permission has been deleted") - - -def set_can_manage_experiment_permission(resp: Response): - response_message = CreateExperiment.Response() - parse_dict(resp.json, response_message) - experiment_id = response_message.experiment_id - username = _get_request_param("user_name") - store.create_experiment_permission(experiment_id, username, MANAGE.name) - - -def set_can_manage_registered_model_permission(resp: Response): - response_message = CreateRegisteredModel.Response() - parse_dict(resp.json, response_message) - name = response_message.registered_model.name - username = _get_request_param("user_name") - store.create_registered_model_permission(name, username, MANAGE.name) diff --git a/web-ui/src/app/features/admin-page/components/details/experiment-permission-details/experiment-permission-details.component.ts b/web-ui/src/app/features/admin-page/components/details/experiment-permission-details/experiment-permission-details.component.ts index 33c626c..92ae6b6 100644 --- a/web-ui/src/app/features/admin-page/components/details/experiment-permission-details/experiment-permission-details.component.ts +++ b/web-ui/src/app/features/admin-page/components/details/experiment-permission-details/experiment-permission-details.component.ts @@ -89,7 +89,7 @@ export class ExperimentPermissionDetailsComponent implements OnInit { filter(Boolean), switchMap(({ user, permission }) => this.permissionDataService.createExperimentPermission({ experiment_id: this.experimentId, - new_permission: permission, + permission: permission, user_name: user, })), switchMap(() => this.loadUsersForExperiment(this.experimentId)), diff --git a/web-ui/src/app/features/admin-page/components/details/model-permission-details/model-permission-details.component.ts b/web-ui/src/app/features/admin-page/components/details/model-permission-details/model-permission-details.component.ts index 932cbe8..986f63c 100644 --- a/web-ui/src/app/features/admin-page/components/details/model-permission-details/model-permission-details.component.ts +++ b/web-ui/src/app/features/admin-page/components/details/model-permission-details/model-permission-details.component.ts @@ -43,7 +43,7 @@ export class ModelPermissionDetailsComponent implements OnInit { } revokePermissionForUser(item: any) { - this.permissionDataService.deleteModelPermission({ model_name: this.modelId, user_name: item.username }) + this.permissionDataService.deleteModelPermission({ name: this.modelId, user_name: item.username }) .pipe( tap(() => this.snackService.openSnackBar('Permission revoked successfully')), switchMap(() => this.loadUsersForModel(this.modelId)), @@ -86,8 +86,8 @@ export class ModelPermissionDetailsComponent implements OnInit { .afterClosed()), filter(Boolean), switchMap(({ user, permission }) => this.permissionDataService.createModelPermission({ - model_name: this.modelId, - new_permission: permission, + name: this.modelId, + permission: permission, user_name: user, })), switchMap(() => this.loadUsersForModel(this.modelId)), diff --git a/web-ui/src/app/features/admin-page/components/details/user-permission-details/user-permission-details.component.ts b/web-ui/src/app/features/admin-page/components/details/user-permission-details/user-permission-details.component.ts index 5ef8557..b8674b9 100644 --- a/web-ui/src/app/features/admin-page/components/details/user-permission-details/user-permission-details.component.ts +++ b/web-ui/src/app/features/admin-page/components/details/user-permission-details/user-permission-details.component.ts @@ -70,8 +70,8 @@ export class UserPermissionDetailsComponent implements OnInit { filter(Boolean), switchMap(({ entity, permission }) => this.permissionDataService.createModelPermission({ user_name: this.userId, - model_name: entity, - new_permission: permission, + name: entity, + permission: permission, })), tap(() => this.snackBarService.openSnackBar('Permission granted successfully')), switchMap(() => this.modelDataService.getModelsForUser(this.userId)), @@ -95,7 +95,7 @@ export class UserPermissionDetailsComponent implements OnInit { return this.permissionDataService.createExperimentPermission({ user_name: this.userId, experiment_name: entity, - new_permission: permission, + permission: permission, }) }), tap(() => this.snackBarService.openSnackBar('Permission granted successfully')), @@ -127,7 +127,7 @@ export class UserPermissionDetailsComponent implements OnInit { } revokeModelPermissionForUser({name}: any) { - this.permissionDataService.deleteModelPermission({model_name: name, user_name: this.userId}) + this.permissionDataService.deleteModelPermission({name: name, user_name: this.userId}) .pipe( tap(() => this.snackBarService.openSnackBar('Permission revoked successfully')), switchMap(() => this.modelDataService.getModelsForUser(this.userId)), diff --git a/web-ui/src/app/shared/interfaces/permission-data.interface.ts b/web-ui/src/app/shared/interfaces/permission-data.interface.ts index 9a331f3..a0d1791 100644 --- a/web-ui/src/app/shared/interfaces/permission-data.interface.ts +++ b/web-ui/src/app/shared/interfaces/permission-data.interface.ts @@ -2,11 +2,11 @@ export interface CreateExperimentPermissionRequestBodyModel { experiment_name?: string; experiment_id?: string; user_name: string; - new_permission: string; + permission: string; } export interface CreateModelPermissionRequestBodyModel { - model_name: string; + name: string; user_name: string; - new_permission: string; + permission: string; } diff --git a/web-ui/src/app/shared/services/data/permission-data.service.ts b/web-ui/src/app/shared/services/data/permission-data.service.ts index 9aac38c..ae4d7d4 100644 --- a/web-ui/src/app/shared/services/data/permission-data.service.ts +++ b/web-ui/src/app/shared/services/data/permission-data.service.ts @@ -21,7 +21,7 @@ export class PermissionDataService { return this.http.post(API_URL.CREATE_EXPERIMENT_PERMISSION, body, { responseType: 'text' }); } - updateExperimentPermission(body: { user_name: string, experiment_id: string, new_permission: string }) { + updateExperimentPermission(body: { user_name: string, experiment_id: string, permission: string }) { return this.http.post(API_URL.UPDATE_EXPERIMENT_PERMISSION, body, { responseType: 'text' }); } @@ -33,11 +33,11 @@ export class PermissionDataService { return this.http.post(API_URL.CREATE_MODEL_PERMISSION, body); } - updateModelPermission(body: { user_name: string, model_name: string, new_permission: string }) { + updateModelPermission(body: { user_name: string, name: string, permission: string }) { return this.http.post(API_URL.UPDATE_MODEL_PERMISSION, body, { responseType: 'text' }); } - deleteModelPermission(body: { model_name: string, user_name: string }) { + deleteModelPermission(body: { name: string, user_name: string }) { return this.http.post(API_URL.DELETE_MODEL_PERMISSION, body, { responseType: 'text' }); } diff --git a/web-ui/src/app/shared/services/permission-modal.service.ts b/web-ui/src/app/shared/services/permission-modal.service.ts index 42fc6c1..b6eeb26 100644 --- a/web-ui/src/app/shared/services/permission-modal.service.ts +++ b/web-ui/src/app/shared/services/permission-modal.service.ts @@ -35,8 +35,8 @@ export class PermissionModalService { .pipe( filter(Boolean), switchMap(({ permission }) => this.permissionDataService.updateModelPermission({ - model_name: modelName, - new_permission: permission, + name: modelName, + permission: permission, user_name: userName, })), ) @@ -57,7 +57,7 @@ export class PermissionModalService { filter(Boolean), switchMap(({ permission }) => this.permissionDataService.updateExperimentPermission({ experiment_id: experimentName, - new_permission: permission, + permission: permission, user_name: userName, })), ) From 4b75e0f12c693395ca42436bc302cf1a1746ca91 Mon Sep 17 00:00:00 2001 From: Alexander Kharkevich Date: Sun, 14 Apr 2024 15:32:01 -0400 Subject: [PATCH 09/13] fix: add after request filter --- mlflow_oidc_auth/app.py | 1 + mlflow_oidc_auth/views.py | 34 ++++++++++++++++++++-- web-ui/src/app/core/configs/permissions.ts | 5 ++++ 3 files changed, 37 insertions(+), 3 deletions(-) diff --git a/mlflow_oidc_auth/app.py b/mlflow_oidc_auth/app.py index df409ee..5c28956 100644 --- a/mlflow_oidc_auth/app.py +++ b/mlflow_oidc_auth/app.py @@ -56,6 +56,7 @@ # Add new hooks app.before_request(views.before_request_hook) +app.after_request(views.after_request_hook) # Set up session Session(app) diff --git a/mlflow_oidc_auth/views.py b/mlflow_oidc_auth/views.py index c33cba7..b5aac5d 100644 --- a/mlflow_oidc_auth/views.py +++ b/mlflow_oidc_auth/views.py @@ -4,8 +4,7 @@ import requests import secrets import string -from typing import Callable, Union - +from typing import Any, Callable, Dict, Optional, Union from mlflow.store.entities import PagedList from mlflow.entities.model_registry import RegisteredModel from mlflow.entities import Experiment @@ -87,7 +86,7 @@ get_endpoints, ) -from mlflow.tracking import MlflowClient +from mlflow.utils.rest_utils import _REST_API_PATH_PREFIX from oauthlib.oauth2 import WebApplicationClient from mlflow_oidc_auth import routes @@ -552,6 +551,31 @@ def get_after_request_handler(request_class): } +@catch_mlflow_exception +def after_request_hook(resp: Response): + if 400 <= resp.status_code < 600: + return resp + + if handler := AFTER_REQUEST_HANDLERS.get((request.path, request.method)): + handler(resp) + return resp + + +def _is_proxy_artifact_path(path: str) -> bool: + return path.startswith(f"{_REST_API_PATH_PREFIX}/mlflow-artifacts/artifacts/") + + +def _get_proxy_artifact_validator(method: str, view_args: Optional[Dict[str, Any]]) -> Optional[Callable[[], bool]]: + if view_args is None: + return validate_can_read_experiment_artifact_proxy # List + + return { + "GET": validate_can_read_experiment_artifact_proxy, # Download + "PUT": validate_can_update_experiment_artifact_proxy, # Upload + "DELETE": validate_can_delete_experiment_artifact_proxy, # Delete + }.get(method) + + def before_request_hook(): """Called before each request. If it did not return a response, the view function for the matched route is called and returns a response""" @@ -572,6 +596,10 @@ def before_request_hook(): if validator := BEFORE_REQUEST_VALIDATORS.get((request.path, request.method)): if not validator(): return make_forbidden_response() + elif _is_proxy_artifact_path(request.path): + if validator := _get_proxy_artifact_validator(request.method, request.view_args): + if not validator(): + return make_forbidden_response() def make_forbidden_response() -> Response: diff --git a/web-ui/src/app/core/configs/permissions.ts b/web-ui/src/app/core/configs/permissions.ts index 284188c..72c409b 100644 --- a/web-ui/src/app/core/configs/permissions.ts +++ b/web-ui/src/app/core/configs/permissions.ts @@ -2,6 +2,7 @@ export enum PermissionEnum { EDIT = 'EDIT', READ = 'READ', MANAGE = 'MANAGE', + NO_PERMISSIONS = 'NO_PERMISSIONS' } export const PERMISSIONS = [ @@ -16,5 +17,9 @@ export const PERMISSIONS = [ { value: PermissionEnum.MANAGE, title: 'Manage' + }, + { + value: PermissionEnum.NO_PERMISSIONS, + title: 'No permissions' } ] From e6ad6f52abfa695eeea11e81ebc3c1c2d0a2f4d1 Mon Sep 17 00:00:00 2001 From: Alexander Kharkevich Date: Sun, 14 Apr 2024 23:21:05 -0400 Subject: [PATCH 10/13] fix: fix _get_experiment_id --- mlflow_oidc_auth/views.py | 38 +++++++++++++++++++++++++++++--------- 1 file changed, 29 insertions(+), 9 deletions(-) diff --git a/mlflow_oidc_auth/views.py b/mlflow_oidc_auth/views.py index b5aac5d..b67a1bb 100644 --- a/mlflow_oidc_auth/views.py +++ b/mlflow_oidc_auth/views.py @@ -102,11 +102,20 @@ _logger = logging.getLogger(__name__) -def _get_experiment_id(request_data: dict) -> str: - if "experiment_id" in request_data: - return request_data.get("experiment_id") - elif "experiment_name" in request_data: - return _get_tracking_store().get_experiment_by_name(request_data.get("experiment_name")).experiment_id +def _get_experiment_id() -> str: + if request.method == "GET": + args = request.args + elif request.method in ("POST", "PATCH", "DELETE"): + args = request.json + else: + raise MlflowException( + f"Unsupported HTTP method '{request.method}'", + BAD_REQUEST, + ) + if "experiment_id" in args: + return args["experiment_id"] + elif "experiment_name" in args: + return _get_tracking_store().get_experiment_by_name(args["experiment_name"]).experiment_id raise MlflowException( "Either 'experiment_id' or 'experiment_name' must be provided in the request data.", INVALID_PARAMETER_VALUE, @@ -221,7 +230,7 @@ def _get_is_admin(): def _get_permission_from_experiment_id() -> Permission: - experiment_id = _get_experiment_id(request.get_json()) + experiment_id = _get_experiment_id() username = _get_username() return _get_permission_from_store_or_default(lambda: store.get_experiment_permission(experiment_id, username).permission) @@ -620,7 +629,7 @@ def make_basic_auth_response() -> Response: @catch_mlflow_exception def create_experiment_permission(): store.create_experiment_permission( - _get_experiment_id(request.get_json()), + _get_experiment_id(), _get_request_param("user_name"), _get_request_param("permission"), ) @@ -771,6 +780,17 @@ def get_current_user(): } for permission in user.experiment_permissions ] + if not _get_is_admin(): + user_json["experiment_permissions"] = [ + permission + for permission in user_json["experiment_permissions"] + if permission["permission"] != "NO_PERMISSIONS" + ] + user_json["registered_model_permissions"] = [ + registered_model_permission + for registered_model_permission in user_json["registered_model_permissions"] + if registered_model_permission["permission"] != "NO_PERMISSIONS" + ] return jsonify(user_json) @@ -910,7 +930,7 @@ def _password_generation(): @catch_mlflow_exception def update_experiment_permission(): store.update_experiment_permission( - _get_experiment_id(request.get_json()), + _get_experiment_id(), _get_request_param("user_name"), _get_request_param("permission"), ) @@ -920,7 +940,7 @@ def update_experiment_permission(): @catch_mlflow_exception def delete_experiment_permission(): store.delete_experiment_permission( - _get_experiment_id(request.get_json()), + _get_experiment_id(), _get_request_param("user_name"), ) return jsonify({"message": "Experiment permission has been deleted."}) From 88fce2c70af2d28828110f9beab0edfa490c185d Mon Sep 17 00:00:00 2001 From: Alexander Kharkevich Date: Mon, 15 Apr 2024 00:14:07 -0400 Subject: [PATCH 11/13] fix: fix requests method --- mlflow_oidc_auth/app.py | 16 ++++++++-------- mlflow_oidc_auth/routes.py | 2 +- web-ui/src/app/core/configs/api-urls.ts | 4 ++-- .../services/data/permission-data.service.ts | 10 +++------- 4 files changed, 14 insertions(+), 18 deletions(-) diff --git a/mlflow_oidc_auth/app.py b/mlflow_oidc_auth/app.py index 5c28956..577caed 100644 --- a/mlflow_oidc_auth/app.py +++ b/mlflow_oidc_auth/app.py @@ -25,7 +25,7 @@ app.add_url_rule(rule=routes.UI_ROOT, methods=["GET"], view_func=views.oidc_ui) # User token -app.add_url_rule(rule=routes.CREATE_ACCESS_TOKEN, methods=["GET"], view_func=views.create_access_token) +app.add_url_rule(rule=routes.GET_ACCESS_TOKEN, methods=["GET"], view_func=views.create_access_token) app.add_url_rule(rule=routes.GET_CURRENT_USER, methods=["GET"], view_func=views.get_current_user) # UI routes support @@ -40,19 +40,19 @@ # User management app.add_url_rule(rule=routes.CREATE_USER, methods=["POST"], view_func=views.create_user) app.add_url_rule(rule=routes.GET_USER, methods=["GET"], view_func=views.get_user) -app.add_url_rule(rule=routes.UPDATE_USER_PASSWORD, methods=["GET"], view_func=views.update_username_password) -app.add_url_rule(rule=routes.UPDATE_USER_ADMIN, methods=["GET"], view_func=views.update_user_admin) -app.add_url_rule(rule=routes.DELETE_USER, methods=["GET"], view_func=views.delete_user) +app.add_url_rule(rule=routes.UPDATE_USER_PASSWORD, methods=["PATCH"], view_func=views.update_username_password) +app.add_url_rule(rule=routes.UPDATE_USER_ADMIN, methods=["PATCH"], view_func=views.update_user_admin) +app.add_url_rule(rule=routes.DELETE_USER, methods=["DELETE"], view_func=views.delete_user) # permission management app.add_url_rule(rule=routes.CREATE_EXPERIMENT_PERMISSION, methods=["POST"], view_func=views.create_experiment_permission) app.add_url_rule(rule=routes.GET_EXPERIMENT_PERMISSION, methods=["GET"], view_func=views.get_experiment_permission) -app.add_url_rule(rule=routes.UPDATE_EXPERIMENT_PERMISSION, methods=["POST"], view_func=views.update_experiment_permission) -app.add_url_rule(rule=routes.DELETE_EXPERIMENT_PERMISSION, methods=["POST"], view_func=views.delete_experiment_permission) +app.add_url_rule(rule=routes.UPDATE_EXPERIMENT_PERMISSION, methods=["PATCH"], view_func=views.update_experiment_permission) +app.add_url_rule(rule=routes.DELETE_EXPERIMENT_PERMISSION, methods=["DELETE"], view_func=views.delete_experiment_permission) app.add_url_rule(rule=routes.CREATE_REGISTERED_MODEL_PERMISSION, methods=["POST"], view_func=views.create_registered_model_permission) app.add_url_rule(rule=routes.GET_REGISTERED_MODEL_PERMISSION, methods=["GET"], view_func=views.get_registered_model_permission) -app.add_url_rule(rule=routes.UPDATE_REGISTERED_MODEL_PERMISSION, methods=["POST"], view_func=views.update_registered_model_permission) -app.add_url_rule(rule=routes.DELETE_REGISTERED_MODEL_PERMISSION, methods=["POST"], view_func=views.delete_registered_model_permission) +app.add_url_rule(rule=routes.UPDATE_REGISTERED_MODEL_PERMISSION, methods=["PATCH"], view_func=views.update_registered_model_permission) +app.add_url_rule(rule=routes.DELETE_REGISTERED_MODEL_PERMISSION, methods=["DELETE"], view_func=views.delete_registered_model_permission) # Add new hooks app.before_request(views.before_request_hook) diff --git a/mlflow_oidc_auth/routes.py b/mlflow_oidc_auth/routes.py index 6878802..e57fa11 100644 --- a/mlflow_oidc_auth/routes.py +++ b/mlflow_oidc_auth/routes.py @@ -10,7 +10,7 @@ UI_ROOT = "/oidc/ui/" # create access token for current user -CREATE_ACCESS_TOKEN = _get_rest_path("/mlflow/users/access-token") +GET_ACCESS_TOKEN = _get_rest_path("/mlflow/users/access-token") # get infrmation about current user GET_CURRENT_USER = _get_rest_path("/mlflow/users/current") # list of experiments, models, users diff --git a/web-ui/src/app/core/configs/api-urls.ts b/web-ui/src/app/core/configs/api-urls.ts index 3095378..34f9c33 100644 --- a/web-ui/src/app/core/configs/api-urls.ts +++ b/web-ui/src/app/core/configs/api-urls.ts @@ -6,10 +6,10 @@ export const API_URL = { MODELS_FOR_USER: '/api/2.0/mlflow/users/${userName}/registered-models', USERS_FOR_MODEL: '/api/2.0/mlflow/registered-models/${modelName}/users', CREATE_EXPERIMENT_PERMISSION: '/api/2.0/mlflow/experiments/permissions/create', - CREATE_MODEL_PERMISSION: '/api/2.0/mlflow/registered-models/permissions/create', UPDATE_EXPERIMENT_PERMISSION: '/api/2.0/mlflow/experiments/permissions/update', - UPDATE_MODEL_PERMISSION: '/api/2.0/mlflow/registered-models/permissions/update', DELETE_EXPERIMENT_PERMISSION: '/api/2.0/mlflow/experiments/permissions/delete', + CREATE_MODEL_PERMISSION: '/api/2.0/mlflow/registered-models/permissions/create', + UPDATE_MODEL_PERMISSION: '/api/2.0/mlflow/registered-models/permissions/update', DELETE_MODEL_PERMISSION: '/api/2.0/mlflow/registered-models/permissions/delete', GET_ALL_USERS: '/api/2.0/mlflow/users', diff --git a/web-ui/src/app/shared/services/data/permission-data.service.ts b/web-ui/src/app/shared/services/data/permission-data.service.ts index ae4d7d4..0a07ab5 100644 --- a/web-ui/src/app/shared/services/data/permission-data.service.ts +++ b/web-ui/src/app/shared/services/data/permission-data.service.ts @@ -21,12 +21,8 @@ export class PermissionDataService { return this.http.post(API_URL.CREATE_EXPERIMENT_PERMISSION, body, { responseType: 'text' }); } - updateExperimentPermission(body: { user_name: string, experiment_id: string, permission: string }) { - return this.http.post(API_URL.UPDATE_EXPERIMENT_PERMISSION, body, { responseType: 'text' }); - } - deleteExperimentPermission(body: { experiment_id: string, user_name: string }) { - return this.http.post(API_URL.DELETE_EXPERIMENT_PERMISSION, body, { responseType: 'text' }); + return this.http.delete(API_URL.DELETE_EXPERIMENT_PERMISSION, { body }); } createModelPermission(body: CreateModelPermissionRequestBodyModel) { @@ -34,11 +30,11 @@ export class PermissionDataService { } updateModelPermission(body: { user_name: string, name: string, permission: string }) { - return this.http.post(API_URL.UPDATE_MODEL_PERMISSION, body, { responseType: 'text' }); + return this.http.patch(API_URL.UPDATE_MODEL_PERMISSION, body, { responseType: 'text' }); } deleteModelPermission(body: { name: string, user_name: string }) { - return this.http.post(API_URL.DELETE_MODEL_PERMISSION, body, { responseType: 'text' }); + return this.http.delete(API_URL.DELETE_MODEL_PERMISSION, { body }); } } From 327b201349dab8f4de21b783277cdc8c15b0470b Mon Sep 17 00:00:00 2001 From: Alexander Kharkevich Date: Mon, 15 Apr 2024 00:33:34 -0400 Subject: [PATCH 12/13] chore: update output --- mlflow_oidc_auth/views.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/mlflow_oidc_auth/views.py b/mlflow_oidc_auth/views.py index b67a1bb..8b71b3c 100644 --- a/mlflow_oidc_auth/views.py +++ b/mlflow_oidc_auth/views.py @@ -612,7 +612,7 @@ def before_request_hook(): def make_forbidden_response() -> Response: - res = make_response("Permission denied") + res = make_response(jsonify({"message": "Permission denied"})) res.status_code = 403 return res @@ -782,9 +782,7 @@ def get_current_user(): ] if not _get_is_admin(): user_json["experiment_permissions"] = [ - permission - for permission in user_json["experiment_permissions"] - if permission["permission"] != "NO_PERMISSIONS" + permission for permission in user_json["experiment_permissions"] if permission["permission"] != "NO_PERMISSIONS" ] user_json["registered_model_permissions"] = [ registered_model_permission @@ -969,7 +967,7 @@ def update_registered_model_permission(): username = _get_request_param("user_name") permission = _get_request_param("permission") store.update_registered_model_permission(name, username, permission) - return make_response("Model permission has been changed") + return make_response(jsonify({"message" : "Model permission has been changed"})) @catch_mlflow_exception @@ -977,4 +975,4 @@ def delete_registered_model_permission(): name = _get_request_param("name") username = _get_request_param("user_name") store.delete_registered_model_permission(name, username) - return make_response("Model permission has been deleted") + return make_response(jsonify({"message" : "Model permission has been deleted"})) From 1f1245b66cd40c12a22511d157f9f31d977e19f0 Mon Sep 17 00:00:00 2001 From: Alexander Kharkevich Date: Tue, 16 Apr 2024 00:24:19 -0400 Subject: [PATCH 13/13] feat: finalize validators --- mlflow_oidc_auth/templates/_header.html | 4 - mlflow_oidc_auth/views.py | 127 +++++++++--------- pyproject.toml | 2 +- web-ui/src/app/app-routing.module.ts | 2 +- .../components/header/header.component.html | 2 +- .../services/data/permission-data.service.ts | 4 + 6 files changed, 72 insertions(+), 69 deletions(-) delete mode 100644 mlflow_oidc_auth/templates/_header.html diff --git a/mlflow_oidc_auth/templates/_header.html b/mlflow_oidc_auth/templates/_header.html deleted file mode 100644 index 2230f12..0000000 --- a/mlflow_oidc_auth/templates/_header.html +++ /dev/null @@ -1,4 +0,0 @@ - diff --git a/mlflow_oidc_auth/views.py b/mlflow_oidc_auth/views.py index 8b71b3c..064478c 100644 --- a/mlflow_oidc_auth/views.py +++ b/mlflow_oidc_auth/views.py @@ -77,7 +77,7 @@ UpdateExperiment, UpdateRun, ) -from mlflow_oidc_auth.permissions import Permission, get_permission, MANAGE +from mlflow_oidc_auth.permissions import Permission, get_permission, MANAGE, NO_PERMISSIONS from mlflow.server.handlers import ( _get_model_registry_store, _get_request_message, @@ -171,24 +171,6 @@ def _get_permission_from_store_or_default(store_permission_func: Callable[[], st return get_permission(perm) -# def authenticate_request() -> Union[Authorization, Response]: -# """Use configured authorization function to get request authorization.""" -# auth_func = get_auth_func(auth_config.authorization_function) -# return auth_func() - -# @functools.lru_cache(maxsize=None) -# def get_auth_func(authorization_function: str) -> Callable[[], Union[Authorization, Response]]: -# """ -# Import and return the specified authorization function. - -# Args: -# authorization_function: A string of the form "module.submodule:auth_func" -# """ -# mod_name, fn_name = authorization_function.split(":", 1) -# module = importlib.import_module(mod_name) -# return getattr(module, fn_name) - - def authenticate_request_basic_auth() -> Union[Authorization, Response]: username = request.authorization.username password = request.authorization.password @@ -211,22 +193,14 @@ def _set_username(username): return -def _get_display_name(): - return session.get("display_name") - - -def _set_display_name(display_name): - session["display_name"] = display_name - return - - -def _set_is_admin(_is_admin: bool = False): - session["is_admin"] = _is_admin - return - +def _get_is_admin() -> bool: + return bool(store.get_user(_get_username()).is_admin) -def _get_is_admin(): - return session.get("is_admin") +def username_is_sender(): + """Validate if the request username is the sender""" + username = _get_request_param("username") + sender = _get_username() + return username == sender def _get_permission_from_experiment_id() -> Permission: @@ -466,6 +440,29 @@ def validate_can_manage_registered_model(): return _get_permission_from_registered_model_name().can_manage +def validate_can_read_user(): + return username_is_sender() + + +def validate_can_create_user(): + # only admins can create user, but admins won't reach this validator + return False + + +def validate_can_update_user_password(): + return username_is_sender() + + +def validate_can_update_user_admin(): + # only admins can update, but admins won't reach this validator + return False + + +def validate_can_delete_user(): + # only admins can delete, but admins won't reach this validator + return False + + def get_before_request_handler(request_class): return BEFORE_REQUEST_HANDLERS.get(request_class) @@ -532,18 +529,18 @@ def get_before_request_handler(request_class): (routes.UPDATE_REGISTERED_MODEL_PERMISSION, "PATCH"): validate_can_manage_registered_model, (routes.DELETE_REGISTERED_MODEL_PERMISSION, "DELETE"): validate_can_manage_registered_model, # (SIGNUP, "GET"): validate_can_create_user, - # (GET_USER, "GET"): validate_can_read_user, - # (CREATE_USER, "POST"): validate_can_create_user, - # (UPDATE_USER_PASSWORD, "PATCH"): validate_can_update_user_password, - # (UPDATE_USER_ADMIN, "PATCH"): validate_can_update_user_admin, - # (DELETE_USER, "DELETE"): validate_can_delete_user, + # (routes.GET_USER, "GET"): validate_can_read_user, + (routes.CREATE_USER, "POST"): validate_can_create_user, + # (routes.UPDATE_USER_PASSWORD, "PATCH"): validate_can_update_user_password, + (routes.UPDATE_USER_ADMIN, "PATCH"): validate_can_update_user_admin, + (routes.DELETE_USER, "DELETE"): validate_can_delete_user, } ) AFTER_REQUEST_PATH_HANDLERS = { CreateExperiment: set_can_manage_experiment_permission, CreateRegisteredModel: set_can_manage_registered_model_permission, - # ???? DeleteRegisteredModel: delete_can_manage_registered_model_permission, + DeleteRegisteredModel: delete_can_manage_registered_model_permission, SearchExperiments: filter_search_experiments, SearchRegisteredModels: filter_search_registered_models, } @@ -601,6 +598,9 @@ def before_request_hook(): username=None, provide_display_name=AppConfig.get_property("OIDC_PROVIDER_DISPLAY_NAME"), ) + # admins don't need to be authorized + if _get_is_admin(): + return # authorization if validator := BEFORE_REQUEST_VALIDATORS.get((request.path, request.method)): if not validator(): @@ -688,8 +688,11 @@ def callback(): # Process the user data user_data = user_response.json() - email = user_data.get("email", "Unknown") - _set_display_name(user_data.get("name", "Unknown")) + email = user_data.get("email", None) + if email is None: + return "No email provided", 401 + display_name = user_data.get("name", "Unknown") + is_admin = False # check if user is in the group if AppConfig.get_property("OIDC_PROVIDER_TYPE") == "microsoft": @@ -703,26 +706,27 @@ def callback(): }, ) group_data = group_response.json() - if not any(group["displayName"] == AppConfig.get_property("OIDC_GROUP_NAME") for group in group_data["value"]): + if not any( + group["displayName"] == AppConfig.get_property("OIDC_GROUP_NAME") + or group["displayName"] == AppConfig.get_property("OIDC_ADMIN_GROUP_NAME") + for group in group_data["value"] + ): return "User not in group", 401 # set is_admin if user is in admin group if any(group["displayName"] == AppConfig.get_property("OIDC_ADMIN_GROUP_NAME") for group in group_data["value"]): - _set_is_admin(True) - else: - _set_is_admin(False) + is_admin = True elif AppConfig.get_property("OIDC_PROVIDER_TYPE") == "oidc": - if AppConfig.get_property("OIDC_GROUP_NAME") not in user_data.get("groups", []): + if (AppConfig.get_property("OIDC_GROUP_NAME") not in user_data.get("groups", [])) or ( + AppConfig.get_property("OIDC_ADMIN_GROUP_NAME") not in user_data.get("groups", []) + ): return "User not in group", 401 # set is_admin if user is in admin group if AppConfig.get_property("OIDC_ADMIN_GROUP_NAME") in user_data.get("groups", []): - _set_is_admin(True) - else: - _set_is_admin(False) + is_admin = True - # Store the user data in the session. - _set_username(email.lower()) # Create user due to auth - create_user() + create_user(username=email.lower(), display_name=display_name, is_admin=is_admin) + _set_username(email.lower()) return redirect(url_for("oidc_ui")) @@ -743,18 +747,17 @@ def oidc_ui(filename=None): return send_from_directory(ui_directory, filename) -def create_user(): +def create_user(username: str, display_name: str, is_admin: bool = False): try: - user = store.get_user(_get_username()) + user = store.get_user(username) + store.update_user(username, is_admin=is_admin) return ( jsonify({"message": f"User {user.username} (ID: {user.id}) already exists"}), 200, ) except MlflowException: password = _password_generation() - user = store.create_user( - username=_get_username(), password=password, display_name=_get_display_name(), is_admin=_get_is_admin() - ) + user = store.create_user(username=username, password=password, display_name=display_name, is_admin=is_admin) return ( jsonify({"message": f"User {user.username} (ID: {user.id}) successfully created"}), 201, @@ -782,12 +785,12 @@ def get_current_user(): ] if not _get_is_admin(): user_json["experiment_permissions"] = [ - permission for permission in user_json["experiment_permissions"] if permission["permission"] != "NO_PERMISSIONS" + permission for permission in user_json["experiment_permissions"] if permission["permission"] != NO_PERMISSIONS.name ] user_json["registered_model_permissions"] = [ registered_model_permission for registered_model_permission in user_json["registered_model_permissions"] - if registered_model_permission["permission"] != "NO_PERMISSIONS" + if registered_model_permission["permission"] != get_permission(NO_PERMISSIONS.name) ] return jsonify(user_json) @@ -967,7 +970,7 @@ def update_registered_model_permission(): username = _get_request_param("user_name") permission = _get_request_param("permission") store.update_registered_model_permission(name, username, permission) - return make_response(jsonify({"message" : "Model permission has been changed"})) + return make_response(jsonify({"message": "Model permission has been changed"})) @catch_mlflow_exception @@ -975,4 +978,4 @@ def delete_registered_model_permission(): name = _get_request_param("name") username = _get_request_param("user_name") store.delete_registered_model_permission(name, username) - return make_response(jsonify({"message" : "Model permission has been deleted"})) + return make_response(jsonify({"message": "Model permission has been deleted"})) diff --git a/pyproject.toml b/pyproject.toml index 55fef82..402d1ea 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,7 @@ description = "OIDC auth plugin for MLflow" readme = "README.md" keywords = ["mlflow", "oauth2", "oidc"] classifiers = [ - "Development Status :: 3 - Alpha", + "Development Status :: 4 - Beta", "Intended Audience :: Developers", "Intended Audience :: End Users/Desktop", "Intended Audience :: Science/Research", diff --git a/web-ui/src/app/app-routing.module.ts b/web-ui/src/app/app-routing.module.ts index 66a14f5..7d9e634 100644 --- a/web-ui/src/app/app-routing.module.ts +++ b/web-ui/src/app/app-routing.module.ts @@ -4,7 +4,7 @@ import { RouterModule, Routes } from '@angular/router'; const routes: Routes = [ { path: 'home', loadChildren: () => import('./features/home-page/home-page.module').then(m => m.HomePageModule) }, { - path: 'admin-panel', + path: 'manage', loadChildren: () => import('./features/admin-page/admin-page.module').then(m => m.AdminPageModule), }, { path: '**', pathMatch: 'full', redirectTo: 'home' }, diff --git a/web-ui/src/app/shared/components/header/header.component.html b/web-ui/src/app/shared/components/header/header.component.html index 5671c15..3c9010f 100644 --- a/web-ui/src/app/shared/components/header/header.component.html +++ b/web-ui/src/app/shared/components/header/header.component.html @@ -4,7 +4,7 @@ Hello, {{name}}