Skip to content

Commit

Permalink
add option to enforce use of scopes
Browse files Browse the repository at this point in the history
  • Loading branch information
Leobouloc committed Oct 17, 2023
1 parent 95c6ede commit 401db9e
Show file tree
Hide file tree
Showing 17 changed files with 704 additions and 258 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ have an authority field matching that of the user
with camelCase alias, in `LRSStatementsQuery`
- API: Add `RALPH_LRS_RESTRICT_BY_AUTHORITY` option making `?mine=True`
implicit
- API: Add `RALPH_LRS_RESTRICT_BY_SCOPE` option enabling endpoint access
control by user scopes

### Fixed

Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ match = ^(?!(setup)\.(py)$).*\.(py)$
[isort]
known_ralph=ralph
sections=FUTURE,STDLIB,THIRDPARTY,RALPH,FIRSTPARTY,LOCALFOLDER
skip_glob=venv
skip_glob=venv,*/.conda/*
profile=black

[tool:pytest]
Expand Down
13 changes: 6 additions & 7 deletions src/ralph/api/auth/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
"""Main module for Ralph's LRS API authentication."""

from ralph.api.auth.basic import get_authenticated_user as get_basic_user
from ralph.api.auth.oidc import get_authenticated_user as get_oidc_user
from ralph.api.auth.basic import get_basic_auth_user
from ralph.api.auth.oidc import get_oidc_user
from ralph.conf import settings

# At startup, select the authentication mode that will be used
get_authenticated_user = (
get_oidc_user
if settings.RUNSERVER_AUTH_BACKEND == settings.AuthBackends.OIDC
else get_basic_user
)
if settings.RUNSERVER_AUTH_BACKEND == settings.AuthBackends.OIDC:
get_authenticated_user = get_oidc_user
else:
get_authenticated_user = get_basic_auth_user
31 changes: 22 additions & 9 deletions src/ralph/api/auth/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@
import bcrypt
from cachetools import TTLCache, cached
from fastapi import Depends, HTTPException, status
from fastapi.security import HTTPBasic, HTTPBasicCredentials
from fastapi.security import HTTPBasic, HTTPBasicCredentials, SecurityScopes
from pydantic import BaseModel, root_validator
from starlette.authentication import AuthenticationError

from ralph.api.auth.user import AuthenticatedUser
from ralph.api.auth.user import AuthenticatedUser, UserScopes
from ralph.conf import settings

# Unused password used to avoid timing attacks, by comparing passwords supplied
Expand Down Expand Up @@ -102,15 +102,17 @@ def get_stored_credentials(auth_file: Path) -> ServerUsersCredentials:
@cached(
TTLCache(maxsize=settings.AUTH_CACHE_MAX_SIZE, ttl=settings.AUTH_CACHE_TTL),
lock=Lock(),
key=lambda credentials: (
key=lambda credentials, security_scopes: (
credentials.username,
credentials.password,
security_scopes.scope_str,
)
if credentials is not None
else None,
)
def get_authenticated_user(
def get_basic_auth_user(
credentials: Union[HTTPBasicCredentials, None] = Depends(security),
security_scopes: SecurityScopes = SecurityScopes([]),
) -> AuthenticatedUser:
"""Checks valid auth parameters.
Expand All @@ -119,13 +121,10 @@ def get_authenticated_user(
Args:
credentials (iterator): auth parameters from the Authorization header
Return:
AuthenticatedUser (AuthenticatedUser)
security_scopes: scopes requested for access
Raises:
HTTPException
"""
if not credentials:
logger.error("The basic authentication mode requires a Basic Auth header")
Expand Down Expand Up @@ -156,6 +155,7 @@ def get_authenticated_user(
status_code=status.HTTP_403_FORBIDDEN, detail=str(exc)
) from exc

# Check that password was passed
if not hashed_password:
# We're doing a bogus password check anyway to avoid timing attacks on
# usernames
Expand All @@ -168,6 +168,7 @@ def get_authenticated_user(
headers={"WWW-Authenticate": "Basic"},
)

# Check password validity
if not bcrypt.checkpw(
credentials.password.encode(settings.LOCALE_ENCODING),
hashed_password.encode(settings.LOCALE_ENCODING),
Expand All @@ -182,4 +183,16 @@ def get_authenticated_user(
headers={"WWW-Authenticate": "Basic"},
)

return AuthenticatedUser(scopes=user.scopes, agent=user.agent)
user = AuthenticatedUser(scopes=UserScopes(user.scopes), agent=user.agent)

# Restrict access by scopes
if settings.LRS_RESTRICT_BY_SCOPES:
for requested_scope in security_scopes.scopes:
is_auth = user.scopes.is_authorized(requested_scope)
if not is_auth:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=f'Access not authorized to scope: "{requested_scope}".',
headers={"WWW-Authenticate": "Basic"},
)
return user
31 changes: 23 additions & 8 deletions src/ralph/api/auth/oidc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,17 @@

import logging
from functools import lru_cache
from typing import Optional, Union
from typing import Optional

import requests
from fastapi import Depends, HTTPException, status
from fastapi.security import OpenIdConnect
from fastapi.security import OpenIdConnect, SecurityScopes
from jose import ExpiredSignatureError, JWTError, jwt
from jose.exceptions import JWTClaimsError
from pydantic import AnyUrl, BaseModel, Extra
from typing_extensions import Annotated

from ralph.api.auth.user import AuthenticatedUser
from ralph.api.auth.user import AuthenticatedUser, UserScopes
from ralph.conf import settings

OPENID_CONFIGURATION_PATH = "/.well-known/openid-configuration"
Expand Down Expand Up @@ -92,8 +93,9 @@ def get_public_keys(jwks_uri: AnyUrl) -> dict:
) from exc


def get_authenticated_user(
auth_header: Union[str, None] = Depends(oauth2_scheme)
def get_oidc_user(
auth_header: Annotated[Optional[str], Depends(oauth2_scheme)],
security_scopes: SecurityScopes = SecurityScopes([]),
) -> AuthenticatedUser:
"""Decode and validate OpenId Connect ID token against issuer in config.
Expand Down Expand Up @@ -143,7 +145,20 @@ def get_authenticated_user(

id_token = IDToken.parse_obj(decoded_token)

return AuthenticatedUser(
agent={"openid": id_token.sub},
scopes=id_token.scope.split(" ") if id_token.scope else [],
user = AuthenticatedUser(
agent={"openid": f"{id_token.iss}/{id_token.sub}"},
scopes=UserScopes(id_token.scope.split(" ") if id_token.scope else []),
)

# Restrict access by scopes
if settings.LRS_RESTRICT_BY_SCOPES:
for requested_scope in security_scopes.scopes:
is_auth = user.scopes.is_authorized(requested_scope)
if not is_auth:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=f'Access not authorized to scope: "{requested_scope}".',
headers={"WWW-Authenticate": "Basic"},
)

return user
41 changes: 39 additions & 2 deletions src/ralph/api/auth/user.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Authenticated user for the Ralph API."""

from typing import Dict, List, Literal
from functools import lru_cache
from typing import Dict, FrozenSet, Literal

from pydantic import BaseModel

Expand All @@ -18,6 +19,42 @@
]


class UserScopes(FrozenSet[Scope]):
"""Scopes available to users."""

@lru_cache()
def is_authorized(self, requested_scope: Scope):
"""Check if the requested scope can be accessed based on user scopes."""
expanded_scopes = {
"statements/read": {"statements/read/mine", "statements/read"},
"all/read": {
"statements/read/mine",
"statements/read",
"state/read",
"profile/read",
"all/read",
},
"all": {
"statements/write",
"statements/read/mine",
"statements/read",
"state/read",
"state/write",
"define",
"profile/read",
"profile/write",
"all/read",
"all",
},
}

expanded_user_scopes = set()
for scope in self:
expanded_user_scopes.update(expanded_scopes.get(scope, {scope}))

return requested_scope in expanded_user_scopes


class AuthenticatedUser(BaseModel):
"""Pydantic model for user authentication.
Expand All @@ -27,4 +64,4 @@ class AuthenticatedUser(BaseModel):
"""

agent: Dict
scopes: List[Scope]
scopes: UserScopes
39 changes: 24 additions & 15 deletions src/ralph/api/routers/statements.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
Query,
Request,
Response,
Security,
status,
)
from fastapi.dependencies.models import Dependant
Expand Down Expand Up @@ -101,6 +102,7 @@ def _enrich_statement_with_authority(statement: dict, current_user: Authenticate
def _parse_agent_parameters(agent_obj: dict):
"""Parse a dict and return an AgentParameters object to use in queries."""
# Transform agent to `dict` as FastAPI cannot parse JSON (seen as string)

agent = parse_obj_as(BaseXapiAgent, agent_obj)

agent_query_params = {}
Expand Down Expand Up @@ -137,10 +139,12 @@ def strict_query_params(request: Request):

@router.get("")
@router.get("/")
# pylint: disable=too-many-arguments, too-many-locals
async def get(
request: Request,
current_user: Annotated[AuthenticatedUser, Depends(get_authenticated_user)],
current_user: Annotated[
AuthenticatedUser,
Security(get_authenticated_user, scopes=["statements/read/mine"]),
],
###
# Query string parameters defined by the LRS specification
###
Expand Down Expand Up @@ -170,15 +174,13 @@ async def get(
"of the Statement is an Activity with the specified id"
),
),
# pylint: disable=unused-argument
registration: Optional[UUID] = Query(
None,
description=(
"**Not implemented** "
"Filter, only return Statements matching the specified registration id"
),
),
# pylint: disable=unused-argument
related_activities: Optional[bool] = Query(
False,
description=(
Expand All @@ -189,7 +191,6 @@ async def get(
"instead of that parameter's normal behaviour"
),
),
# pylint: disable=unused-argument
related_agents: Optional[bool] = Query(
False,
description=(
Expand Down Expand Up @@ -221,7 +222,6 @@ async def get(
"0 indicates return the maximum the server will allow"
),
),
# pylint: disable=unused-argument, redefined-builtin
format: Optional[Literal["ids", "exact", "canonical"]] = Query(
"exact",
description=(
Expand All @@ -240,7 +240,6 @@ async def get(
'as in "exact" mode.'
),
),
# pylint: disable=unused-argument
attachments: Optional[bool] = Query(
False,
description=(
Expand Down Expand Up @@ -286,6 +285,9 @@ async def get(
LRS Specification:
https://github.com/adlnet/xAPI-Spec/blob/1.0.3/xAPI-Communication.md#213-get-statements
"""
# pylint: disable=unused-argument,redefined-builtin,too-many-arguments
# pylint: disable=too-many-locals

# Make sure the limit does not go above max from settings
limit = min(limit, settings.RUNSERVER_MAX_SEARCH_HITS_COUNT)

Expand Down Expand Up @@ -334,14 +336,15 @@ async def get(
json.loads(query_params["agent"])
)

if settings.LRS_RESTRICT_BY_AUTHORITY:
# If using scopes, only restrict results when appropriate
if settings.LRS_RESTRICT_BY_SCOPES:
raise NotImplementedError("Scopes are not yet implemented in Ralph.")

# Otherwise, enforce mine for all users
# mine: If using scopes, only restrict users with limited scopes
if settings.LRS_RESTRICT_BY_SCOPES:
if not current_user.scopes.is_authorized("statements/read"):
mine = True
# mine: If using only authority, always restrict (otherwise, use the default value)
elif settings.LRS_RESTRICT_BY_AUTHORITY:
mine = True

# Filter by authority if using `mine`
if mine:
query_params["authority"] = _parse_agent_parameters(current_user.agent)

Expand Down Expand Up @@ -399,7 +402,10 @@ async def get(
@router.put("", responses=POST_PUT_RESPONSES, status_code=status.HTTP_204_NO_CONTENT)
# pylint: disable=unused-argument, too-many-branches
async def put(
current_user: Annotated[AuthenticatedUser, Depends(get_authenticated_user)],
current_user: Annotated[
AuthenticatedUser,
Security(get_authenticated_user, scopes=["statements/write"]),
],
statement: LaxStatement,
background_tasks: BackgroundTasks,
statement_id: UUID = Query(alias="statementId"),
Expand Down Expand Up @@ -478,7 +484,10 @@ async def put(
@router.post("", responses=POST_PUT_RESPONSES)
# pylint: disable = too-many-branches
async def post(
current_user: Annotated[AuthenticatedUser, Depends(get_authenticated_user)],
current_user: Annotated[
AuthenticatedUser,
Security(get_authenticated_user, scopes=["statements/write"]),
],
statements: Union[LaxStatement, List[LaxStatement]],
background_tasks: BackgroundTasks,
response: Response,
Expand Down
18 changes: 17 additions & 1 deletion src/ralph/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@
from unittest.mock import Mock

get_app_dir = Mock(return_value=".")
from pydantic import AnyHttpUrl, AnyUrl, BaseModel, BaseSettings, Extra

from pydantic import AnyHttpUrl, AnyUrl, BaseModel, BaseSettings, Extra, root_validator

from ralph.exceptions import ConfigurationException

from .utils import import_string

Expand Down Expand Up @@ -210,5 +213,18 @@ def LOCALE_ENCODING(self) -> str: # pylint: disable=invalid-name
"""Returns Ralph's default locale encoding."""
return self._CORE.LOCALE_ENCODING

@root_validator(allow_reuse=True)
@classmethod
def check_restriction_compatibility(cls, values):
"""Raise an error if scopes are being used without authority restriction."""
if values.get("LRS_RESTRICT_BY_SCOPES") and not values.get(
"LRS_RESTRICT_BY_AUTHORITY"
):
raise ConfigurationException(
"`LRS_RESTRICT_BY_AUTHORITY` must be set to `True` if using "
"`LRS_RESTRICT_BY_SCOPES=True`"
)
return values


settings = Settings()
Loading

0 comments on commit 401db9e

Please sign in to comment.