Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: add token validation to enforcer and local routes #31

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions horizon/enforcer/api.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import json
from typing import Optional, Dict

from fastapi import APIRouter, status, Response
from fastapi import APIRouter, Depends, status, Response
from opal_client.policy_store import BasePolicyStoreClient, DEFAULT_POLICY_STORE_GETTER
from opal_client.policy_store.opa_client import fail_silently
from opal_client.logger import logger
from horizon.config import sidecar_config
from horizon.token_utils import JWTBearer

from horizon.enforcer.schemas import AuthorizationQuery, AuthorizationResult

Expand Down Expand Up @@ -66,7 +67,8 @@ def log_query_and_result(query: AuthorizationQuery, response: Response):
)


@router.post("/allowed", response_model=AuthorizationResult, status_code=status.HTTP_200_OK, response_model_exclude_none=True)
@router.post("/allowed", response_model=AuthorizationResult, status_code=status.HTTP_200_OK, response_model_exclude_none=True,
dependencies=[Depends(JWTBearer())])
async def is_allowed(query: AuthorizationQuery):
async def _is_allowed():
return await policy_store.get_data_with_input(path="rbac", input=query)
Expand Down
5 changes: 3 additions & 2 deletions horizon/local/api.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from typing import Dict, Any, List, Optional

from fastapi import APIRouter, status, HTTPException
from fastapi import APIRouter, Depends, status, HTTPException
from opal_client.policy_store import BasePolicyStoreClient, DEFAULT_POLICY_STORE_GETTER

from horizon.local.schemas import Message, SyncedRole, SyncedUser
from horizon.token_utils import JWTBearer

def init_local_cache_api_router(policy_store:BasePolicyStoreClient=None):
policy_store = policy_store or DEFAULT_POLICY_STORE_GETTER()
router = APIRouter()
router = APIRouter(dependencies=[Depends(JWTBearer())])

def error_message(msg: str):
return {
Expand Down
87 changes: 87 additions & 0 deletions horizon/token_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import requests
from fastapi import Request, HTTPException
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from jose import jwt, JWTError


AUTH0_DOMAIN = "acalla-dev.us.auth0.com"
API_AUDIENCE = f"https://api.acalla.com/v1/"
ALGORITHMS = ["RS256"]
JWK_TOKENS = requests.get(f"https://{AUTH0_DOMAIN}/.well-known/jwks.json").json()

class JWTBearer(HTTPBearer):
def __init__(self, auto_error: bool = True):
super(JWTBearer, self).__init__(auto_error=auto_error)
async def __call__(self, request: Request):
credentials: HTTPAuthorizationCredentials = await super(
JWTBearer, self
).__call__(request)
if credentials:
if not credentials.scheme == "Bearer":
raise HTTPException(
status_code=403, detail="Invalid authentication scheme."
)
if not type(self)._verify_jwt(credentials.credentials):
raise HTTPException(
status_code=403, detail="Invalid token or expired token."
)
return credentials.credentials
else:
raise HTTPException(status_code=403, detail="Invalid authorization code.")

@classmethod
def _verify_jwt(cls, jwtoken: str) -> bool:
# is_token_valid: bool = False

payload = cls._decode_jwt(jwtoken)
return payload
# except HTTPException:
# payload = None
# if payload:
# is_token_valid = True
# return is_token_valid

@classmethod
def _decode_jwt(cls, jwtoken: str) -> dict:
rsa_key = cls._get_rsa_key(jwtoken)
try:
return jwt.decode(
jwtoken,
rsa_key,
algorithms=ALGORITHMS,
audience=API_AUDIENCE,
issuer=f"https://{AUTH0_DOMAIN}/",
)
except jwt.ExpiredSignatureError:
raise HTTPException(
status_code=401,
detail="token is expired",
)
except jwt.JWTClaimsError:
raise HTTPException(
status_code=401,
detail="incorrect claims, please check the audience and issuer",
)
except Exception as e:
print(e)
raise HTTPException(
status_code=401,
detail="Unable to parse authentication token." + str(e),
)

@staticmethod
def _get_rsa_key(jwtoken: str) -> dict:
unverified_header = jwt.get_unverified_header(jwtoken)
for key in JWK_TOKENS["keys"]:
if key["kid"] == unverified_header["kid"]:
return {
"kty": key["kty"],
"kid": key["kid"],
"use": key["use"],
"n": key["n"],
"e": key["e"],
}
raise HTTPException(
status_code=401,
detail="Unable to find appropriate key",
)
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@ tenacity==6.3.1
Jinja2==3.0.3
logzio-python-handler
rook
ddtrace
ddtrace
jose