diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 0000000..6d7e758 --- /dev/null +++ b/.editorconfig @@ -0,0 +1,17 @@ +# http://editorconfig.org + +root = true + +[*] +indent_style = space +indent_size = 4 +trim_trailing_whitespace = true +insert_final_newline = true +charset = utf-8 +end_of_line = lf + +[LICENSE] +insert_final_newline = false + +[Makefile] +indent_style = tab diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml new file mode 100644 index 0000000..5c07f9e --- /dev/null +++ b/.github/FUNDING.yml @@ -0,0 +1,3 @@ +# These are supported funding model platforms + +github: [aliev] # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2] diff --git a/.github/ISSUE_TEMPLATE.md b/.github/ISSUE_TEMPLATE.md new file mode 100644 index 0000000..1e6739f --- /dev/null +++ b/.github/ISSUE_TEMPLATE.md @@ -0,0 +1,15 @@ +* aioauth version: +* Python version: +* Operating System: + +### Description + +Describe what you were trying to get done. +Tell us what happened, what went wrong, and what you expected to happen. + +### What I Did + +``` +Paste the command(s) you ran and the output. +If there was a crash, please include the traceback here. +``` diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml new file mode 100644 index 0000000..e5e1856 --- /dev/null +++ b/.github/workflows/python-package.yml @@ -0,0 +1,43 @@ +# This workflow will install Python dependencies, run tests and lint with a variety of Python versions +# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions + +name: Python package + +on: + push: + branches: [ master ] + pull_request: + branches: [ master ] + +jobs: + build: + + runs-on: ubuntu-latest + strategy: + matrix: + python-version: [3.6, 3.7, 3.8] + + steps: + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Setup Node.js environment + uses: actions/setup-node@v1.4.3 + with: + node-version: 12.x + - name: Install dependencies + run: | + python -m pip install --upgrade pip + npm install -g pyright + pip install -e ."[test]" + - name: Type checking with pyright + run: | + pyright src/aioauth tests + - name: Lint with flake8 + run: | + flake8 src/aioauth tests + - name: Test with pytest + run: | + pytest tests diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..edf536b --- /dev/null +++ b/.gitignore @@ -0,0 +1,128 @@ +# Editors +.vscode/ +.idea/ + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# pyenv +.python-version + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Direnv (https://github.com/direnv/direnv) +.envrc + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ + +# MacOS +.DS_Store + +# Emacs +*~ + +# setuptools_scm +_repo_version.py + +# codegen +codegen/openapi.json +codegen/python/ + +# es +data +pythonenv* diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..43b1fa0 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,39 @@ +repos: +- repo: https://github.com/psf/black + rev: stable + hooks: + - id: black + language_version: python3.7 + +- repo: https://github.com/pre-commit/mirrors-isort + rev: v4.3.21 + hooks: + - id: isort + +- repo: https://github.com/PyCQA/flake8 + rev: 3.8.3 + hooks: + - id: flake8 + additional_dependencies: [ + flake8-blind-except, + flake8-builtins, + flake8-comprehensions, + flake8-docstrings, + flake8-mutable, + flake8-print, + flake8-quotes, + flake8-tuple, + ] + +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v2.5.0 + hooks: + - id: mixed-line-ending + args: ['--fix=lf'] + - id: trailing-whitespace + - id: end-of-file-fixer + - id: check-merge-conflict + - id: check-json + - id: check-toml + - id: check-xml + - id: check-yaml diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..44292e0 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +The MIT License (MIT) + +Copyright (c) 2020 Ali Aliyev + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..4fe836c --- /dev/null +++ b/Makefile @@ -0,0 +1,69 @@ +.PHONY: clean clean-test clean-pyc clean-build docs help +.DEFAULT_GOAL := help + +define BROWSER_PYSCRIPT +import os, webbrowser, sys + +from urllib.request import pathname2url + +webbrowser.open("file://" + pathname2url(os.path.abspath(sys.argv[1]))) +endef +export BROWSER_PYSCRIPT + +define PRINT_HELP_PYSCRIPT +import re, sys + +for line in sys.stdin: + match = re.match(r'^([a-zA-Z_-]+):.*?## (.*)$$', line) + if match: + target, help = match.groups() + print("%-20s %s" % (target, help)) +endef +export PRINT_HELP_PYSCRIPT + +BROWSER := python -c "$$BROWSER_PYSCRIPT" + +help: + @python -c "$$PRINT_HELP_PYSCRIPT" < $(MAKEFILE_LIST) + +clean: clean-build clean-pyc clean-test ## remove all build, test, coverage and Python artifacts + +clean-build: ## remove build artifacts + rm -fr build/ + rm -fr dist/ + rm -fr .eggs/ + find . -name '*.egg-info' -exec rm -fr {} + + find . -name '*.egg' -exec rm -f {} + + +clean-pyc: ## remove Python file artifacts + find . -name '*.pyc' -exec rm -f {} + + find . -name '*.pyo' -exec rm -f {} + + find . -name '*~' -exec rm -f {} + + find . -name '__pycache__' -exec rm -fr {} + + +clean-test: ## remove test and coverage artifacts + rm -fr .tox/ + rm -f .coverage + rm -fr htmlcov/ + rm -fr .pytest_cache + +lint: ## check style with flake8 + flake8 src/aioauth tests + pyright src/aioauth tests + +test: ## run tests quickly with the default Python + pytest tests + +test-all: ## run tests on every Python version with tox + tox + +release: dist ## package and upload a release + twine upload dist/* + +dist: clean ## builds source and wheel package + python setup.py sdist + python setup.py bdist_wheel + ls -l dist + +install: clean ## install the package to the active Python's site-packages + python setup.py install diff --git a/README.md b/README.md new file mode 100644 index 0000000..3b5a85c --- /dev/null +++ b/README.md @@ -0,0 +1,25 @@ +## Asynchronous OAuth 2.0 framework for Python 3 + +`aioauth` implements [OAuth 2.0 protocol](https://tools.ietf.org/html/rfc6749) and can be used in asynchronous frameworks like [FastAPI / Starlette](https://github.com/tiangolo/fastapi), [aiohttp](https://github.com/aio-libs/aiohttp). It can work with any databases like `MongoDB`, `PostgreSQL`, `MySQL` and ORMs like [gino](https://python-gino.org/), [sqlalchemy](https://www.sqlalchemy.org/), [databases](https://pypi.org/project/databases/) over simple [BaseDB](src/aioauth/db.py) interface. + +## Why this project exists? + +There are few great OAuth frameworks for Python like [oauthlib](https://github.com/oauthlib/oauthlib) and [authlib](https://github.com/lepture/authlib), but they do not support asyncio because rewriting these libraries to asyncio is a big challenge (see issues [here](https://github.com/lepture/authlib/issues/63) and [here](https://github.com/oauthlib/oauthlib/issues/415)). + +## Supported RFCs + +- [x] [The OAuth 2.0 Authorization Framework](https://tools.ietf.org/html/rfc6749) +- [X] [OAuth 2.0 Token Introspection](https://tools.ietf.org/html/rfc7662) +- [X] [Proof Key for Code Exchange by OAuth Public Clients](https://tools.ietf.org/html/rfc7636) + +``` +python -m pip install aioauth +``` + +## Settings and defaults + +| Setting | Default value | Description | +| ------------------------------------- | ------------- | ------------------------------------------------------------------------------------------------------------------- | +| AIOAUTH_TOKEN_EXPIRES_IN | 86400 | Access token lifetime. Default value in seconds. | +| AIOAUTH_AUTHORIZATION_CODE_EXPIRES_IN | 300 | Authorization code lifetime. Default value in seconds. | +| AIOAUTH_INSECURE_TRANSPORT | False | Allow connections over SSL only. When this option is disabled server will raise "HTTP method is not allowed" error. | diff --git a/requirements/base.txt b/requirements/base.txt new file mode 100644 index 0000000..e69de29 diff --git a/requirements/test.txt b/requirements/test.txt new file mode 100644 index 0000000..42411b8 --- /dev/null +++ b/requirements/test.txt @@ -0,0 +1,24 @@ +async-asgi-testclient==1.4.4 +pre-commit==2.4.0 +black==19.10b0 +isort==4.3.21 +flake8==3.8.2 +flake8-black==0.2.0 +flake8-blind-except==0.1.1 +flake8-builtins==1.5.3 +flake8-comprehensions==3.2.3 +flake8-mutable==1.2.0 +flake8-print==3.1.4 +flake8-quotes==3.2.0 +flake8-tuple==0.4.1 +pytest==5.4.3 +pytest-asyncio==0.12.0 +pytest-cov==2.9.0 +pytest-env==0.6.2 +pytest-sugar==0.9.3 +testfixtures==6.14.1 +bump2version==0.5.11 +tox==3.20.0 +wheel==0.33.6 +twine==1.14.0 +watchdog==0.9.0 diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..5fb6ce4 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,37 @@ +[bumpversion] +current_version = 2.0.0 +commit = True +tag = True + +[tool:pytest] +addopts = -s --strict -vv --cache-clear --maxfail=1 --cov=aioauth --cov-report=term --cov-report=html --cov-branch --no-cov-on-fail + +[isort] +multi_line_output = 3 +not_skip = __init__.py +include_trailing_comma = True +force_grid_wrap = 0 +use_parentheses = True +line_length = 88 +default_section = THIRDPARTY +known_first_party = ownauth + +[coverage:run] +branch = True +omit = + site-packages + src/aioauth/__version__.py + +[bumpversion:file:src/aioauth/__version__.py] +search = __version__ = '{current_version}' +replace = __version__ = '{new_version}' + +[bdist_wheel] +universal = 1 + +[flake8] +ignore = D10,E203,E501,W503,D205,D400,A001,D210,D401 +max-line-length = 88 +select = A,B,C4,D,E,F,M,Q,T,W,ABS,BLK +exclude = versions/* +inline-quotes = " diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..34874f9 --- /dev/null +++ b/setup.py @@ -0,0 +1,57 @@ +from pathlib import Path + +from setuptools import find_packages, setup + +here = Path(__file__).parent +about = {} + +with open(here / "src" / "aioauth" / "__version__.py", "r") as f: + exec(f.read(), about) + +with open("README.md") as readme_file: + readme = readme_file.read() + + +def read_requirements(path): + try: + with path.open(mode="rt", encoding="utf-8") as fp: + return list(filter(bool, (line.split("#")[0].strip() for line in fp))) + except IndexError: + raise RuntimeError(f"{path} is broken") + + +base_requirements = read_requirements(here / "requirements" / "base.txt") +test_requirements = read_requirements(here / "requirements" / "test.txt") + +setup( + name=about["__title__"], + version=about["__version__"], + description=about["__description__"], + long_description=readme, + long_description_content_type="text/markdown", + author=about["__author__"], + author_email=about["__author_email__"], + url=about["__url__"], + license=about["__license__"], + python_requires=">=3.6.0", + classifiers=[ + "Development Status :: 2 - Pre-Alpha", + "Intended Audience :: Developers", + "License :: OSI Approved :: MIT License", + "Natural Language :: English", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.6", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + ], + install_requires=base_requirements, + tests_require=test_requirements, + extras_require={"test": test_requirements}, + include_package_data=True, + keywords="aioauth", + packages=find_packages(where="src"), + package_dir={"": "src"}, + test_suite="tests", + zip_safe=False, + project_urls={"Source": about["__url__"]}, +) diff --git a/src/aioauth/__init__.py b/src/aioauth/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/aioauth/__version__.py b/src/aioauth/__version__.py new file mode 100644 index 0000000..a462624 --- /dev/null +++ b/src/aioauth/__version__.py @@ -0,0 +1,8 @@ +__title__ = "aioauth" +__description__ = "Asynchronous OAuth 2.0 framework for Python 3." +__url__ = "https://github.com/aliev/aioauth" +__version__ = "2.0.0" +__author__ = "Ali Aliyev" +__author_email__ = "ali@aliev.me" +__license__ = "The MIT License (MIT)" +__copyright__ = "Copyright 2020 Ali Aliyev" diff --git a/src/aioauth/base/__init__.py b/src/aioauth/base/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/aioauth/base/database.py b/src/aioauth/base/database.py new file mode 100644 index 0000000..40ab5a5 --- /dev/null +++ b/src/aioauth/base/database.py @@ -0,0 +1,135 @@ +import time +from typing import Optional + +from aioauth.types import CodeChallengeMethod, ResponseType + +from ..config import get_settings +from ..models import AuthorizationCode, Client, Token +from ..requests import Request +from ..utils import generate_token + + +class BaseDB: + async def create_token(self, request: Request, client_id: str, scope: str) -> Token: + """Generates Token model instance. + + Generated Token MUST be stored in database. + + Method is used by all core grant types. + Method is used by response types: + - ResponseTypeToken + """ + settings = get_settings() + return Token( + client_id=client_id, + expires_in=settings.TOKEN_EXPIRES_IN, + access_token=generate_token(42), + refresh_token=generate_token(48), + issued_at=int(time.time()), + scope=scope, + revoked=False, + ) + + async def get_token( + self, + request: Request, + client_id: str, + token: Optional[str] = None, + refresh_token: Optional[str] = None, + ) -> Optional[Token]: + """Gets existing token from the database + + Method is used by: + - create_token_introspection_response + Method is used by grant types: + - RefreshTokenGrantType + """ + raise NotImplementedError("Method get_token must be implemented") + + async def create_authorization_code( + self, + request: Request, + client_id: str, + scope: str, + response_type: ResponseType, + redirect_uri: str, + code_challenge_method: CodeChallengeMethod, + code_challenge: str, + ) -> AuthorizationCode: + """Generates AuthorizationCode model instance. + + Generated AuthorizationCode MUST be stored in database. + + Method is used by response types: + - ResponseTypeAuthorizationCode + """ + return AuthorizationCode( + code=generate_token(48), + client_id=client_id, + redirect_uri=redirect_uri, + response_type=response_type, + scope=scope, + auth_time=int(time.time()), + code_challenge_method=code_challenge_method, + code_challenge=code_challenge, + ) + + async def get_client( + self, request: Request, client_id: str, client_secret: Optional[str] = None + ) -> Optional[Client]: + """Gets existing Client from database. + + If client doesn't exists in database this method MUST return None + to indicate to the validator that the requested ``client_id`` does not exist or is invalid. + + Method is used by all core grant types. + Method is used by all core response types. + """ + raise NotImplementedError("Method get_client must be implemented") + + async def authenticate(self, request: Request) -> bool: + """Authenticate user. + + Method is used by grant types: + - PasswordGrantType + """ + raise NotImplementedError("Method authenticate must be implemented") + + async def get_authorization_code( + self, request: Request, client_id: str, code: str + ) -> Optional[AuthorizationCode]: + """Gets existing AuthorizationCode from database. + + If authorization code doesn't exists it MUST return None + to indicate to the validator that the requested authorization code does not exist or is invalid. + + Method is used by grant types: + - AuthorizationCodeGrantType + """ + raise NotImplementedError( + "Method get_authorization_code must be implemented for AuthorizationCodeGrantType" + ) + + async def delete_authorization_code( + self, request: Request, client_id: str, code: str + ): + """Deletes authorization code from database. + + Method is used by grant types: + - AuthorizationCodeGrantType + """ + raise NotImplementedError( + "Method delete_authorization_code must be implemented for AuthorizationCodeGrantType" + ) + + async def revoke_token(self, request: Request, token: str) -> None: + """Revokes token in database. + + This method MUST set `revoked` in True for existing token record. + + Method is used by grant types: + - RefreshTokenGrantType + """ + raise NotImplementedError( + "Method revoke_token must be implemented for RefreshTokenGrantType" + ) diff --git a/src/aioauth/base/endpoint.py b/src/aioauth/base/endpoint.py new file mode 100644 index 0000000..ac77284 --- /dev/null +++ b/src/aioauth/base/endpoint.py @@ -0,0 +1,35 @@ +from typing import Dict, Optional, Type, Union + +from ..grant_type import GrantTypeBase +from ..response_type import ResponseTypeBase +from ..types import EndpointType, GrantType, ResponseType +from .database import BaseDB + + +class BaseEndpoint: + response_type: Dict[Optional[ResponseType], Type[ResponseTypeBase]] = {} + grant_type: Dict[Optional[GrantType], Type[GrantTypeBase]] = {} + available: bool = True + + def __init__( + self, db: BaseDB, available: Optional[bool] = None, + ): + self.db = db + + if available is not None: + self.available = available + + def register( + self, + endpoint_type: EndpointType, + endpoint: Union[ResponseType, GrantType], + endpoint_cls: Union[Type[ResponseTypeBase], Type[GrantTypeBase]], + ): + endpoint_dict = getattr(self, endpoint_type) + endpoint_dict[endpoint] = endpoint_cls + + def unregister( + self, endpoint_type: EndpointType, endpoint: Union[ResponseType, GrantType] + ): + endpoint_dict = getattr(self, endpoint_type) + del endpoint_dict[endpoint] diff --git a/src/aioauth/base/request_validator.py b/src/aioauth/base/request_validator.py new file mode 100644 index 0000000..fd8b0b3 --- /dev/null +++ b/src/aioauth/base/request_validator.py @@ -0,0 +1,28 @@ +from aioauth.structures import CaseInsensitiveDict + +from ..constances import default_headers +from ..errors import InsecureTransportError, MethodNotAllowedError +from ..requests import Request +from ..types import RequestMethod +from ..utils import is_secure_transport +from .database import BaseDB + + +class BaseRequestValidator: + allowed_methods = [ + RequestMethod.GET, + RequestMethod.POST, + ] + + def __init__(self, db: BaseDB): + self.db = db + + async def validate_request(self, request: Request): + if not is_secure_transport(request.url): + raise InsecureTransportError() + + if request.method not in self.allowed_methods: + headers = CaseInsensitiveDict( + {**default_headers, "allow": ", ".join(self.allowed_methods)} + ) + raise MethodNotAllowedError(headers=headers) diff --git a/src/aioauth/config.py b/src/aioauth/config.py new file mode 100644 index 0000000..361b977 --- /dev/null +++ b/src/aioauth/config.py @@ -0,0 +1,25 @@ +import os +from typing import Any, Callable, NamedTuple + + +def get_env(env: str, default_value: Any, to_type: Callable, prefix="AIOAUTH_"): + """Get the value of an environment variable and apply a specific type to it""" + return to_type(os.environ.get(f"{prefix}{env}", default_value)) + + +class Settings(NamedTuple): + TOKEN_EXPIRES_IN: int + AUTHORIZATION_CODE_EXPIRES_IN: int + INSECURE_TRANSPORT: bool + ERROR_URI: str + + +def get_settings(): + return Settings( + TOKEN_EXPIRES_IN=get_env("TOKEN_EXPIRES_IN", 86400, int), + AUTHORIZATION_CODE_EXPIRES_IN=get_env( + "AUTHORIZATION_CODE_EXPIRES_IN", 300, int + ), + INSECURE_TRANSPORT=get_env("INSECURE_TRANSPORT", False, bool), + ERROR_URI=get_env("ERROR_URI", "", str), + ) diff --git a/src/aioauth/constances.py b/src/aioauth/constances.py new file mode 100644 index 0000000..6585eb1 --- /dev/null +++ b/src/aioauth/constances.py @@ -0,0 +1,20 @@ +from .structures import CaseInsensitiveDict + + +def _default_headers() -> CaseInsensitiveDict: + """The authorization server MUST include the HTTP "Cache-Control" + response header field [RFC2616] with a value of "no-store" in any + response containing tokens, credentials, or other sensitive + information, as well as the "Pragma" response header field [RFC2616] + with a value of "no-cache". + """ + return CaseInsensitiveDict( + { + "Content-Type": "application/json", + "Cache-Control": "no-store", + "Pragma": "no-cache", + } + ) + + +default_headers = _default_headers() diff --git a/src/aioauth/endpoints.py b/src/aioauth/endpoints.py new file mode 100644 index 0000000..be337ff --- /dev/null +++ b/src/aioauth/endpoints.py @@ -0,0 +1,69 @@ +from http import HTTPStatus + +from .base.endpoint import BaseEndpoint +from .constances import default_headers +from .grant_type import GrantTypeBase +from .requests import Request +from .response_type import ResponseTypeBase +from .responses import ( + Response, + TokenActiveIntrospectionResponse, + TokenInactiveIntrospectionResponse, +) +from .structures import CaseInsensitiveDict +from .types import ResponseType +from .utils import build_uri, catch_errors_and_unavailability, decode_auth_headers + + +class Endpoint(BaseEndpoint): + @catch_errors_and_unavailability + async def create_token_introspection_response(self, request: Request) -> Response: + client_id, _ = decode_auth_headers(request.headers.get("Authorization", "")) + + token = await self.db.get_token( + request=request, client_id=client_id, token=request.post.token + ) + + token_response = TokenInactiveIntrospectionResponse() + + if token and not token.is_expired: + token_response = TokenActiveIntrospectionResponse( + scope=token.scope, client_id=token.client_id, exp=token.expires_in + ) + + return Response( + content=token_response, status_code=HTTPStatus.OK, headers=default_headers + ) + + @catch_errors_and_unavailability + async def create_token_response(self, request: Request) -> Response: + grant_type_cls = self.grant_type.get(request.post.grant_type, GrantTypeBase) + grant_type_handler = grant_type_cls(self.db) + token_response = await grant_type_handler.create_token_response(request) + return Response( + content=token_response, status_code=HTTPStatus.OK, headers=default_headers + ) + + @catch_errors_and_unavailability + async def create_authorization_code_response(self, request: Request) -> Response: + response_type_cls = self.response_type.get( + request.query.response_type, ResponseTypeBase + ) + response_type_handler = response_type_cls(self.db) + authorization_code_response = await response_type_handler.create_authorization_code_response( + request + ) + response_type = request.query.response_type + response_dict = { + **authorization_code_response._asdict(), + "state": request.query.state, + } + query_params = response_dict if response_type == ResponseType.TYPE_CODE else {} + fragment = response_dict if response_type == ResponseType.TYPE_TOKEN else {} + + location = build_uri(request.query.redirect_uri, query_params, fragment) + + return Response( + status_code=HTTPStatus.FOUND, + headers=CaseInsensitiveDict({"location": location}), + ) diff --git a/src/aioauth/errors.py b/src/aioauth/errors.py new file mode 100644 index 0000000..23f7b85 --- /dev/null +++ b/src/aioauth/errors.py @@ -0,0 +1,150 @@ +from http import HTTPStatus +from typing import Optional +from urllib.parse import urljoin + +from .config import get_settings +from .constances import default_headers +from .structures import CaseInsensitiveDict +from .types import ErrorType + + +class OAuth2Error(Exception): + error: ErrorType + description: str = "" + status_code: HTTPStatus = HTTPStatus.BAD_REQUEST + error_uri: str = "" + headers: CaseInsensitiveDict = default_headers + + def __init__( + self, + description: Optional[str] = None, + headers: Optional[CaseInsensitiveDict] = None, + ): + settings = get_settings() + + if description is not None: + self.description = description + + if headers is not None: + self.headers = headers + + if settings.ERROR_URI: + self.error_uri = urljoin(settings.ERROR_URI, self.error) + + super().__init__(f"({self.error}) {self.description}") + + +class MethodNotAllowedError(OAuth2Error): + description = "HTTP method is not allowed." + status_code: HTTPStatus = HTTPStatus.METHOD_NOT_ALLOWED + error = ErrorType.METHOD_IS_NOT_ALLOWED + + +class InvalidRequestError(OAuth2Error): + """ + The request is missing a required parameter, includes an invalid + parameter value, includes a parameter more than once, or is + otherwise malformed. + """ + + error = ErrorType.INVALID_REQUEST + + +class InvalidClientError(OAuth2Error): + """ + Client authentication failed (e.g. unknown client, no client + authentication included, or unsupported authentication method). + The authorization server MAY return an HTTP 401 (Unauthorized) status + code to indicate which HTTP authentication schemes are supported. + If the client attempted to authenticate via the "Authorization" request + header field, the authorization server MUST respond with an + HTTP 401 (Unauthorized) status code, and include the "WWW-Authenticate" + response header field matching the authentication scheme used by the + client. + """ + + error = ErrorType.INVALID_CLIENT + status_code: HTTPStatus = HTTPStatus.UNAUTHORIZED + + +class InsecureTransportError(OAuth2Error): + description = "OAuth 2 MUST utilize https." + error = ErrorType.INSECURE_TRANSPORT + + +class UnsupportedGrantTypeError(OAuth2Error): + """ + The authorization grant type is not supported by the authorization + server. + """ + + error = ErrorType.UNSUPPORTED_GRANT_TYPE + + +class UnsupportedResponseTypeError(OAuth2Error): + """ + The authorization server does not support obtaining an authorization + code using this method. + """ + + error = ErrorType.UNSUPPORTED_RESPONSE_TYPE + + +class InvalidGrantError(OAuth2Error): + """ + The provided authorization grant (e.g. authorization code, resource + owner credentials) or refresh token is invalid, expired, revoked, does + not match the redirection URI used in the authorization request, or was + issued to another client. + + https://tools.ietf.org/html/rfc6749#section-5.2 + """ + + error = ErrorType.INVALID_GRANT + + +class MismatchingStateError(OAuth2Error): + description = "CSRF Warning! State not equal in request and response." + error = ErrorType.MISMATCHING_STATE + + +class UnauthorizedClientError(OAuth2Error): + """ + The authenticated client is not authorized to use this authorization + grant type. + """ + + error = ErrorType.UNAUTHORIZED_CLIENT + + +class InvalidScopeError(OAuth2Error): + """ + The requested scope is invalid, unknown, or malformed, or + exceeds the scope granted by the resource owner. + + https://tools.ietf.org/html/rfc6749#section-5.2 + """ + + error = ErrorType.INVALID_SCOPE + + +class ServerError(OAuth2Error): + """ + The authorization server encountered an unexpected condition that + prevented it from fulfilling the request. (This error code is needed + because a 500 Internal Server Error HTTP status code cannot be returned + to the client via a HTTP redirect.) + """ + + error = ErrorType.SERVER_ERROR + + +class TemporarilyUnavailableError(OAuth2Error): + """ + The authorization server is currently unable to handle the request + due to a temporary overloading or maintenance of the server. + (This error code is needed because a 503 Service Unavailable HTTP + status code cannot be returned to the client via a HTTP redirect.) + """ + + error = ErrorType.TEMPORARILY_UNAVAILABLE diff --git a/src/aioauth/grant_type.py b/src/aioauth/grant_type.py new file mode 100644 index 0000000..c898c8a --- /dev/null +++ b/src/aioauth/grant_type.py @@ -0,0 +1,158 @@ +from typing import Optional + +from .base.request_validator import BaseRequestValidator +from .errors import ( + InvalidGrantError, + InvalidRequestError, + InvalidScopeError, + MismatchingStateError, + UnauthorizedClientError, + UnsupportedGrantTypeError, +) +from .models import Client +from .requests import Request +from .responses import TokenResponse +from .types import GrantType, RequestMethod +from .utils import decode_auth_headers + + +class GrantTypeBase(BaseRequestValidator): + allowed_methods = [ + RequestMethod.POST, + ] + grant_type: Optional[GrantType] = None + + async def create_token_response(self, request: Request) -> TokenResponse: + client = await self.validate_request(request) + token = await self.db.create_token( + request, client.client_id, request.post.scope + ) + + return TokenResponse( + expires_in=token.expires_in, + refresh_token_expires_in=token.refresh_token_expires_in, + access_token=token.access_token, + refresh_token=token.refresh_token, + scope=token.scope, + token_type=token.token_type, + ) + + async def validate_request(self, request: Request) -> Client: + await super().validate_request(request) + + client_id, client_secret = decode_auth_headers( + request.headers.get("Authorization", "") + ) + + if not request.post.grant_type: + raise InvalidRequestError(description="Request is missing grant type.") + + if self.grant_type != request.post.grant_type: + raise UnsupportedGrantTypeError() + + client = await self.db.get_client( + request, client_id=client_id, client_secret=client_secret + ) + + if not client: + raise InvalidRequestError(description="Invalid client_id parameter value.") + + if not client.check_grant_type(request.post.grant_type): + raise UnauthorizedClientError() + + if not client.check_scope(request.post.scope): + raise InvalidScopeError() + + return client + + +class AuthorizationCodeGrantType(GrantTypeBase): + grant_type: GrantType = GrantType.TYPE_AUTHORIZATION_CODE + + async def validate_request(self, request: Request) -> Client: + client = await super().validate_request(request) + + if not request.post.redirect_uri: + raise InvalidRequestError(description="Mismatching redirect URI.") + + if not client.check_redirect_uri(request.post.redirect_uri): + raise InvalidRequestError(description="Invalid redirect URI.") + + if not request.post.code: + raise InvalidRequestError(description="Missing code parameter.") + + authorization_code = await self.db.get_authorization_code( + request, client.client_id, request.post.code + ) + + if not authorization_code: + raise InvalidGrantError() + + if ( + authorization_code.code_challenge + and authorization_code.code_challenge_method + ): + if not request.post.code_verifier: + raise InvalidRequestError(description="Code verifier required.") + + is_valid_code_challenge = authorization_code.check_code_challenge( + request.post.code_verifier + ) + if not is_valid_code_challenge: + raise MismatchingStateError() + + if authorization_code.is_expired: + raise InvalidGrantError() + + await self.db.delete_authorization_code( + request, client.client_id, request.post.code + ) + + return client + + +class PasswordGrantType(GrantTypeBase): + grant_type: GrantType = GrantType.TYPE_PASSWORD + + async def validate_request(self, request: Request) -> Client: + client = await super().validate_request(request) + + if not request.post.password or not request.post.password: + raise InvalidGrantError(description="Invalid credentials given.") + + user = await self.db.authenticate(request) + + if not user: + raise InvalidGrantError(description="Invalid credentials given.") + + return client + + +class RefreshTokenGrantType(GrantTypeBase): + grant_type: GrantType = GrantType.TYPE_REFRESH_TOKEN + + async def validate_request(self, request: Request) -> Client: + client = await super().validate_request(request) + + if not request.post.refresh_token: + raise InvalidRequestError(description="Missing refresh token parameter.") + + token = await self.db.get_token( + request=request, + client_id=client.client_id, + refresh_token=request.post.refresh_token, + ) + + if not token: + raise InvalidGrantError() + + if token.refresh_token_expired: + raise InvalidGrantError() + + await self.db.revoke_token(request, request.post.refresh_token) + + return client + + +class ClientCredentialsGrantType(GrantTypeBase): + grant_type: GrantType = GrantType.TYPE_CLIENT_CREDENTIALS diff --git a/src/aioauth/models.py b/src/aioauth/models.py new file mode 100644 index 0000000..3f77927 --- /dev/null +++ b/src/aioauth/models.py @@ -0,0 +1,100 @@ +import time +from typing import List, NamedTuple, Optional, Text + +from .config import get_settings +from .types import CodeChallengeMethod, GrantType, ResponseType +from .utils import create_s256_code_challenge, list_to_scope, scope_to_list + + +class ClientMetadata(NamedTuple): + grant_types: List[GrantType] = [] + response_types: List[ResponseType] = [] + redirect_uris: List[str] = [] + scope: Text = "" + + +class Client(NamedTuple): + client_id: Text + client_secret: Text + client_metadata: ClientMetadata + + def check_redirect_uri(self, redirect_uri) -> bool: + return redirect_uri in self.client_metadata.redirect_uris + + def check_grant_type(self, grant_type: GrantType) -> bool: + return grant_type in self.client_metadata.grant_types + + def check_response_type(self, response_type: ResponseType) -> bool: + return response_type in self.client_metadata.response_types + + def get_allowed_scope(self, scope) -> Text: + if not scope: + return "" + allowed = set(self.client_metadata.scope.split()) + scopes = scope_to_list(scope) + return list_to_scope([s for s in scopes if s in allowed]) + + def check_scope(self, scope: str) -> bool: + allowed_scope = self.get_allowed_scope(scope) + return not (set(scope_to_list(scope)) - set(scope_to_list(allowed_scope))) + + +class AuthorizationCode(NamedTuple): + code: Text + client_id: Text + redirect_uri: Text + response_type: ResponseType + scope: Text + auth_time: int + code_challenge: Optional[Text] = None + code_challenge_method: Optional[CodeChallengeMethod] = None + nonce: Optional[Text] = None + + def check_code_challenge(self, code_verifier: str) -> bool: + is_valid_code_challenge = False + + if self.code_challenge_method == CodeChallengeMethod.PLAIN: + # If the "code_challenge_method" was "plain", they are compared directly + is_valid_code_challenge = code_verifier == self.code_challenge + + if self.code_challenge_method == CodeChallengeMethod.S256: + # base64url(sha256(ascii(code_verifier))) == code_challenge + is_valid_code_challenge = ( + create_s256_code_challenge(code_verifier) == self.code_challenge + ) + + return is_valid_code_challenge + + @property + def is_expired(self) -> bool: + settings = get_settings() + return self.auth_time + settings.AUTHORIZATION_CODE_EXPIRES_IN < time.time() + + +class Token(NamedTuple): + access_token: Text + refresh_token: Text + scope: Text + issued_at: int + expires_in: int + client_id: Text + token_type: Text = "Bearer" + revoked: bool = False + + @property + def is_expired(self) -> bool: + return self.token_expires_in < time.time() + + @property + def refresh_token_expires_in(self) -> int: + expires_at = self.issued_at + self.expires_in * 2 + return expires_at + + @property + def token_expires_in(self) -> int: + expires_at = self.issued_at + self.expires_in + return expires_at + + @property + def refresh_token_expired(self) -> bool: + return self.refresh_token_expires_in < time.time() diff --git a/src/aioauth/requests.py b/src/aioauth/requests.py new file mode 100644 index 0000000..541876f --- /dev/null +++ b/src/aioauth/requests.py @@ -0,0 +1,35 @@ +from typing import Any, NamedTuple, Optional + +from .structures import CaseInsensitiveDict +from .types import CodeChallengeMethod, GrantType, RequestMethod, ResponseType + + +class Query(NamedTuple): + client_id: Optional[str] = None + redirect_uri: str = "" + response_type: Optional[ResponseType] = None + state: str = "" + scope: str = "" + code_challenge_method: Optional[CodeChallengeMethod] = None + code_challenge: Optional[str] = None + + +class Post(NamedTuple): + grant_type: Optional[GrantType] = None + redirect_uri: Optional[str] = None + scope: str = "" + username: Optional[str] = None + password: Optional[str] = None + refresh_token: Optional[str] = None + code: Optional[str] = None + token: Optional[str] = None + code_verifier: Optional[str] = None + + +class Request(NamedTuple): + method: RequestMethod + headers: CaseInsensitiveDict = CaseInsensitiveDict() + query: Query = Query() + post: Post = Post() + url: str = "" + user: Optional[Any] = None diff --git a/src/aioauth/response_type.py b/src/aioauth/response_type.py new file mode 100644 index 0000000..80860cf --- /dev/null +++ b/src/aioauth/response_type.py @@ -0,0 +1,110 @@ +from typing import Optional + +from .base.request_validator import BaseRequestValidator +from .errors import ( + InvalidClientError, + InvalidRequestError, + InvalidScopeError, + UnsupportedResponseTypeError, +) +from .models import Client +from .requests import Request +from .responses import AuthorizationCodeResponse, TokenResponse +from .types import CodeChallengeMethod, RequestMethod, ResponseType + + +class ResponseTypeBase(BaseRequestValidator): + response_type: Optional[ResponseType] = None + allowed_methods = [ + RequestMethod.GET, + ] + code_challenge_methods = list(CodeChallengeMethod) + + async def validate_request(self, request: Request) -> Client: + await super().validate_request(request) + + if not request.query.client_id: + raise InvalidRequestError(description="Missing client_id parameter.") + + client = await self.db.get_client( + request=request, client_id=request.query.client_id + ) + + if not client: + raise InvalidRequestError(description="Invalid client_id parameter value.") + + if not request.query.redirect_uri: + raise InvalidRequestError(description="Mismatching redirect URI.") + + if self.response_type != request.query.response_type: + raise UnsupportedResponseTypeError() + + if not client.check_redirect_uri(request.query.redirect_uri): + raise InvalidRequestError(description="Invalid redirect URI.") + + if not request.query.response_type: + raise InvalidRequestError(description="Missing response_type parameter.") + + if request.query.code_challenge_method: + if request.query.code_challenge_method not in self.code_challenge_methods: + raise InvalidRequestError( + description="Transform algorithm not supported." + ) + + if not request.query.code_challenge: + raise InvalidRequestError(description="Code challenge required.") + + if not client.check_response_type(request.query.response_type): + raise UnsupportedResponseTypeError() + + if not client.check_scope(request.query.scope): + raise InvalidScopeError() + + if not request.user: + raise InvalidClientError(description="User is not authorized") + + return client + + async def create_authorization_code_response(self, request: Request) -> Client: + return await self.validate_request(request) + + +class ResponseTypeToken(ResponseTypeBase): + response_type: ResponseType = ResponseType.TYPE_TOKEN + + async def create_authorization_code_response( + self, request: Request + ) -> TokenResponse: + client = await super().create_authorization_code_response(request) + token = await self.db.create_token( + request, client.client_id, request.query.scope + ) + return TokenResponse( + expires_in=token.expires_in, + refresh_token_expires_in=token.refresh_token_expires_in, + access_token=token.access_token, + refresh_token=token.refresh_token, + scope=token.scope, + token_type=token.token_type, + ) + + +class ResponseTypeAuthorizationCode(ResponseTypeBase): + response_type: ResponseType = ResponseType.TYPE_CODE + + async def create_authorization_code_response( + self, request: Request + ) -> AuthorizationCodeResponse: + client = await super().create_authorization_code_response(request) + authorization_code = await self.db.create_authorization_code( + request, + client.client_id, + request.query.scope, + request.query.response_type, # type: ignore + request.query.redirect_uri, + request.query.code_challenge_method, # type: ignore + request.query.code_challenge, # type: ignore + ) + return AuthorizationCodeResponse( + code=authorization_code.code, scope=authorization_code.scope, + ) diff --git a/src/aioauth/responses.py b/src/aioauth/responses.py new file mode 100644 index 0000000..0ec982c --- /dev/null +++ b/src/aioauth/responses.py @@ -0,0 +1,86 @@ +from http import HTTPStatus +from typing import NamedTuple, Optional, Union + +from .constances import default_headers +from .structures import CaseInsensitiveDict +from .types import ErrorType + + +class ErrorResponse(NamedTuple): + """Response for error. + + Used by response_types. + Used by grant_types. + """ + + error: ErrorType + description: str + error_uri: str = "" + + +class AuthorizationCodeResponse(NamedTuple): + """Response for authorization_code. + + Used by response_types: + - ResponseTypeAuthorizationCode + """ + + code: str + scope: str + + +class TokenResponse(NamedTuple): + """Response for token. + + Used by grant_types. + Used by response_types: + - ResponseTypeToken + """ + + expires_in: int + refresh_token_expires_in: int + access_token: str + refresh_token: str + scope: str + token_type: str = "Bearer" + + +class TokenActiveIntrospectionResponse(NamedTuple): + """Response for a valid access token. + + Used by token introspection endpoint. + """ + + scope: str + client_id: str + exp: int + active: bool = True + + +class TokenInactiveIntrospectionResponse(NamedTuple): + """For an invalid, revoked or expired token. + + Used by token introspection endpoint. + """ + + active: bool = False + + +class Response(NamedTuple): + """General response class. + + Used by: + - Endpoint + """ + + content: Optional[ + Union[ + ErrorResponse, + TokenResponse, + AuthorizationCodeResponse, + TokenActiveIntrospectionResponse, + TokenInactiveIntrospectionResponse, + ] + ] = None + status_code: HTTPStatus = HTTPStatus.OK + headers: CaseInsensitiveDict = default_headers diff --git a/src/aioauth/structures.py b/src/aioauth/structures.py new file mode 100644 index 0000000..c9c541f --- /dev/null +++ b/src/aioauth/structures.py @@ -0,0 +1,11 @@ +from collections import UserDict + + +class CaseInsensitiveDict(UserDict): + """A case-insensitive ``dict``-like object.""" + + def __setitem__(self, key, value): + super().__setitem__(key.lower(), value) + + def __getitem__(self, key): + return super().__getitem__(key.lower()) diff --git a/src/aioauth/types.py b/src/aioauth/types.py new file mode 100644 index 0000000..a8713c0 --- /dev/null +++ b/src/aioauth/types.py @@ -0,0 +1,43 @@ +from enum import Enum + + +class ErrorType(str, Enum): + INVALID_REQUEST = "invalid_request" + INVALID_CLIENT = "invalid_client" + INVALID_GRANT = "invalid_grant" + INVALID_SCOPE = "invalid_scope" + UNAUTHORIZED_CLIENT = "unauthorized_client" + UNSUPPORTED_GRANT_TYPE = "unsupported_grant_type" + UNSUPPORTED_RESPONSE_TYPE = "unsupported_response_type" + INSECURE_TRANSPORT = "insecure_transport" + MISMATCHING_STATE = "mismatching_state" + METHOD_IS_NOT_ALLOWED = "method_is_not_allowed" + SERVER_ERROR = "server_error" + TEMPORARILY_UNAVAILABLE = "temporarily_unavailable" + + +class GrantType(str, Enum): + TYPE_AUTHORIZATION_CODE = "authorization_code" + TYPE_PASSWORD = "password" + TYPE_CLIENT_CREDENTIALS = "client_credentials" + TYPE_REFRESH_TOKEN = "refresh_token" + + +class ResponseType(str, Enum): + TYPE_TOKEN = "token" + TYPE_CODE = "code" + + +class EndpointType(str, Enum): + GRANT_TYPE = "grant_type" + RESPONSE_TYPE = "response_type" + + +class RequestMethod(str, Enum): + GET = "GET" + POST = "POST" + + +class CodeChallengeMethod(str, Enum): + PLAIN = "plain" + S256 = "S256" diff --git a/src/aioauth/utils.py b/src/aioauth/utils.py new file mode 100644 index 0000000..11d87b7 --- /dev/null +++ b/src/aioauth/utils.py @@ -0,0 +1,170 @@ +import base64 +import binascii +import functools +import hashlib +import logging +import random +import string +from base64 import b64decode, b64encode +from typing import Callable, Dict, List, Optional, Set, Text, Tuple, Union +from urllib.parse import quote, urlencode, urlparse, urlunsplit + +from .config import get_settings +from .errors import ( + InvalidClientError, + OAuth2Error, + ServerError, + TemporarilyUnavailableError, +) +from .responses import ErrorResponse, Response +from .structures import CaseInsensitiveDict + +UNICODE_ASCII_CHARACTER_SET = string.ascii_letters + string.digits + + +log = logging.getLogger(__name__) + + +def is_secure_transport(uri: str) -> bool: + """Check if the uri is over ssl.""" + settings = get_settings() + + if settings.INSECURE_TRANSPORT: + return True + return uri.lower().startswith("https://") + + +def get_authorization_scheme_param( + authorization_header_value: Text, +) -> Tuple[Text, Text]: + if not authorization_header_value: + return "", "" + scheme, _, param = authorization_header_value.partition(" ") + return scheme, param + + +def list_to_scope(scope: Optional[List] = None) -> Text: + """Convert a list of scopes to a space separated string.""" + if isinstance(scope, str) or scope is None: + return "" + elif isinstance(scope, (set, tuple, list)): + return " ".join([str(s) for s in scope]) + else: + raise ValueError( + "Invalid scope (%s), must be string, tuple, set, or list." % scope + ) + + +def scope_to_list(scope: Union[Text, List, Set, Tuple]) -> List: + """Convert a space separated string to a list of scopes.""" + if isinstance(scope, (tuple, list, set)): + return [str(s) for s in scope] + elif scope is None: + return [] + else: + return scope.strip().split(" ") + + +def generate_token(length: int = 30, chars: str = UNICODE_ASCII_CHARACTER_SET) -> str: + """Generates a non-guessable OAuth token + + OAuth (1 and 2) does not specify the format of tokens except that they + should be strings of random characters. Tokens should not be guessable + and entropy when generating the random characters is important. Which is + why SystemRandom is used instead of the default random.choice method. + """ + rand = random.SystemRandom() + return "".join(rand.choice(chars) for _ in range(length)) + + +def build_uri( + url: str, query_params: Optional[Dict] = None, fragment: Optional[Dict] = None +) -> str: + """Build uri string from given url, query_params and fragment""" + if query_params is None: + query_params = {} + + if fragment is None: + fragment = {} + + parsed_url = urlparse(url) + uri = urlunsplit( + ( + parsed_url.scheme, + parsed_url.netloc, + parsed_url.path, + urlencode(query_params, quote_via=quote), + urlencode(fragment, quote_via=quote), + ) + ) + return uri + + +def encode_auth_headers(client_id: str, client_secret: str) -> CaseInsensitiveDict: + authorization = b64encode(f"{client_id}:{client_secret}".encode("ascii")) + return CaseInsensitiveDict(Authorization=f"basic {authorization.decode()}") + + +def decode_auth_headers(authorization: str) -> Tuple[str, str]: + """Decode an encrypted HTTP basic authentication string. Returns a tuple of + the form (client_id, client_secret), and raises a InvalidClientError exception if + nothing could be decoded. + """ + headers = CaseInsensitiveDict({"WWW-Authenticate": "Basic"}) + + scheme, param = get_authorization_scheme_param(authorization) + if not authorization or scheme.lower() != "basic": + raise InvalidClientError(headers=headers) + + try: + data = b64decode(param).decode("ascii") + except (ValueError, UnicodeDecodeError, binascii.Error): + raise InvalidClientError(headers=headers) + + client_id, separator, client_secret = data.partition(":") + + if not separator: + raise InvalidClientError(headers=headers) + + return client_id, client_secret + + +def create_s256_code_challenge(code_verifier: str) -> str: + """Create S256 code_challenge with the given code_verifier. + + Implements: + base64url(sha256(ascii(code_verifier))) + """ + code_verifier_bytes = code_verifier.encode("utf-8") + data = hashlib.sha256(code_verifier_bytes).digest() + return base64.urlsafe_b64encode(data).rstrip(b"=").decode() + + +def catch_errors_and_unavailability(f) -> Callable: + @functools.wraps(f) + async def wrapper(endpoint, *args, **kwargs) -> Optional[Response]: + if not endpoint.available: + error = TemporarilyUnavailableError() + content = ErrorResponse(error=error.error, description=error.description) + return Response( + content=content, status_code=error.status_code, headers=error.headers + ) + + try: + response = await f(endpoint, *args, **kwargs) + return response + except OAuth2Error as exc: + content = ErrorResponse(error=exc.error, description=exc.description) + log.debug(exc) + return Response( + content=content, status_code=exc.status_code, headers=exc.headers + ) + except Exception: + error = ServerError() + log.exception("Exception caught while processing request.") + content = ErrorResponse(error=error.error, description=error.description) + return Response( + content=content, status_code=error.status_code, headers=error.headers, + ) + + return wrapper diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/classes.py b/tests/classes.py new file mode 100644 index 0000000..b57c005 --- /dev/null +++ b/tests/classes.py @@ -0,0 +1,130 @@ +from typing import Dict, List, Optional + +from aioauth.base.database import BaseDB +from aioauth.models import AuthorizationCode, Client, Token +from aioauth.requests import Request +from aioauth.types import CodeChallengeMethod, ResponseType +from tests.utils import set_values + +from .models import Defaults + + +class DB(BaseDB): + storage: Dict[str, List] + defaults: Defaults + + def _get_by_client_secret(self, client_id: str, client_secret: str): + clients: List[Client] = self.storage.get("clients", []) + + for client in clients: + if client.client_id == client_id and client.client_secret == client_secret: + return client + + def _get_by_client_id(self, client_id: str): + clients: List[Client] = self.storage.get("clients", []) + + for client in clients: + if client.client_id == client_id: + return client + + async def get_client( + self, request: Request, client_id: str, client_secret: Optional[str] = None + ) -> Optional[Client]: + if client_secret is not None: + return self._get_by_client_secret(client_id, client_secret) + + return self._get_by_client_id(client_id) + + async def create_token(self, request: Request, client_id: str, scope: str) -> Token: + token = await super().create_token(request, client_id, scope) + self.storage["tokens"].append(token) + return token + + async def revoke_token(self, request: Request, token: str) -> None: + tokens: List[Token] = self.storage.get("tokens", []) + for key, token_ in enumerate(tokens): + if token_.refresh_token == token: + tokens[key] = set_values(token_, {"revoked": True}) + + async def get_token( + self, + request: Request, + client_id: str, + token: Optional[str] = None, + refresh_token: Optional[str] = None, + ) -> Optional[Token]: + tokens: List[Token] = self.storage.get("tokens", []) + for token_ in tokens: + if ( + refresh_token is not None + and refresh_token == token_.refresh_token + and client_id == token_.client_id + ): + return token_ + if ( + token is not None + and token == token_.access_token + and client_id == token_.client_id + ): + return token_ + + async def authenticate(self, request: Request) -> Optional[bool]: + if ( + request.post.username == self.defaults.username + and request.post.password == self.defaults.password + ): + return True + + async def create_authorization_code( + self, + request: Request, + client_id: str, + scope: str, + response_type: ResponseType, + redirect_uri: str, + code_challenge_method: CodeChallengeMethod, + code_challenge: str, + ) -> AuthorizationCode: + authorization_code = await super().create_authorization_code( + request, + client_id, + scope, + response_type, + redirect_uri, + code_challenge_method, + code_challenge, + ) + self.storage["authorization_codes"].append(authorization_code) + return authorization_code + + async def get_authorization_code( + self, request: Request, client_id: str, code: str + ) -> Optional[AuthorizationCode]: + authorization_codes: List[AuthorizationCode] = self.storage.get( + "authorization_codes", [] + ) + for authorization_code in authorization_codes: + if ( + authorization_code.code == code + and authorization_code.client_id == client_id + ): + return authorization_code + + async def delete_authorization_code( + self, request: Request, client_id: str, code: str, + ): + authorization_codes: List[AuthorizationCode] = self.storage.get( + "authorization_codes", [] + ) + for authorization_code in authorization_codes: + if ( + authorization_code.client_id == client_id + and authorization_code.code == code + ): + authorization_codes.remove(authorization_code) + + +def get_db_class(defaults: Defaults, storage: Dict[str, List]): + DB.storage = storage + DB.defaults = defaults + return DB diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..0eda69b --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,126 @@ +import time +from typing import Dict, Type + +import pytest +from aioauth.base.database import BaseDB +from aioauth.config import get_settings +from aioauth.endpoints import Endpoint +from aioauth.grant_type import ( + AuthorizationCodeGrantType, + ClientCredentialsGrantType, + PasswordGrantType, + RefreshTokenGrantType, +) +from aioauth.models import AuthorizationCode, Client, ClientMetadata, Token +from aioauth.response_type import ResponseTypeAuthorizationCode, ResponseTypeToken +from aioauth.types import CodeChallengeMethod, EndpointType, GrantType, ResponseType +from aioauth.utils import generate_token + +from .classes import get_db_class +from .models import Defaults + + +@pytest.fixture +def defaults() -> Defaults: + return Defaults( + client_id=generate_token(48), + client_secret=generate_token(48), + code=generate_token(5), + refresh_token=generate_token(48), + access_token=generate_token(42), + username="root", + password="toor", + redirect_uri="https://ownauth.com/callback", + scope="read write", + ) + + +@pytest.fixture +def storage(defaults: Defaults) -> Dict: + settings = get_settings() + + client_metadata = ClientMetadata( + grant_types=[ + GrantType.TYPE_AUTHORIZATION_CODE, + GrantType.TYPE_CLIENT_CREDENTIALS, + GrantType.TYPE_REFRESH_TOKEN, + GrantType.TYPE_PASSWORD, + ], + redirect_uris=[defaults.redirect_uri], + response_types=[ResponseType.TYPE_CODE, ResponseType.TYPE_TOKEN], + scope=defaults.scope, + ) + + client = Client( + client_id=defaults.client_id, + client_secret=defaults.client_secret, + client_metadata=client_metadata, + ) + + authorization_code = AuthorizationCode( + code=defaults.code, + client_id=defaults.client_id, + response_type=ResponseType.TYPE_CODE, + auth_time=int(time.time()), + redirect_uri=defaults.redirect_uri, + scope=defaults.scope, + code_challenge_method=CodeChallengeMethod.PLAIN, + ) + + token = Token( + client_id=defaults.client_id, + expires_in=settings.TOKEN_EXPIRES_IN, + access_token=defaults.access_token, + refresh_token=defaults.refresh_token, + issued_at=int(time.time()), + scope=defaults.scope, + ) + + return { + "tokens": [token], + "authorization_codes": [authorization_code], + "clients": [client], + } + + +@pytest.fixture +def db_class(defaults: Defaults, storage) -> Type[BaseDB]: + return get_db_class(defaults, storage) + + +@pytest.fixture +def db(db_class: Type[BaseDB]): + return db_class() + + +@pytest.fixture +def endpoint(db: BaseDB) -> Endpoint: + endpoint = Endpoint(db=db) + # Register response type endpoints + endpoint.register( + EndpointType.RESPONSE_TYPE, ResponseType.TYPE_TOKEN, ResponseTypeToken, + ) + endpoint.register( + EndpointType.RESPONSE_TYPE, + ResponseType.TYPE_CODE, + ResponseTypeAuthorizationCode, + ) + + # Register grant type endpoints + endpoint.register( + EndpointType.GRANT_TYPE, + GrantType.TYPE_AUTHORIZATION_CODE, + AuthorizationCodeGrantType, + ) + endpoint.register( + EndpointType.GRANT_TYPE, + GrantType.TYPE_CLIENT_CREDENTIALS, + ClientCredentialsGrantType, + ) + endpoint.register( + EndpointType.GRANT_TYPE, GrantType.TYPE_PASSWORD, PasswordGrantType, + ) + endpoint.register( + EndpointType.GRANT_TYPE, GrantType.TYPE_REFRESH_TOKEN, RefreshTokenGrantType, + ) + return endpoint diff --git a/tests/models.py b/tests/models.py new file mode 100644 index 0000000..788188d --- /dev/null +++ b/tests/models.py @@ -0,0 +1,13 @@ +from typing import NamedTuple, Text + + +class Defaults(NamedTuple): + client_id: Text + client_secret: Text + code: Text + refresh_token: Text + access_token: Text + username: Text + password: Text + redirect_uri: Text + scope: Text diff --git a/tests/test_db.py b/tests/test_db.py new file mode 100644 index 0000000..76683de --- /dev/null +++ b/tests/test_db.py @@ -0,0 +1,42 @@ +from typing import Dict, List + +import pytest +from aioauth.base.database import BaseDB +from aioauth.models import AuthorizationCode, Client, Token +from aioauth.requests import Request +from aioauth.types import RequestMethod + + +@pytest.mark.asyncio +async def test_db(storage: Dict[str, List]): + db = BaseDB() + request = Request(method=RequestMethod.POST) + client: Client = storage["clients"][0] + token: Token = storage["tokens"][0] + authorization_code: AuthorizationCode = storage["authorization_codes"][0] + + with pytest.raises(NotImplementedError): + await db.get_token( + request=request, + client_id=client.client_id, + token=token.access_token, + refresh_token=token.refresh_token, + ) + with pytest.raises(NotImplementedError): + await db.get_client( + request=request, + client_id=client.client_id, + client_secret=client.client_secret, + ) + with pytest.raises(NotImplementedError): + await db.authenticate(request=request) + with pytest.raises(NotImplementedError): + await db.get_authorization_code( + request=request, client_id=client.client_id, code=authorization_code.code + ) + with pytest.raises(NotImplementedError): + await db.delete_authorization_code( + request=request, client_id=client.client_id, code=authorization_code.code + ) + with pytest.raises(NotImplementedError): + await db.revoke_token(request=request, token=token.access_token) diff --git a/tests/test_endpoint.py b/tests/test_endpoint.py new file mode 100644 index 0000000..5f182cc --- /dev/null +++ b/tests/test_endpoint.py @@ -0,0 +1,123 @@ +import time +from http import HTTPStatus +from typing import Dict, List, Optional, Type + +import pytest +from aioauth.base.database import BaseDB +from aioauth.config import get_settings +from aioauth.endpoints import Endpoint +from aioauth.models import Token +from aioauth.requests import Post, Request +from aioauth.types import EndpointType, ErrorType, GrantType, RequestMethod +from aioauth.utils import ( + catch_errors_and_unavailability, + encode_auth_headers, + generate_token, +) + +from .models import Defaults + + +@pytest.mark.asyncio +async def test_internal_server_error(): + class EndpointClass: + available: Optional[bool] = True + + def __init__(self, available: Optional[bool] = None): + if available is not None: + self.available = available + + @catch_errors_and_unavailability + async def endpoint(self): + raise Exception() + + e = EndpointClass() + response = await e.endpoint() + assert response.status_code == HTTPStatus.BAD_REQUEST + + +@pytest.mark.asyncio +async def test_invalid_token(endpoint: Endpoint, defaults: Defaults): + client_id = defaults.client_id + client_secret = defaults.client_secret + request_url = "https://localhost" + token = "invalid token" + + post = Post(token=token) + request = Request( + url=request_url, + post=post, + method=RequestMethod.POST, + headers=encode_auth_headers(client_id, client_secret), + ) + response = await endpoint.create_token_introspection_response(request) + assert not response.content.active + assert response.status_code == HTTPStatus.OK + + +@pytest.mark.asyncio +async def test_expired_token( + endpoint: Endpoint, storage: Dict[str, List], defaults: Defaults +): + settings = get_settings() + token = Token( + client_id=defaults.client_id, + expires_in=settings.TOKEN_EXPIRES_IN, + access_token=generate_token(42), + refresh_token=generate_token(48), + issued_at=int(time.time() - settings.TOKEN_EXPIRES_IN), + scope=defaults.scope, + ) + + client_id = defaults.client_id + client_secret = defaults.client_secret + + storage["tokens"].append(token) + + post = Post(token=token.access_token) + request = Request( + post=post, + method=RequestMethod.POST, + headers=encode_auth_headers(client_id, client_secret), + ) + + response = await endpoint.create_token_introspection_response(request) + assert response.status_code == HTTPStatus.OK + assert not response.content.active + + +@pytest.mark.asyncio +async def test_valid_token( + endpoint: Endpoint, storage: Dict[str, List], defaults: Defaults +): + client_id = defaults.client_id + client_secret = defaults.client_secret + + token = storage["tokens"][0] + + post = Post(token=token.access_token) + request = Request( + post=post, + method=RequestMethod.POST, + headers=encode_auth_headers(client_id, client_secret), + ) + + response = await endpoint.create_token_introspection_response(request) + assert response.status_code == HTTPStatus.OK + assert response.content.active + + +@pytest.mark.asyncio +async def test_unregister_endpoint(endpoint: Endpoint): + assert endpoint.grant_type.get(GrantType.TYPE_AUTHORIZATION_CODE) is not None + endpoint.unregister(EndpointType.GRANT_TYPE, GrantType.TYPE_AUTHORIZATION_CODE) + assert endpoint.grant_type.get(GrantType.TYPE_AUTHORIZATION_CODE) is None + + +@pytest.mark.asyncio +async def test_endpoint_availability(db_class: Type[BaseDB]): + endpoint = Endpoint(db=db_class(), available=False) + request = Request(method=RequestMethod.POST) + response = await endpoint.create_token_introspection_response(request) + assert response.status_code == HTTPStatus.BAD_REQUEST + assert response.content.error == ErrorType.TEMPORARILY_UNAVAILABLE diff --git a/tests/test_flow.py b/tests/test_flow.py new file mode 100644 index 0000000..b00b8c4 --- /dev/null +++ b/tests/test_flow.py @@ -0,0 +1,279 @@ +from http import HTTPStatus +from urllib.parse import parse_qsl, urlparse + +import pytest +from aioauth.base.database import BaseDB +from aioauth.constances import default_headers +from aioauth.endpoints import Endpoint +from aioauth.requests import Post, Query, Request +from aioauth.types import CodeChallengeMethod, GrantType, RequestMethod, ResponseType +from aioauth.utils import ( + create_s256_code_challenge, + encode_auth_headers, + generate_token, +) + +from .conftest import Defaults +from .utils import check_request_validators + + +@pytest.mark.asyncio +async def test_authorization_code_flow_plan_code_challenge( + endpoint: Endpoint, defaults: Defaults, db: BaseDB +): + code_challenge = generate_token(128) + client_id = defaults.client_id + client_secret = defaults.client_secret + scope = defaults.scope + state = generate_token(10) + redirect_uri = defaults.redirect_uri + request_url = "https://localhost" + user = "username" + + query = Query( + client_id=defaults.client_id, + response_type=ResponseType.TYPE_CODE, + redirect_uri=redirect_uri, + scope=scope, + state=state, + code_challenge_method=CodeChallengeMethod.PLAIN, + code_challenge=code_challenge, + ) + + request = Request( + url=request_url, query=query, method=RequestMethod.GET, user=user, + ) + + await check_request_validators(request, endpoint.create_authorization_code_response) + response = await endpoint.create_authorization_code_response(request) + assert response.status_code == HTTPStatus.FOUND + location = response.headers["location"] + location = urlparse(location) + query = dict(parse_qsl(location.query)) + assert query["state"] == state + assert query["scope"] == scope + assert await db.get_authorization_code(request, client_id, query["code"]) + assert "code" in query + + location = response.headers["location"] + location = urlparse(location) + query = dict(parse_qsl(location.query)) + code = query["code"] + + post = Post( + grant_type=GrantType.TYPE_AUTHORIZATION_CODE, + redirect_uri=defaults.redirect_uri, + code=code, + code_verifier=code_challenge, + scope=scope, + ) + + request = Request( + url=request_url, + post=post, + method=RequestMethod.POST, + headers=encode_auth_headers(client_id, client_secret), + ) + + response = await endpoint.create_token_response(request) + assert response.status_code == HTTPStatus.OK + assert response.headers == default_headers + assert response.content.scope == scope + assert response.content.token_type == "Bearer" + # Check that token was created in db + assert await db.get_token( + request, + client_id, + response.content.access_token, + response.content.refresh_token, + ) + + access_token = response.content.access_token + refresh_token = response.content.refresh_token + + post = Post(grant_type=GrantType.TYPE_REFRESH_TOKEN, refresh_token=refresh_token,) + + request = Request( + url=request_url, + post=post, + method=RequestMethod.POST, + headers=encode_auth_headers(client_id, client_secret), + ) + await check_request_validators(request, endpoint.create_token_response) + response = await endpoint.create_token_response(request) + + assert response.status_code == HTTPStatus.OK + assert response.content.access_token != access_token + assert response.content.refresh_token != refresh_token + # Check that token was created in db + assert await db.get_token( + request, + client_id, + response.content.access_token, + response.content.refresh_token, + ) + # Check that previous token was revoken + token_in_db = await db.get_token(request, client_id, access_token, refresh_token) + assert token_in_db.revoked + + +@pytest.mark.asyncio +async def test_authorization_code_flow_pkce_code_challenge( + endpoint: Endpoint, defaults: Defaults, db: BaseDB +): + client_id = defaults.client_id + client_secret = defaults.client_secret + code_verifier = generate_token(128) + scope = defaults.scope + code_challenge = create_s256_code_challenge(code_verifier) + request_url = "https://localhost" + user = "username" + state = generate_token(10) + + query = Query( + client_id=defaults.client_id, + response_type=ResponseType.TYPE_CODE, + redirect_uri=defaults.redirect_uri, + scope=scope, + state=state, + code_challenge_method=CodeChallengeMethod.S256, + code_challenge=code_challenge, + ) + + request = Request( + url=request_url, query=query, method=RequestMethod.GET, user=user, + ) + response = await endpoint.create_authorization_code_response(request) + assert response.status_code == HTTPStatus.FOUND + location = response.headers["location"] + location = urlparse(location) + query = dict(parse_qsl(location.query)) + assert query["state"] == state + assert query["scope"] == scope + assert "code" in query + code = query["code"] + + post = Post( + grant_type=GrantType.TYPE_AUTHORIZATION_CODE, + redirect_uri=defaults.redirect_uri, + code=code, + scope=scope, + code_verifier=code_verifier, + ) + + request = Request( + url=request_url, + post=post, + method=RequestMethod.POST, + headers=encode_auth_headers(client_id, client_secret), + ) + + await check_request_validators(request, endpoint.create_token_response) + + code_record = await db.get_authorization_code(request, client_id, code) + assert code_record + + response = await endpoint.create_token_response(request) + assert response.status_code == HTTPStatus.OK + assert response.headers == default_headers + assert response.content.scope == scope + assert response.content.token_type == "Bearer" + + code_record = await db.get_authorization_code(request, client_id, code) + assert not code_record + + +@pytest.mark.asyncio +async def test_implicit_flow(endpoint: Endpoint, defaults: Defaults): + request_url = "https://localhost" + state = generate_token(10) + scope = defaults.scope + user = "username" + + query = Query( + client_id=defaults.client_id, + response_type=ResponseType.TYPE_TOKEN, + redirect_uri=defaults.redirect_uri, + scope=scope, + state=state, + ) + + request = Request( + url=request_url, query=query, method=RequestMethod.GET, user=user, + ) + + response = await endpoint.create_authorization_code_response(request) + assert response.status_code == HTTPStatus.FOUND + location = response.headers["location"] + location = urlparse(location) + fragment = dict(parse_qsl(location.fragment)) + assert fragment["state"] == state + assert fragment["scope"] == scope + + +@pytest.mark.asyncio +async def test_password_grant_type(endpoint: Endpoint, defaults: Defaults): + client_id = defaults.client_id + client_secret = defaults.client_secret + request_url = "https://localhost" + + post = Post( + grant_type=GrantType.TYPE_PASSWORD, + username=defaults.username, + password=defaults.password, + ) + + request = Request( + post=post, + url=request_url, + method=RequestMethod.POST, + headers=encode_auth_headers(client_id, client_secret), + ) + + await check_request_validators(request, endpoint.create_token_response) + response = await endpoint.create_token_response(request) + assert response.status_code == HTTPStatus.OK + + +@pytest.mark.asyncio +async def test_authorization_code_flow(endpoint: Endpoint, defaults: Defaults): + client_id = defaults.client_id + client_secret = defaults.client_secret + request_url = "https://localhost" + user = "username" + + query = Query( + client_id=defaults.client_id, + response_type=ResponseType.TYPE_CODE, + redirect_uri=defaults.redirect_uri, + scope=defaults.scope, + state=generate_token(10), + ) + + request = Request( + url=request_url, query=query, method=RequestMethod.GET, user=user, + ) + + response = await endpoint.create_authorization_code_response(request) + assert response.status_code == HTTPStatus.FOUND + + location = response.headers["location"] + location = urlparse(location) + query = dict(parse_qsl(location.query)) + code = query["code"] + + post = Post( + grant_type=GrantType.TYPE_AUTHORIZATION_CODE, + redirect_uri=defaults.redirect_uri, + code=code, + ) + + request = Request( + url=request_url, + post=post, + method=RequestMethod.POST, + headers=encode_auth_headers(client_id, client_secret), + ) + + response = await endpoint.create_token_response(request) + assert response.status_code == HTTPStatus.OK diff --git a/tests/test_request_validator.py b/tests/test_request_validator.py new file mode 100644 index 0000000..61d9f0a --- /dev/null +++ b/tests/test_request_validator.py @@ -0,0 +1,236 @@ +import time +from http import HTTPStatus +from typing import Dict, List + +import pytest +from aioauth.config import get_settings +from aioauth.endpoints import Endpoint +from aioauth.models import Client +from aioauth.requests import Post, Query, Request +from aioauth.types import ( + CodeChallengeMethod, + ErrorType, + GrantType, + RequestMethod, + ResponseType, +) +from aioauth.utils import ( + create_s256_code_challenge, + encode_auth_headers, + generate_token, +) + +from .models import Defaults +from .utils import set_values + + +@pytest.mark.asyncio +async def test_insecure_transport_error(endpoint: Endpoint): + request_url = "http://localhost" + + request = Request(url=request_url, method=RequestMethod.GET,) + + response = await endpoint.create_authorization_code_response(request) + assert response.status_code == HTTPStatus.BAD_REQUEST + + +@pytest.mark.asyncio +async def test_allowed_methods(endpoint: Endpoint): + request_url = "https://localhost" + + request = Request(url=request_url, method=RequestMethod.POST,) + + response = await endpoint.create_authorization_code_response(request) + assert response.status_code == HTTPStatus.METHOD_NOT_ALLOWED + + +@pytest.mark.asyncio +async def test_invalid_client_credentials(endpoint: Endpoint, defaults: Defaults): + client_id = defaults.client_id + request_url = "https://localhost" + + post = Post( + grant_type=GrantType.TYPE_PASSWORD, + username=defaults.username, + password=defaults.password, + ) + + request = Request( + post=post, + url=request_url, + method=RequestMethod.POST, + headers=encode_auth_headers(client_id, "client_secret"), + ) + + response = await endpoint.create_token_response(request) + assert response.status_code == HTTPStatus.BAD_REQUEST + assert response.content.error == ErrorType.INVALID_REQUEST + + +@pytest.mark.asyncio +async def test_invalid_scope(endpoint: Endpoint, defaults: Defaults): + client_id = defaults.client_id + client_secret = defaults.client_secret + request_url = "https://localhost" + + post = Post( + grant_type=GrantType.TYPE_PASSWORD, + username=defaults.username, + password=defaults.password, + scope="test test", + ) + + request = Request( + post=post, + url=request_url, + method=RequestMethod.POST, + headers=encode_auth_headers(client_id, client_secret), + ) + + response = await endpoint.create_token_response(request) + assert response.status_code == HTTPStatus.BAD_REQUEST + assert response.content.error == ErrorType.INVALID_SCOPE + + +@pytest.mark.asyncio +async def test_invalid_grant_type(endpoint: Endpoint, defaults: Defaults, storage): + client: Client = storage["clients"][0] + + client_metadata = set_values( + client.client_metadata, {"grant_types": [GrantType.TYPE_AUTHORIZATION_CODE]} + ) + + client = set_values(client, {"client_metadata": client_metadata}) + + storage["clients"][0] = client + + client_id = defaults.client_id + client_secret = defaults.client_secret + request_url = "https://localhost" + + post = Post( + grant_type=GrantType.TYPE_PASSWORD, + username=defaults.username, + password=defaults.password, + scope="test test", + ) + + request = Request( + post=post, + url=request_url, + method=RequestMethod.POST, + headers=encode_auth_headers(client_id, client_secret), + ) + + response = await endpoint.create_token_response(request) + assert response.status_code == HTTPStatus.BAD_REQUEST + assert response.content.error == ErrorType.UNAUTHORIZED_CLIENT + + +@pytest.mark.asyncio +async def test_invalid_response_type(endpoint: Endpoint, defaults: Defaults, storage): + code_verifier = generate_token(128) + code_challenge = create_s256_code_challenge(code_verifier) + request_url = "https://localhost" + user = "username" + + client = storage["clients"][0] + + client_metadata = set_values( + client.client_metadata, {"response_types": [ResponseType.TYPE_TOKEN]} + ) + client = set_values(client, {"client_metadata": client_metadata}) + + storage["clients"][0] = client + + query = Query( + client_id=defaults.client_id, + response_type=ResponseType.TYPE_CODE, + redirect_uri=defaults.redirect_uri, + scope=defaults.scope, + state=generate_token(10), + code_challenge_method=CodeChallengeMethod.S256, + code_challenge=code_challenge, + ) + + request = Request( + url=request_url, query=query, method=RequestMethod.GET, user=user, + ) + response = await endpoint.create_authorization_code_response(request) + assert response.status_code == HTTPStatus.BAD_REQUEST + assert response.content.error == ErrorType.UNSUPPORTED_RESPONSE_TYPE + + +@pytest.mark.asyncio +async def test_anonymous_user(endpoint: Endpoint, defaults: Defaults, storage): + code_verifier = generate_token(128) + code_challenge = create_s256_code_challenge(code_verifier) + request_url = "https://localhost" + + query = Query( + client_id=defaults.client_id, + response_type=ResponseType.TYPE_CODE, + redirect_uri=defaults.redirect_uri, + scope=defaults.scope, + state=generate_token(10), + code_challenge_method=CodeChallengeMethod.S256, + code_challenge=code_challenge, + ) + + request = Request(url=request_url, query=query, method=RequestMethod.GET) + response = await endpoint.create_authorization_code_response(request) + assert response.status_code == HTTPStatus.UNAUTHORIZED + assert response.content.error == ErrorType.INVALID_CLIENT + + +@pytest.mark.asyncio +async def test_expired_authorization_code( + endpoint: Endpoint, defaults: Defaults, storage: Dict[str, List] +): + request_url = "https://localhost" + + settings = get_settings() + + authorization_code = storage["authorization_codes"][0] + storage["authorization_codes"][0] = set_values( + authorization_code, + {"auth_time": time.time() - settings.AUTHORIZATION_CODE_EXPIRES_IN}, + ) + post = Post( + grant_type=GrantType.TYPE_AUTHORIZATION_CODE, + redirect_uri=defaults.redirect_uri, + code=storage["authorization_codes"][0].code, + ) + + request = Request( + url=request_url, + post=post, + method=RequestMethod.POST, + headers=encode_auth_headers(defaults.client_id, defaults.client_secret), + ) + response = await endpoint.create_token_response(request) + assert response.status_code == HTTPStatus.BAD_REQUEST + assert response.content.error == ErrorType.INVALID_GRANT + + +@pytest.mark.asyncio +async def test_expired_refresh_token( + endpoint: Endpoint, defaults: Defaults, storage: Dict[str, List] +): + settings = get_settings() + token = storage["tokens"][0] + refresh_token = token.refresh_token + storage["tokens"][0] = set_values( + token, {"issued_at": time.time() - (settings.TOKEN_EXPIRES_IN * 2)} + ) + request_url = "https://localhost" + post = Post(grant_type=GrantType.TYPE_REFRESH_TOKEN, refresh_token=refresh_token,) + request = Request( + url=request_url, + post=post, + method=RequestMethod.POST, + headers=encode_auth_headers(defaults.client_id, defaults.client_secret), + ) + response = await endpoint.create_token_response(request) + assert response.status_code == HTTPStatus.BAD_REQUEST + assert response.content.error == ErrorType.INVALID_GRANT diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..66b3c3c --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,61 @@ +from urllib.parse import urljoin + +import pytest +from aioauth.errors import InvalidClientError +from aioauth.utils import ( + build_uri, + decode_auth_headers, + get_authorization_scheme_param, + is_secure_transport, + list_to_scope, + scope_to_list, +) + + +def test_is_secure_transport(monkeypatch): + monkeypatch.setenv("AIOAUTH_INSECURE_TRANSPORT", "1") + + is_secure = is_secure_transport("https://google.com") + assert is_secure + + is_secure = is_secure_transport("http://google.com") + assert is_secure + + +def test_get_authorization_scheme_param(): + assert get_authorization_scheme_param("") == ("", "") + + +def test_list_to_scope(): + assert list_to_scope("") == "" # type: ignore + assert list_to_scope(["read", "write"]) == "read write" + with pytest.raises(ValueError): + list_to_scope(1) # type: ignore + + +def test_scope_to_list(): + assert scope_to_list("read write") == ["read", "write"] + assert scope_to_list(["read", "write"]) == ["read", "write"] + assert scope_to_list(None) == [] # type: ignore + + +def test_build_uri(): + build_uri("https://google.com") == "https://google.com" + + +def test_decode_auth_headers(): + with pytest.raises(InvalidClientError): + decode_auth_headers("") + + with pytest.raises(InvalidClientError): + decode_auth_headers("authorization") + + +def test_base_error_uri(monkeypatch): + ERROR_URI = "https://google.com" + monkeypatch.setenv("AIOAUTH_ERROR_URI", ERROR_URI) + + try: + raise InvalidClientError() + except InvalidClientError as exc: + assert urljoin(ERROR_URI, exc.error) == exc.error_uri diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 0000000..aa383b0 --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,235 @@ +from http import HTTPStatus +from typing import Any, Callable, Dict, Union + +from aioauth.constances import default_headers +from aioauth.requests import Post, Query, Request +from aioauth.responses import ErrorResponse, Response +from aioauth.types import ErrorType, RequestMethod + +EMPTY_KEYS = { + RequestMethod.GET: { + "client_id": Response( + content=ErrorResponse( + error=ErrorType.INVALID_REQUEST, + description="Missing client_id parameter.", + ), + status_code=HTTPStatus.BAD_REQUEST, + headers=default_headers, + ), + "response_type": Response( + content=ErrorResponse( + error=ErrorType.INVALID_REQUEST, + description="Missing response_type parameter.", + ), + status_code=HTTPStatus.BAD_REQUEST, + headers=default_headers, + ), + "redirect_uri": Response( + content=ErrorResponse( + error=ErrorType.INVALID_REQUEST, + description="Mismatching redirect URI.", + ), + status_code=HTTPStatus.BAD_REQUEST, + headers=default_headers, + ), + "code_challenge": Response( + content=ErrorResponse( + error=ErrorType.INVALID_REQUEST, description="Code challenge required.", + ), + status_code=HTTPStatus.BAD_REQUEST, + headers=default_headers, + ), + }, + RequestMethod.POST: { + "grant_type": Response( + content=ErrorResponse( + error=ErrorType.INVALID_REQUEST, + description="Request is missing grant type.", + ), + status_code=HTTPStatus.BAD_REQUEST, + headers=default_headers, + ), + "redirect_uri": Response( + content=ErrorResponse( + error=ErrorType.INVALID_REQUEST, + description="Mismatching redirect URI.", + ), + status_code=HTTPStatus.BAD_REQUEST, + headers=default_headers, + ), + "code": Response( + content=ErrorResponse( + error=ErrorType.INVALID_REQUEST, description="Missing code parameter.", + ), + status_code=HTTPStatus.BAD_REQUEST, + headers=default_headers, + ), + "refresh_token": Response( + content=ErrorResponse( + error=ErrorType.INVALID_REQUEST, + description="Missing refresh token parameter.", + ), + status_code=HTTPStatus.BAD_REQUEST, + headers=default_headers, + ), + "code_verifier": Response( + content=ErrorResponse( + error=ErrorType.INVALID_REQUEST, description="Code verifier required.", + ), + status_code=HTTPStatus.BAD_REQUEST, + headers=default_headers, + ), + "username": Response( + content=ErrorResponse( + error=ErrorType.INVALID_GRANT, description="Invalid credentials given.", + ), + status_code=HTTPStatus.BAD_REQUEST, + headers=default_headers, + ), + "password": Response( + content=ErrorResponse( + error=ErrorType.INVALID_GRANT, description="Invalid credentials given.", + ), + status_code=HTTPStatus.BAD_REQUEST, + headers=default_headers, + ), + }, +} + +INVALID_KEYS = { + RequestMethod.GET: { + "client_id": Response( + content=ErrorResponse( + error=ErrorType.INVALID_REQUEST, + description="Invalid client_id parameter value.", + ), + status_code=HTTPStatus.BAD_REQUEST, + headers=default_headers, + ), + "response_type": Response( + content=ErrorResponse( + error=ErrorType.UNSUPPORTED_RESPONSE_TYPE, description="", + ), + status_code=HTTPStatus.BAD_REQUEST, + headers=default_headers, + ), + "redirect_uri": Response( + content=ErrorResponse( + error=ErrorType.INVALID_REQUEST, description="Invalid redirect URI.", + ), + status_code=HTTPStatus.BAD_REQUEST, + headers=default_headers, + ), + "code_challenge_method": Response( + content=ErrorResponse( + error=ErrorType.INVALID_REQUEST, + description="Transform algorithm not supported.", + ), + status_code=HTTPStatus.BAD_REQUEST, + headers=default_headers, + ), + "scope": Response( + content=ErrorResponse(error=ErrorType.INVALID_SCOPE, description="",), + status_code=HTTPStatus.BAD_REQUEST, + headers=default_headers, + ), + }, + RequestMethod.POST: { + "grant_type": Response( + content=ErrorResponse( + error=ErrorType.UNSUPPORTED_GRANT_TYPE, description="", + ), + status_code=HTTPStatus.BAD_REQUEST, + headers=default_headers, + ), + "redirect_uri": Response( + content=ErrorResponse( + error=ErrorType.INVALID_REQUEST, description="Invalid redirect URI.", + ), + status_code=HTTPStatus.BAD_REQUEST, + headers=default_headers, + ), + "code": Response( + content=ErrorResponse(error=ErrorType.INVALID_GRANT, description="",), + status_code=HTTPStatus.BAD_REQUEST, + headers=default_headers, + ), + "code_verifier": Response( + content=ErrorResponse( + error=ErrorType.MISMATCHING_STATE, + description="CSRF Warning! State not equal in request and response.", + ), + status_code=HTTPStatus.BAD_REQUEST, + headers=default_headers, + ), + "refresh_token": Response( + content=ErrorResponse(error=ErrorType.INVALID_GRANT, description="",), + status_code=HTTPStatus.BAD_REQUEST, + headers=default_headers, + ), + "username": Response( + content=ErrorResponse( + error=ErrorType.INVALID_GRANT, description="Invalid credentials given.", + ), + status_code=HTTPStatus.BAD_REQUEST, + headers=default_headers, + ), + "password": Response( + content=ErrorResponse( + error=ErrorType.INVALID_GRANT, description="Invalid credentials given.", + ), + status_code=HTTPStatus.BAD_REQUEST, + headers=default_headers, + ), + }, +} + + +def get_keys(query: Union[Query, Post]) -> Dict[str, Any]: + """Converts dataclass object to dict and returns dict without empty values""" + return {key: value for key, value in query._asdict().items() if bool(value)} + + +def set_values(model, values): + """Sets NamedTuple instance value and returns new NamedTuple""" + return model.__class__(**{**model._asdict(), **values}) + + +async def check_query_values( + request: Request, responses, query_dict: Dict, endpoint_func, value +): + keys = set(query_dict.keys()) & set(responses.keys()) + + for key in keys: + request_ = request + + if request_.method == RequestMethod.POST: + post = set_values(request_.post, {key: value}) + request_ = set_values(request_, {"post": post}) + + if request_.method == RequestMethod.GET: + query = set_values(request_.query, {key: value}) + request_ = set_values(request_, {"query": query}) + + response_expected = responses[key] + response_actual = await endpoint_func(request_) + + assert response_expected == response_actual + + +async def check_request_validators( + request: Request, endpoint_func: Callable, +): + query_dict = {} + + if request.method == RequestMethod.POST: + query_dict = get_keys(request.post) + + if request.method == RequestMethod.GET: + query_dict = get_keys(request.query) + + responses = EMPTY_KEYS[request.method] + await check_query_values(request, responses, query_dict, endpoint_func, None) + + responses = INVALID_KEYS[request.method] + await check_query_values(request, responses, query_dict, endpoint_func, "invalid") diff --git a/tox.ini b/tox.ini new file mode 100644 index 0000000..1b76d3b --- /dev/null +++ b/tox.ini @@ -0,0 +1,19 @@ +[tox] +envlist = py36, py37, py38, flake8 + +[testenv:flake8] +basepython = python +deps = flake8 +commands = flake8 src/aioauth tests + +[testenv] +setenv = + PYTHONPATH = {toxinidir} +deps = + -r{toxinidir}/requirements/test.txt +; If you want to make tox run the tests with the same versions, create a +; requirements.txt with the pinned versions and uncomment the following line: +; -r{toxinidir}/requirements.txt +commands = + pip install -U pip + pytest --basetemp={envtmpdir} tests