Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PoC with response and API caching #372

Open
wants to merge 6 commits into
base: development
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions admin/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from werkzeug.utils import secure_filename

import api
from cache_config import clear_api_cache_keys, clear_view_cache_keys
from util import get_theme_directories, length_check

# Blueprint configuration
Expand Down Expand Up @@ -175,6 +176,8 @@ def save_settings(settings: Dict[str, Any], flash_msg: str) -> Response:
# Load the theme template if the current theme is changed
set_theme_loader(app, remove_cache=True)

clear_view_cache_keys(all_users=True)
clear_api_cache_keys("admin_save_settings")
flash(flash_msg, 'success')

return redirect(url_for("admin_bp.index"))
Expand Down
139 changes: 104 additions & 35 deletions api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
__license__ = 'GPLv3, see LICENSE'

import base64
import hashlib
import json
import re
import sys
Expand All @@ -15,6 +16,7 @@
from flask import current_app as app
from irods import message, rule

from cache_config import cache, clear_api_cache_keys, get_api_cache_timeout, make_key
from errors import InvalidAPIError, UnauthorizedAPIAccessError
from util import log_error

Expand All @@ -23,51 +25,116 @@

@api_bp.route('/<fn>', methods=['POST'])
def _call(fn: str) -> Response:
"""Handle API calls to specified function.

:param fn: The name of the API function to call

:returns: JSON response containing the result of the API call

:raises UnauthorizedAPIAccessError: If the user is not authenticated
:raises InvalidAPIError: If the function name is invalid
"""
if not authenticated():
raise UnauthorizedAPIAccessError

if not re.match("^([a-z_]+)$", fn):
if not re.match(r"^[a-z_]+$", fn):
raise InvalidAPIError

data: Dict[str, Any] = {}
if 'data' in request.form:
data = json.loads(request.form['data'])
data = json.loads(request.form.get('data', '{}'))
result = call(fn, data)
return jsonify(result), get_response_code(result)

result: Dict[str, Any] = call(fn, data)
code: int = 200

if result['status'] == 'error_internal':
code = 500
elif result['status'] != 'ok':
code = 400
def get_response_code(result: Dict[str, Any]) -> int:
"""Determine the HTTP response code based on the result status.

response = jsonify(result)
response.status_code = code
return response
:param result: The result dictionary from the API call

:returns: HTTP status code
"""
if result['status'] == 'error_internal':
return 500
return 400 if result['status'] != 'ok' else 200


def call(fn: str, data: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
"""Call the specified API function with the provided data.

:param fn: The name of the API function to call
:param data: Optional dictionary of data to pass to the function

:returns: The result of the API call as a dictionary
"""
if app.config.get('LOG_API_CALL_DURATION', False):
begintime = timer()

if data is None:
data = {}

params = json.dumps(data)
encoded_params = hashlib.shake_256(params.encode('utf-8')).hexdigest(20)

# Clear API cache keys if the API function called impacts keys.
clear_api_cache_keys(fn)

timeout = get_api_cache_timeout(fn)
cached_result = None
if timeout > 0:
cached_result = cache.get(make_key(f"{fn}-{encoded_params}"))

# Execute rule if there is no cached result.
if cached_result is None:
result = execute_rule(fn, params)

# Cache result if a timeout is specified for this API.
if timeout > 0:
cache.set(make_key(f"{fn}-{encoded_params}"), result, timeout=timeout)
else:
result = cached_result

if app.config.get('LOG_API_CALL_DURATION', False):
endtime = timer()
callduration = round((endtime - begintime) * 1000)
log_message = f"DEBUG: {callduration:4d}ms api_{fn} {params}"
if cached_result is not None:
log_message += " (from cache)"
print(log_message, file=sys.stderr)

return json.loads(result)


def execute_rule(fn: str, params: str) -> str:
"""Execute the specified iRODS rule with the given parameters.

:param fn: The name of the API function to execute
:param params: The parameters to pass to the rule

:returns: The output of the rule execution as a string.
"""
def bytesbuf_to_str(s: message.BinBytesBuf) -> str:
"""Convert a BinBytesBuf to a string, handling null termination."""
s = s.buf[:s.buflen]
i = s.find(b'\x00')
return s if i < 0 else s[:i]

def escape_quotes(s: str) -> str:
"""Escape quotes in a string for safe inclusion in rules."""
return s.replace('\\', '\\\\').replace('"', '\\"')

def break_strings(N: int, m: int) -> int:
"""Calculate the number of segments needed to break a string."""
return (N - 1) // m + 1

def nrep_string_expr(s: str, m: int = 64) -> str:
return '++\n'.join(f'"{escape_quotes(s[i * m:i * m + m])}"' for i in range(break_strings(len(s), m) + 1))
"""Break up the string literal to work around limits for both parameter strings
and literal string constants in the iRODS core code.

if app.config.get('LOG_API_CALL_DURATION', False):
begintime = timer()

if data is None:
data = {}
:param s: The string to be broken
:param m: The maximum length of each segment

params = json.dumps(data)
:returns: A string formatted for iRODS rule input
"""
return '++\n'.join(f'"{escape_quotes(s[i * m:i * m + m])}"' for i in range(break_strings(len(s), m) + 1))

# Compress params and encode as base64 to reduce size (max rule length in iRODS is 20KB)
compressed_params = zlib.compress(params.encode())
Expand All @@ -91,35 +158,36 @@ def nrep_string_expr(s: str, m: int = 64) -> str:
g.irods.cleanup()

x = x.execute(session_cleanup=False)
x = bytesbuf_to_str(x._values['MsParam_PI'][0]._values['inOutStruct']._values['stdoutBuf'])

result = x.decode()

if app.config.get('LOG_API_CALL_DURATION', False):
endtime = timer()
callduration = round((endtime - begintime) * 1000)
print(f"DEBUG: {callduration:4d}ms api_{fn} {params}", file=sys.stderr)

return json.loads(result)
return bytesbuf_to_str(x._values['MsParam_PI'][0]._values['inOutStruct']._values['stdoutBuf'])


def authenticated() -> bool:
"""Check if the user is authenticated.

:returns: True if the user is authenticated, False otherwise
"""
return g.get('user') is not None and g.get('irods') is not None


@api_bp.errorhandler(Exception)
def api_error_handler(error: Exception) -> Response:
"""Handle exceptions raised during API calls.

:param error: The exception that was raised

:returns: A JSON response containing the error details and HTTP status code
"""
log_error(f'API Error: {error}', True)
status = "internal_error"
status_info = "Something went wrong"
data: Dict[str, Any] = {}
code = 500
code = 500 # Default to internal server error.

if type(error) is InvalidAPIError:
# Determine specific error types and set appropriate response details.
if isinstance(error, InvalidAPIError):
code = 400
status_info = "Bad API request"

if type(error) is UnauthorizedAPIAccessError:
elif isinstance(error, UnauthorizedAPIAccessError):
code = 401
status_info = "Not authorized to use the API"

Expand All @@ -128,4 +196,5 @@ def api_error_handler(error: Exception) -> Response:
"status": status,
"status_info": status_info,
"data": data
}), code
}
), code
5 changes: 5 additions & 0 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from admin.admin import admin_bp, set_theme_loader
from api import api_bp
from cache_config import cache
from datarequest.datarequest import datarequest_bp
from deposit.deposit import deposit_bp
from fileviewer.fileviewer import fileviewer_bp
Expand All @@ -28,6 +29,7 @@
from util import get_validated_static_path, log_error
from vault.vault import vault_bp


app = Flask(__name__, static_folder='assets')
app.json.sort_keys = False

Expand Down Expand Up @@ -123,6 +125,9 @@ def load_admin_setting() -> Dict[str, Any]:
# Start Flask-Session
Session(app)

# Initialize the cache.
cache.init_app(app)

# Start monitoring thread for extracting tech support information
# Monitor signal file can be set to empty to completely disable monitor thread
monitor_enabled: bool = app.config.get("MONITOR_SIGNAL_FILE", "/var/www/yoda/show-tech.sig") != ""
Expand Down
Loading
Loading