-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
closes #1
- Loading branch information
Showing
5 changed files
with
255 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,7 +3,7 @@ name: Lint | |
on: | ||
push: | ||
paths: | ||
- "eoapi/**" | ||
- "**/*.py" | ||
|
||
jobs: | ||
pre-commit: | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
name: Test | ||
|
||
on: | ||
push: | ||
paths: | ||
- "**/*.py" | ||
- "pyproject.toml" | ||
|
||
jobs: | ||
pytest: | ||
runs-on: ubuntu-latest | ||
steps: | ||
- uses: actions/checkout@v4 | ||
|
||
- name: Set up Python | ||
uses: actions/setup-python@v5 | ||
with: | ||
python-version: "3.11" | ||
cache: "pip" | ||
|
||
- name: Install dependencies | ||
run: pip install -e ".[testing]" | ||
|
||
- name: Run tests | ||
run: pytest |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -34,4 +34,5 @@ lint = [ | |
testing = [ | ||
"pytest>=6.0", | ||
"coverage", | ||
"jwcrypto>=1.5.6", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,217 @@ | ||
import pytest | ||
import jwt | ||
import json | ||
from typing import Any, Dict | ||
from fastapi import FastAPI, HTTPException, Security, status, testclient | ||
from unittest.mock import patch, MagicMock | ||
from eoapi.auth_utils import OpenIdConnectAuth | ||
from cryptography.hazmat.primitives.asymmetric import rsa | ||
from jwcrypto.jwt import JWK | ||
from jwcrypto.jwt import JWT | ||
|
||
|
||
@pytest.fixture | ||
def test_key() -> "JWK": | ||
return JWK.generate( | ||
kty="RSA", size=2048, kid="test", use="sig", e="AQAB", alg="RS256" | ||
) | ||
|
||
|
||
@pytest.fixture | ||
def public_key(test_key: "JWK") -> Dict[str, Any]: | ||
return test_key.export_public(as_dict=True) | ||
|
||
|
||
@pytest.fixture | ||
def private_key(test_key: "JWK") -> Dict[str, Any]: | ||
return test_key.export_private(as_dict=True) | ||
|
||
|
||
@pytest.fixture(autouse=True) | ||
def mock_jwks(public_key: "rsa.RSAPrivateKey"): | ||
mock_oidc_config = {"jwks_uri": "https://example.com/jwks"} | ||
|
||
mock_jwks = {"keys": [public_key]} | ||
|
||
with ( | ||
patch("urllib.request.urlopen") as mock_urlopen, | ||
patch("jwt.PyJWKClient.fetch_data") as mock_fetch_data, | ||
): | ||
mock_oidc_config_response = MagicMock() | ||
mock_oidc_config_response.read.return_value = json.dumps( | ||
mock_oidc_config | ||
).encode() | ||
mock_oidc_config_response.status = 200 | ||
|
||
mock_urlopen.return_value.__enter__.return_value = mock_oidc_config_response | ||
mock_fetch_data.return_value = mock_jwks | ||
yield mock_urlopen | ||
|
||
|
||
@pytest.fixture | ||
def token_builder(test_key: "JWK"): | ||
def build_token(payload: Dict[str, Any], key=None) -> str: | ||
jwt_token = JWT( | ||
header={k: test_key.get(k) for k in ["alg", "kid"]}, | ||
claims=payload, | ||
) | ||
jwt_token.make_signed_token(key or test_key) | ||
return jwt_token.serialize() | ||
|
||
return build_token | ||
|
||
|
||
@pytest.fixture | ||
def test_app(): | ||
app = FastAPI() | ||
|
||
@app.get("/test-route") | ||
def test(): | ||
return {"message": "Hello World"} | ||
|
||
return app | ||
|
||
|
||
@pytest.fixture | ||
def test_client(test_app): | ||
return testclient.TestClient(test_app) | ||
|
||
|
||
def test_oidc_auth_initialization(mock_jwks: MagicMock): | ||
""" | ||
Auth object is initialized with the correct dependencies. | ||
""" | ||
openid_configuration_url = "https://example.com/.well-known/openid-configuration" | ||
auth = OpenIdConnectAuth(openid_configuration_url=openid_configuration_url) | ||
assert auth.jwks_client is not None | ||
assert auth.auth_scheme is not None | ||
assert auth.valid_token_dependency is not None | ||
mock_jwks.assert_called_once_with(openid_configuration_url) | ||
|
||
|
||
def test_auth_token_valid(token_builder): | ||
""" | ||
Auth token dependency returns the token payload when the token is valid. | ||
""" | ||
token = token_builder({"scope": "test_scope"}) | ||
|
||
auth = OpenIdConnectAuth( | ||
openid_configuration_url="https://example.com/.well-known/openid-configuration" | ||
) | ||
|
||
token_payload = auth.valid_token_dependency( | ||
auth_header=f"Bearer {token}", required_scopes=Security([]) | ||
) | ||
assert token_payload["scope"] == "test_scope" | ||
|
||
|
||
def test_auth_token_invalid_audience(token_builder): | ||
""" | ||
Auth token dependency throws 401 when the token audience is invalid. | ||
""" | ||
token = token_builder({"scope": "test_scope", "aud": "test_audience"}) | ||
|
||
auth = OpenIdConnectAuth( | ||
openid_configuration_url="https://example.com/.well-known/openid-configuration" | ||
) | ||
|
||
with pytest.raises(HTTPException) as exc_info: | ||
auth.valid_token_dependency( | ||
auth_header=f"Bearer {token}", required_scopes=Security([]) | ||
) | ||
|
||
assert exc_info.value.status_code == status.HTTP_401_UNAUTHORIZED | ||
assert exc_info.value.detail == "Could not validate credentials" | ||
assert isinstance(exc_info.value.__cause__, jwt.exceptions.InvalidAudienceError) | ||
|
||
|
||
def test_auth_token_invalid_signature(token_builder): | ||
""" | ||
Auth token dependency throws 401 when the token signature is invalid. | ||
""" | ||
other_key = JWK.generate( | ||
kty="RSA", size=2048, kid="test", use="sig", e="AQAB", alg="RS256" | ||
) | ||
token = token_builder({"scope": "test_scope", "aud": "test_audience"}, other_key) | ||
|
||
auth = OpenIdConnectAuth( | ||
openid_configuration_url="https://example.com/.well-known/openid-configuration" | ||
) | ||
|
||
with pytest.raises(HTTPException) as exc_info: | ||
auth.valid_token_dependency( | ||
auth_header=f"Bearer {token}", required_scopes=Security([]) | ||
) | ||
|
||
assert exc_info.value.status_code == status.HTTP_401_UNAUTHORIZED | ||
assert exc_info.value.detail == "Could not validate credentials" | ||
assert isinstance(exc_info.value.__cause__, jwt.exceptions.InvalidSignatureError) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"token", | ||
[ | ||
"foo", | ||
"Bearer foo", | ||
"Bearer foo.bar.xyz", | ||
"Basic foo", | ||
], | ||
) | ||
def test_auth_token_invalid_token(token): | ||
""" | ||
Auth token dependency throws 401 when the token is invalid. | ||
""" | ||
auth = OpenIdConnectAuth( | ||
openid_configuration_url="https://example.com/.well-known/openid-configuration" | ||
) | ||
|
||
with pytest.raises(HTTPException) as exc_info: | ||
auth.valid_token_dependency(auth_header=token, required_scopes=Security([])) | ||
|
||
assert exc_info.value.status_code == status.HTTP_401_UNAUTHORIZED | ||
assert exc_info.value.detail == "Could not validate credentials" | ||
|
||
|
||
def test_apply_auth_dependencies(test_app, test_client): | ||
auth = OpenIdConnectAuth( | ||
openid_configuration_url="https://example.com/.well-known/openid-configuration" | ||
) | ||
|
||
for route in test_app.routes: | ||
auth.apply_auth_dependencies( | ||
api_route=route, required_token_scopes=["test_scope"] | ||
) | ||
|
||
resp = test_client.get("/test-route") | ||
assert resp.json() == {"detail": "Not authenticated"} | ||
assert resp.status_code == status.HTTP_403_FORBIDDEN | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"required_sent_response", | ||
[ | ||
("a", "b", status.HTTP_401_UNAUTHORIZED), | ||
("a b c", "a b", status.HTTP_401_UNAUTHORIZED), | ||
("a", "a", status.HTTP_200_OK), | ||
(None, None, status.HTTP_200_OK), | ||
(None, "a", status.HTTP_200_OK), | ||
("a b c", "d c b a", status.HTTP_200_OK), | ||
], | ||
) | ||
def test_reject_wrong_scope( | ||
test_app, test_client, token_builder, required_sent_response | ||
): | ||
auth = OpenIdConnectAuth( | ||
openid_configuration_url="https://example.com/.well-known/openid-configuration" | ||
) | ||
|
||
scope_required, scope_sent, expected_status = required_sent_response | ||
for route in test_app.routes: | ||
auth.apply_auth_dependencies( | ||
api_route=route, | ||
required_token_scopes=scope_required.split(" ") if scope_required else None, | ||
) | ||
|
||
token = token_builder({"scope": scope_sent}) | ||
resp = test_client.get("/test-route", headers={"Authorization": f"Bearer {token}"}) | ||
assert resp.status_code == expected_status |