Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(api) Add option to enforce use of scopes (all/read, all/write, etc.) [after #448] #441

Closed
Closed
Show file tree
Hide file tree
Changes from 27 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions src/ralph/api/auth/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
from ralph.conf import settings

# At startup, select the authentication mode that will be used
get_authenticated_user = (
get_oidc_user
if settings.RUNSERVER_AUTH_BACKEND == settings.AuthBackends.OIDC
else get_basic_user
)
if settings.RUNSERVER_AUTH_BACKEND == settings.AuthBackends.OIDC:
get_authenticated_user = get_oidc_user
else:
get_authenticated_user = get_basic_user
29 changes: 21 additions & 8 deletions src/ralph/api/auth/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@
import bcrypt
from cachetools import TTLCache, cached
from fastapi import Depends, HTTPException, status
from fastapi.security import HTTPBasic, HTTPBasicCredentials
from fastapi.security import HTTPBasic, HTTPBasicCredentials, SecurityScopes
from pydantic import BaseModel, root_validator
from starlette.authentication import AuthenticationError

from ralph.api.auth.user import AuthenticatedUser
from ralph.api.auth.user import AuthenticatedUser, UserScopes
from ralph.conf import settings

# Unused password used to avoid timing attacks, by comparing passwords supplied
Expand Down Expand Up @@ -102,30 +102,29 @@ def get_stored_credentials(auth_file: Path) -> ServerUsersCredentials:
@cached(
TTLCache(maxsize=settings.AUTH_CACHE_MAX_SIZE, ttl=settings.AUTH_CACHE_TTL),
lock=Lock(),
key=lambda credentials: (
key=lambda credentials, security_scopes: (
credentials.username,
credentials.password,
security_scopes.scope_str,
)
if credentials is not None
else None,
)
def get_authenticated_user(
credentials: Union[HTTPBasicCredentials, None] = Depends(security),
security_scopes: SecurityScopes = SecurityScopes([]),
) -> AuthenticatedUser:
"""Checks valid auth parameters.

Get the basic auth parameters from the Authorization header, and checks them
against our own list of hashed credentials.

Args:
security_scopes: scopes requested for access
credentials (iterator): auth parameters from the Authorization header

Return:
AuthenticatedUser (AuthenticatedUser)

Raises:
HTTPException

"""
if not credentials:
logger.error("The basic authentication mode requires a Basic Auth header")
Expand Down Expand Up @@ -156,6 +155,7 @@ def get_authenticated_user(
status_code=status.HTTP_403_FORBIDDEN, detail=str(exc)
) from exc

# Check that password was passed
if not hashed_password:
# We're doing a bogus password check anyway to avoid timing attacks on
# usernames
Expand All @@ -168,6 +168,7 @@ def get_authenticated_user(
headers={"WWW-Authenticate": "Basic"},
)

# Check password validity
if not bcrypt.checkpw(
credentials.password.encode(settings.LOCALE_ENCODING),
hashed_password.encode(settings.LOCALE_ENCODING),
Expand All @@ -182,4 +183,16 @@ def get_authenticated_user(
headers={"WWW-Authenticate": "Basic"},
)

return AuthenticatedUser(scopes=user.scopes, agent=user.agent)
user = AuthenticatedUser(scopes=UserScopes(user.scopes), agent=user.agent)

# Restrict access by scopes
if settings.LRS_RESTRICT_BY_SCOPES:
for requested_scope in security_scopes.scopes:
is_auth = user.scopes.is_authorized(requested_scope)
if not is_auth:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=f'Access not authorized to scope: "{requested_scope}".',
headers={"WWW-Authenticate": "Basic"},
)
return user
26 changes: 20 additions & 6 deletions src/ralph/api/auth/oidc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,16 @@

import logging
from functools import lru_cache
from typing import Optional, Union
from typing import Annotated, Optional

import requests
from fastapi import Depends, HTTPException, status
from fastapi.security import OpenIdConnect
from fastapi.security import OpenIdConnect, SecurityScopes
from jose import ExpiredSignatureError, JWTError, jwt
from jose.exceptions import JWTClaimsError
from pydantic import AnyUrl, BaseModel, Extra

from ralph.api.auth.user import AuthenticatedUser
from ralph.api.auth.user import AuthenticatedUser, UserScopes
from ralph.conf import settings

OPENID_CONFIGURATION_PATH = "/.well-known/openid-configuration"
Expand Down Expand Up @@ -93,7 +93,8 @@ def get_public_keys(jwks_uri: AnyUrl) -> dict:


def get_authenticated_user(
auth_header: Union[str, None] = Depends(oauth2_scheme)
auth_header: Annotated[Optional[str], Depends(oauth2_scheme)],
security_scopes: SecurityScopes = SecurityScopes([]),
) -> AuthenticatedUser:
"""Decode and validate OpenId Connect ID token against issuer in config.

Expand Down Expand Up @@ -143,7 +144,20 @@ def get_authenticated_user(

id_token = IDToken.parse_obj(decoded_token)

return AuthenticatedUser(
user = AuthenticatedUser(
agent={"openid": id_token.sub},
scopes=id_token.scope.split(" ") if id_token.scope else [],
scopes=UserScopes(id_token.scope.split(" ") if id_token.scope else []),
)

# Restrict access by scopes
if settings.LRS_RESTRICT_BY_SCOPES:
for requested_scope in security_scopes.scopes:
is_auth = user.scopes.is_authorized(requested_scope)
if not is_auth:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=f'Access not authorized to scope: "{requested_scope}".',
headers={"WWW-Authenticate": "Basic"},
)

return user
42 changes: 40 additions & 2 deletions src/ralph/api/auth/user.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Authenticated user for the Ralph API."""

from typing import Dict, List, Literal
from functools import lru_cache
from typing import Dict, FrozenSet, Literal

from pydantic import BaseModel

Expand All @@ -18,6 +19,43 @@
]


class UserScopes(FrozenSet[Scope]):
"""Scopes available to users."""

@lru_cache()
def is_authorized(self, requested_scope: Scope):
"""Check if the requested scope can be accessed based on user scopes."""

expanded_scopes = {
"statements/read": {"statements/read/mine", "statements/read"},
"all/read": {
"statements/read/mine",
"statements/read",
"state/read",
"profile/read",
"all/read",
},
Comment on lines +30 to +37
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't it be possible to define them with a wildcard? E.g. statements/read/* and */read/*?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefered this explicit syntax as it seemed clearer when reading the code, but I can change this.

Anyone else has an opinion ?

"all": {
"statements/write",
"statements/read/mine",
"statements/read",
"state/read",
"state/write",
"define",
"profile/read",
"profile/write",
"all/read",
"all",
},
}

expanded_user_scopes = set()
for scope in self:
expanded_user_scopes.update(expanded_scopes.get(scope, {scope}))

return requested_scope in expanded_user_scopes


class AuthenticatedUser(BaseModel):
"""Pydantic model for user authentication.

Expand All @@ -27,4 +65,4 @@ class AuthenticatedUser(BaseModel):
"""

agent: Dict
scopes: List[Scope]
scopes: UserScopes
35 changes: 21 additions & 14 deletions src/ralph/api/routers/statements.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
Query,
Request,
Response,
Security,
status,
)
from fastapi.dependencies.models import Dependant
Expand Down Expand Up @@ -130,10 +131,12 @@ def strict_query_params(request: Request):

@router.get("")
@router.get("/")
# pylint: disable=too-many-arguments, too-many-locals
async def get(
request: Request,
current_user: Annotated[AuthenticatedUser, Depends(get_authenticated_user)],
current_user: Annotated[
AuthenticatedUser,
Security(get_authenticated_user, scopes=["statements/read/mine"]),
],
###
# Query string parameters defined by the LRS specification
###
Expand Down Expand Up @@ -163,15 +166,13 @@ async def get(
"of the Statement is an Activity with the specified id"
),
),
# pylint: disable=unused-argument
registration: Optional[UUID] = Query(
None,
description=(
"**Not implemented** "
"Filter, only return Statements matching the specified registration id"
),
),
# pylint: disable=unused-argument
related_activities: Optional[bool] = Query(
False,
description=(
Expand All @@ -182,7 +183,6 @@ async def get(
"instead of that parameter's normal behaviour"
),
),
# pylint: disable=unused-argument
related_agents: Optional[bool] = Query(
False,
description=(
Expand Down Expand Up @@ -214,7 +214,6 @@ async def get(
"0 indicates return the maximum the server will allow"
),
),
# pylint: disable=unused-argument, redefined-builtin
format: Optional[Literal["ids", "exact", "canonical"]] = Query(
"exact",
description=(
Expand All @@ -233,7 +232,6 @@ async def get(
'as in "exact" mode.'
),
),
# pylint: disable=unused-argument
attachments: Optional[bool] = Query(
False,
description=(
Expand Down Expand Up @@ -279,6 +277,7 @@ async def get(
LRS Specification:
https://github.com/adlnet/xAPI-Spec/blob/1.0.3/xAPI-Communication.md#213-get-statements
"""
# pylint: disable=unused-argument,redefined-builtin,too-many-arguments,too-many-locals
# Make sure the limit does not go above max from settings
limit = min(limit, settings.RUNSERVER_MAX_SEARCH_HITS_COUNT)

Expand Down Expand Up @@ -328,12 +327,14 @@ async def get(
)

if settings.LRS_RESTRICT_BY_AUTHORITY:
# If using scopes, only restrict results when appropriate
# If using scopes, restrict to "mine" when user does not have
# scopes wider than `statements/read/mine`
if settings.LRS_RESTRICT_BY_SCOPES:
raise NotImplementedError("Scopes are not yet implemented in Ralph.")

# Otherwise, enforce mine for all users
mine = True
if not current_user.scopes.is_authorized("statements/read"):
mine = True
else:
# If not using scopes, enforce "mine" for all users
mine = True

if mine:
query_params["authority"] = _parse_agent_parameters(current_user.agent)
Expand Down Expand Up @@ -390,7 +391,10 @@ async def get(
@router.put("", responses=POST_PUT_RESPONSES, status_code=status.HTTP_204_NO_CONTENT)
# pylint: disable=unused-argument
async def put(
current_user: Annotated[AuthenticatedUser, Depends(get_authenticated_user)],
current_user: Annotated[
AuthenticatedUser,
Security(get_authenticated_user, scopes=["statements/write"]),
],
statement: LaxStatement,
background_tasks: BackgroundTasks,
statement_id: UUID = Query(alias="statementId"),
Expand Down Expand Up @@ -459,7 +463,10 @@ async def put(
@router.post("/", responses=POST_PUT_RESPONSES)
@router.post("", responses=POST_PUT_RESPONSES)
async def post(
current_user: Annotated[AuthenticatedUser, Depends(get_authenticated_user)],
current_user: Annotated[
AuthenticatedUser,
Security(get_authenticated_user, scopes=["statements/write"]),
],
statements: Union[LaxStatement, List[LaxStatement]],
background_tasks: BackgroundTasks,
response: Response,
Expand Down
Loading