From d958311977bdbc1fb4d6f53af8c6a2e7f719eae1 Mon Sep 17 00:00:00 2001 From: Martijn Maas Date: Fri, 22 Dec 2023 13:28:36 +0100 Subject: [PATCH] Add simple security --- api/flask_app.py | 6 ++++ api/simple_security.py | 66 ++++++++++++++++++++++++++++++++++++++++ api/start_flask_local.sh | 4 ++- 3 files changed, 75 insertions(+), 1 deletion(-) create mode 100644 api/simple_security.py diff --git a/api/flask_app.py b/api/flask_app.py index 45a562f..8f4e809 100644 --- a/api/flask_app.py +++ b/api/flask_app.py @@ -19,6 +19,7 @@ from run import Predictor from utils.image_utils import load_image_array_from_bytes, load_image_tensor_from_bytes from utils.logging_utils import get_logger_name +from api.simple_security import SimpleSecurity, session_key_required # Reading environment files try: @@ -45,6 +46,10 @@ app = Flask(__name__) +enable_security = os.getenv("SECURITY_ENABLED", False) +api_key_user_json_string = os.getenv("API_KEY_USER_JSON_STRING") +SimpleSecurity(app, enable_security, api_key_user_json_string) + predictor = None gen_page = None @@ -233,6 +238,7 @@ def check_exception_callback(future: Future): @app.route("/predict", methods=["POST"]) @exception_predict_counter.count_exceptions() +@session_key_required def predict() -> tuple[Response, int]: """ Run the prediction on a submitted image diff --git a/api/simple_security.py b/api/simple_security.py new file mode 100644 index 0000000..932b9f7 --- /dev/null +++ b/api/simple_security.py @@ -0,0 +1,66 @@ +import functools +import uuid + +import flask +from flask import request, Response, jsonify, Flask +import json + + +class SimpleSecurity: + def __init__(self, app: Flask, enabled: bool = False, key_user_json: str = None): + app.extensions["security"] = self + self.enabled = enabled + if enabled: + self.register_login_resource(app) + try: + self.api_key_user = json.loads(key_user_json) + self.session_key_user = {} + except Exception as e: + raise ValueError("When security is enabled, key_user_json should be a valid json string. ", e) + + def is_known_session_key(self, session_key: str): + return session_key in self.session_key_user.keys() + + def register_login_resource(self, app): + @app.route("/login", methods=["POST"]) + def login(): + if "Authorization" in request.headers.keys(): + api_key = request.headers["Authorization"] + session_key = self.login(api_key) + + if session_key is not None: + response = Response(status=204) + response.headers["X_AUTH_TOKEN"] = session_key + + return response + + return Response(status=401) + + def login(self, api_key: str) -> str | None: + + if self.enabled and api_key in self.api_key_user: + session_key = str(uuid.uuid4()) + self.session_key_user[session_key] = self.api_key_user[api_key] + return session_key + + return None + + +def session_key_required(func): + @functools.wraps(func) + def decorator(*args, **kwargs) -> Response: + security_ = flask.current_app.extensions["security"] + if security_.enabled: + if "Authorization" in request.headers.keys(): + session_key = request.headers["Authorization"] + if security_.is_known_session_key(session_key): + return func(*args, **kwargs) + + response = jsonify({'message': 'Expected a valid session key in the Authorization header'}) + response.status_code = 401 + return response + else: + print("security disabled") + return func(*args, **kwargs) + + return decorator diff --git a/api/start_flask_local.sh b/api/start_flask_local.sh index 78de204..e46c13c 100755 --- a/api/start_flask_local.sh +++ b/api/start_flask_local.sh @@ -7,6 +7,8 @@ if [[ $( builtin cd "$( dirname ${BASH_SOURCE[0]} )/.."; pwd ) != $( pwd ) ]]; t fi LAYPA_MAX_QUEUE_SIZE=128 \ -LAYPA_MODEL_BASE_PATH="/home/stefan/Documents/models/" \ +LAYPA_MODEL_BASE_PATH="/home/martijnm/workspace/images/laypa-models" \ LAYPA_OUTPUT_BASE_PATH="/tmp/" \ +SECURITY_ENABLED="True" \ +API_KEY_USER_JSON_STRING='{"1234": "test user"}' \ FLASK_DEBUG=true FLASK_APP=api.flask_app.py flask run \ No newline at end of file