Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
closes #1
  • Loading branch information
alukach committed Aug 22, 2024
1 parent 2c65bb0 commit ab9ba1e
Show file tree
Hide file tree
Showing 5 changed files with 255 additions and 9 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/lint.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ name: Lint
on:
push:
paths:
- "eoapi/**"
- "**/*.py"

jobs:
pre-commit:
Expand Down
25 changes: 25 additions & 0 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
name: Test

on:
push:
paths:
- "**/*.py"
- "pyproject.toml"

jobs:
pytest:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.11"
cache: "pip"

- name: Install dependencies
run: pip install -e ".[testing]"

- name: Run tests
run: pytest
19 changes: 11 additions & 8 deletions eoapi/auth_utils/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,28 +69,31 @@ def create_auth_token_dependency(
"""

def auth_token(
token_str: Annotated[str, Security(auth_scheme)],
auth_header: Annotated[str, Security(auth_scheme)],
required_scopes: security.SecurityScopes,
):
token_parts = token_str.split(" ")
# Extract token from header
token_parts = auth_header.split(" ")
if len(token_parts) != 2 or token_parts[0].lower() != "bearer":
logger.error(f"Invalid token: {auth_header}")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid authorization header",
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
else:
[_, token] = token_parts
[_, token] = token_parts

# Parse & validate token
try:
key = jwks_client.get_signing_key_from_jwt(token).key
payload = jwt.decode(
token,
jwks_client.get_signing_key_from_jwt(token).key,
key,
algorithms=["RS256"],
# NOTE: Audience validation MUST match audience claim if set in token (https://pyjwt.readthedocs.io/en/stable/changelog.html?highlight=audience#id40)
audience=allowed_jwt_audiences,
)
except jwt.exceptions.InvalidTokenError as e:
except (jwt.exceptions.InvalidTokenError, jwt.exceptions.DecodeError) as e:
logger.exception(f"InvalidTokenError: {e=}")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
Expand Down Expand Up @@ -124,7 +127,7 @@ def apply_auth_dependencies(
"""
# Ignore paths without dependants, e.g. /api, /api.html, /docs/oauth2-redirect
if not hasattr(api_route, "dependant"):
logger.warn(
logger.warning(
f"Route {api_route} has no dependant, not apply auth dependency"
)
return
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,5 @@ lint = [
testing = [
"pytest>=6.0",
"coverage",
"jwcrypto>=1.5.6",
]
217 changes: 217 additions & 0 deletions tests/test_auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,217 @@
import pytest
import jwt
import json
from typing import Any, Dict
from fastapi import FastAPI, HTTPException, Security, status, testclient
from unittest.mock import patch, MagicMock
from eoapi.auth_utils import OpenIdConnectAuth
from cryptography.hazmat.primitives.asymmetric import rsa
from jwcrypto.jwt import JWK
from jwcrypto.jwt import JWT


@pytest.fixture
def test_key() -> "JWK":
return JWK.generate(
kty="RSA", size=2048, kid="test", use="sig", e="AQAB", alg="RS256"
)


@pytest.fixture
def public_key(test_key: "JWK") -> Dict[str, Any]:
return test_key.export_public(as_dict=True)


@pytest.fixture
def private_key(test_key: "JWK") -> Dict[str, Any]:
return test_key.export_private(as_dict=True)


@pytest.fixture(autouse=True)
def mock_jwks(public_key: "rsa.RSAPrivateKey"):
mock_oidc_config = {"jwks_uri": "https://example.com/jwks"}

mock_jwks = {"keys": [public_key]}

with (
patch("urllib.request.urlopen") as mock_urlopen,
patch("jwt.PyJWKClient.fetch_data") as mock_fetch_data,
):
mock_oidc_config_response = MagicMock()
mock_oidc_config_response.read.return_value = json.dumps(
mock_oidc_config
).encode()
mock_oidc_config_response.status = 200

mock_urlopen.return_value.__enter__.return_value = mock_oidc_config_response
mock_fetch_data.return_value = mock_jwks
yield mock_urlopen


@pytest.fixture
def token_builder(test_key: "JWK"):
def build_token(payload: Dict[str, Any], key=None) -> str:
jwt_token = JWT(
header={k: test_key.get(k) for k in ["alg", "kid"]},
claims=payload,
)
jwt_token.make_signed_token(key or test_key)
return jwt_token.serialize()

return build_token


@pytest.fixture
def test_app():
app = FastAPI()

@app.get("/test-route")
def test():
return {"message": "Hello World"}

return app


@pytest.fixture
def test_client(test_app):
return testclient.TestClient(test_app)


def test_oidc_auth_initialization(mock_jwks: MagicMock):
"""
Auth object is initialized with the correct dependencies.
"""
openid_configuration_url = "https://example.com/.well-known/openid-configuration"
auth = OpenIdConnectAuth(openid_configuration_url=openid_configuration_url)
assert auth.jwks_client is not None
assert auth.auth_scheme is not None
assert auth.valid_token_dependency is not None
mock_jwks.assert_called_once_with(openid_configuration_url)


def test_auth_token_valid(token_builder):
"""
Auth token dependency returns the token payload when the token is valid.
"""
token = token_builder({"scope": "test_scope"})

auth = OpenIdConnectAuth(
openid_configuration_url="https://example.com/.well-known/openid-configuration"
)

token_payload = auth.valid_token_dependency(
auth_header=f"Bearer {token}", required_scopes=Security([])
)
assert token_payload["scope"] == "test_scope"


def test_auth_token_invalid_audience(token_builder):
"""
Auth token dependency throws 401 when the token audience is invalid.
"""
token = token_builder({"scope": "test_scope", "aud": "test_audience"})

auth = OpenIdConnectAuth(
openid_configuration_url="https://example.com/.well-known/openid-configuration"
)

with pytest.raises(HTTPException) as exc_info:
auth.valid_token_dependency(
auth_header=f"Bearer {token}", required_scopes=Security([])
)

assert exc_info.value.status_code == status.HTTP_401_UNAUTHORIZED
assert exc_info.value.detail == "Could not validate credentials"
assert isinstance(exc_info.value.__cause__, jwt.exceptions.InvalidAudienceError)


def test_auth_token_invalid_signature(token_builder):
"""
Auth token dependency throws 401 when the token signature is invalid.
"""
other_key = JWK.generate(
kty="RSA", size=2048, kid="test", use="sig", e="AQAB", alg="RS256"
)
token = token_builder({"scope": "test_scope", "aud": "test_audience"}, other_key)

auth = OpenIdConnectAuth(
openid_configuration_url="https://example.com/.well-known/openid-configuration"
)

with pytest.raises(HTTPException) as exc_info:
auth.valid_token_dependency(
auth_header=f"Bearer {token}", required_scopes=Security([])
)

assert exc_info.value.status_code == status.HTTP_401_UNAUTHORIZED
assert exc_info.value.detail == "Could not validate credentials"
assert isinstance(exc_info.value.__cause__, jwt.exceptions.InvalidSignatureError)


@pytest.mark.parametrize(
"token",
[
"foo",
"Bearer foo",
"Bearer foo.bar.xyz",
"Basic foo",
],
)
def test_auth_token_invalid_token(token):
"""
Auth token dependency throws 401 when the token is invalid.
"""
auth = OpenIdConnectAuth(
openid_configuration_url="https://example.com/.well-known/openid-configuration"
)

with pytest.raises(HTTPException) as exc_info:
auth.valid_token_dependency(auth_header=token, required_scopes=Security([]))

assert exc_info.value.status_code == status.HTTP_401_UNAUTHORIZED
assert exc_info.value.detail == "Could not validate credentials"


def test_apply_auth_dependencies(test_app, test_client):
auth = OpenIdConnectAuth(
openid_configuration_url="https://example.com/.well-known/openid-configuration"
)

for route in test_app.routes:
auth.apply_auth_dependencies(
api_route=route, required_token_scopes=["test_scope"]
)

resp = test_client.get("/test-route")
assert resp.json() == {"detail": "Not authenticated"}
assert resp.status_code == status.HTTP_403_FORBIDDEN


@pytest.mark.parametrize(
"required_sent_response",
[
("a", "b", status.HTTP_401_UNAUTHORIZED),
("a b c", "a b", status.HTTP_401_UNAUTHORIZED),
("a", "a", status.HTTP_200_OK),
(None, None, status.HTTP_200_OK),
(None, "a", status.HTTP_200_OK),
("a b c", "d c b a", status.HTTP_200_OK),
],
)
def test_reject_wrong_scope(
test_app, test_client, token_builder, required_sent_response
):
auth = OpenIdConnectAuth(
openid_configuration_url="https://example.com/.well-known/openid-configuration"
)

scope_required, scope_sent, expected_status = required_sent_response
for route in test_app.routes:
auth.apply_auth_dependencies(
api_route=route,
required_token_scopes=scope_required.split(" ") if scope_required else None,
)

token = token_builder({"scope": scope_sent})
resp = test_client.get("/test-route", headers={"Authorization": f"Bearer {token}"})
assert resp.status_code == expected_status

0 comments on commit ab9ba1e

Please sign in to comment.