From 6207afa3f754d21f9b8b63c04cf5e4b54960c859 Mon Sep 17 00:00:00 2001 From: Giovanni Barillari Date: Mon, 14 Oct 2024 23:44:22 +0200 Subject: [PATCH] Support Emmett 2.6 (#25) --- .github/workflows/publish.yml | 28 +- .github/workflows/tests.yml | 24 +- .gitignore | 4 +- Makefile | 19 ++ README.md | 3 - emmett_rest/__init__.py | 2 +- emmett_rest/__version__.py | 2 +- emmett_rest/ext.py | 117 ++++--- emmett_rest/helpers.py | 49 ++- emmett_rest/openapi/api.py | 61 ++-- emmett_rest/openapi/generation.py | 485 ++++++++--------------------- emmett_rest/openapi/helpers.py | 10 +- emmett_rest/openapi/mod.py | 42 +-- emmett_rest/openapi/schemas.py | 10 +- emmett_rest/parsers.py | 31 +- emmett_rest/queries/errors.py | 14 +- emmett_rest/queries/helpers.py | 31 +- emmett_rest/queries/parser.py | 171 ++++------ emmett_rest/queries/validation.py | 89 +++--- emmett_rest/rest.py | 428 +++++++------------------ emmett_rest/serializers.py | 14 +- emmett_rest/typing.py | 10 +- emmett_rest/wrappers.py | 26 +- pyproject.toml | 100 ++++-- tests/conftest.py | 33 +- tests/test_endpoints.py | 139 +++------ tests/test_endpoints_additional.py | 56 ++-- tests/test_envelopes.py | 63 ++-- tests/test_meta.py | 67 ++-- tests/test_queries.py | 348 +++++++-------------- 30 files changed, 914 insertions(+), 1562 deletions(-) create mode 100644 Makefile diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index da55b56..234020a 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -7,20 +7,24 @@ on: jobs: publish: runs-on: ubuntu-latest + environment: + name: pypi + url: https://pypi.org/p/emmett-rest + permissions: + id-token: write steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: - python-version: '3.10' - - name: Install and configure Poetry - uses: gi0baro/setup-poetry-bin@v1 - with: - virtualenvs-in-project: true - - name: Publish + python-version: 3.12 + - name: Install uv + uses: astral-sh/setup-uv@v3 + - name: Build distributions run: | - poetry build - poetry publish - env: - POETRY_PYPI_TOKEN_PYPI: ${{ secrets.PYPI_TOKEN }} + uv build + - name: Publish package to pypi + uses: pypa/gh-action-pypi-publish@release/v1 + with: + skip-existing: true diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index dba2e3f..bc57f94 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -3,17 +3,19 @@ name: Tests on: push: branches: - - "**" - tags-ignore: - - "**" + - "master" pull_request: + types: [opened, synchronize] + branches: + - master jobs: Linux: runs-on: ubuntu-latest strategy: + fail-fast: false matrix: - python-version: [3.8, 3.9, '3.10', '3.11'] + python-version: [3.8, 3.9, '3.10', '3.11', '3.12', '3.13'] services: postgres: @@ -26,18 +28,16 @@ jobs: options: --health-cmd pg_isready --health-interval 10s --health-timeout 5s --health-retries 5 steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - - name: Install and configure Poetry - uses: gi0baro/setup-poetry-bin@v1.3 - with: - virtualenvs-in-project: true + - name: Install uv + uses: astral-sh/setup-uv@v3 - name: Install dependencies run: | - poetry install -v + uv sync --dev - name: Test run: | - poetry run pytest -v tests + uv run pytest -v tests diff --git a/.gitignore b/.gitignore index f47107c..427f783 100644 --- a/.gitignore +++ b/.gitignore @@ -8,8 +8,8 @@ __pycache__ build/* dist/* -Emmett_REST.egg-info/* -poetry.lock +*.egg-info/* +uv.lock tests/databases/* tests/logs/* diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..fc647d6 --- /dev/null +++ b/Makefile @@ -0,0 +1,19 @@ +.DEFAULT_GOAL := all +pysources = emmett_rest tests + +.PHONY: format +format: + ruff check --fix $(pysources) + ruff format $(pysources) + +.PHONY: lint +lint: + ruff check $(pysources) + ruff format --check $(pysources) + +.PHONY: test +test: + pytest -v tests + +.PHONY: all +all: format lint test diff --git a/README.md b/README.md index 545c464..57170cd 100644 --- a/README.md +++ b/README.md @@ -2,9 +2,6 @@ Emmett-REST is a REST extension for [Emmett framework](https://emmett.sh). -[![pip version](https://img.shields.io/pypi/v/emmett-rest.svg?style=flat)](https://pypi.python.org/pypi/Emmett-REST) -![Tests Status](https://github.com/emmett-framework/rest/workflows/Tests/badge.svg) - ## In a nutshell ```python diff --git a/emmett_rest/__init__.py b/emmett_rest/__init__.py index bd51e2a..2a74572 100644 --- a/emmett_rest/__init__.py +++ b/emmett_rest/__init__.py @@ -1,4 +1,4 @@ from .ext import REST +from .parsers import Parser, parse_params, parse_params_with_parser from .rest import RESTModule -from .parsers import Parser, parse_params_with_parser, parse_params from .serializers import Serializer, serialize diff --git a/emmett_rest/__version__.py b/emmett_rest/__version__.py index 5197c5f..e4adfb8 100644 --- a/emmett_rest/__version__.py +++ b/emmett_rest/__version__.py @@ -1 +1 @@ -__version__ = "1.5.2" +__version__ = "1.6.0" diff --git a/emmett_rest/ext.py b/emmett_rest/ext.py index e4e6dd0..bea7bb5 100644 --- a/emmett_rest/ext.py +++ b/emmett_rest/ext.py @@ -1,12 +1,12 @@ # -*- coding: utf-8 -*- """ - emmett_rest.ext - --------------- +emmett_rest.ext +--------------- - Provides REST extension for Emmett +Provides REST extension for Emmett - :copyright: 2017 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2017 Giovanni Barillari +:license: BSD-3-Clause """ from typing import Any, Dict, List, Optional, Type, Union @@ -16,66 +16,55 @@ from emmett.orm.models import MetaModel from .openapi.mod import OpenAPIModule -from .rest import AppModule, RESTModule from .parsers import Parser +from .rest import AppModule, RESTModule from .serializers import Serializer -from .wrappers import ( - wrap_method_on_obj, - wrap_module_from_app, - wrap_module_from_module, - wrap_module_from_modulegroup -) +from .wrappers import wrap_method_on_obj, wrap_module_from_app, wrap_module_from_module, wrap_module_from_modulegroup class REST(Extension): - default_config = dict( - default_module_class=RESTModule, - default_serializer=Serializer, - default_parser=Parser, - page_param='page', - pagesize_param='page_size', - sort_param='sort_by', - query_param='where', - min_pagesize=1, - max_pagesize=50, - default_pagesize=20, - default_sort=None, - base_path='/', - id_path='/', - list_envelope='data', - single_envelope=False, - groups_envelope='data', - use_envelope_on_parse=False, - serialize_meta=True, - meta_envelope='meta', - default_enabled_methods=[ - 'index', 'create', 'read', 'update', 'delete' - ], - default_disabled_methods=[], - use_save=True, - use_destroy=True - ) + default_config = { + "default_module_class": RESTModule, + "default_serializer": Serializer, + "default_parser": Parser, + "page_param": "page", + "pagesize_param": "page_size", + "sort_param": "sort_by", + "query_param": "where", + "min_pagesize": 1, + "max_pagesize": 50, + "default_pagesize": 20, + "default_sort": None, + "base_path": "/", + "id_path": "/", + "list_envelope": "data", + "single_envelope": False, + "groups_envelope": "data", + "use_envelope_on_parse": False, + "serialize_meta": True, + "meta_envelope": "meta", + "default_enabled_methods": ["index", "create", "read", "update", "delete"], + "default_disabled_methods": [], + "use_save": True, + "use_destroy": True, + } def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - from .serializers import serialize from .parsers import parse_params + from .serializers import serialize + self._serialize = serialize self._parse_params = parse_params @listen_signal(Signals.before_database) def _configure_models_attr(self): - MetaModel._inheritable_dict_attrs_.append( - ('rest_rw', {'id': (True, False)}) - ) + MetaModel._inheritable_dict_attrs_.append(("rest_rw", {"id": (True, False)})) def on_load(self): - setattr(AppModule, 'rest_module', wrap_module_from_module(self)) - setattr(AppModuleGroup, 'rest_module', wrap_module_from_modulegroup(self)) - self.app.rest_module = wrap_method_on_obj( - wrap_module_from_app(self), - self.app - ) + AppModule.rest_module = wrap_module_from_module(self) + AppModuleGroup.rest_module = wrap_module_from_modulegroup(self) + self.app.rest_module = wrap_method_on_obj(wrap_module_from_app(self), self.app) @property def module(self): @@ -109,7 +98,7 @@ def docs_module( url_prefix: Optional[str] = None, hostname: Optional[str] = None, module_class: Optional[Type[OpenAPIModule]] = None, - **kwargs: Any + **kwargs: Any, ): module_class = module_class or OpenAPIModule return module_class.from_app( @@ -127,19 +116,19 @@ def docs_module( pipeline=[], injectors=[], opts={ - 'title': title, - 'version': version, - 'modules_tree_prefix': modules_tree_prefix, - 'description': description, - 'tags': tags, - 'servers': servers, - 'terms_of_service': terms_of_service, - 'contact': contact, - 'license_info': license_info, - 'security_schemes': security_schemes, - 'produce_schemas': produce_schemas, - 'expose_ui': expose_ui, - 'ui_path': ui_path + "title": title, + "version": version, + "modules_tree_prefix": modules_tree_prefix, + "description": description, + "tags": tags, + "servers": servers, + "terms_of_service": terms_of_service, + "contact": contact, + "license_info": license_info, + "security_schemes": security_schemes, + "produce_schemas": produce_schemas, + "expose_ui": expose_ui, + "ui_path": ui_path, }, - **kwargs + **kwargs, ) diff --git a/emmett_rest/helpers.py b/emmett_rest/helpers.py index 18db1b2..324c431 100644 --- a/emmett_rest/helpers.py +++ b/emmett_rest/helpers.py @@ -1,21 +1,23 @@ # -*- coding: utf-8 -*- """ - emmett_rest.helpers - ------------------- +emmett_rest.helpers +------------------- - Provides helpers +Provides helpers - :copyright: 2017 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2017 Giovanni Barillari +:license: BSD-3-Clause """ from __future__ import annotations + from typing import TYPE_CHECKING, TypeVar from emmett import request, response from emmett.pipeline import Pipe from emmett.routing.router import RoutingCtx + if TYPE_CHECKING: from .rest import RESTModule @@ -44,37 +46,33 @@ def __init__(self, mod): class SetFetcher(ModulePipe): async def pipe_request(self, next_pipe, **kwargs): - kwargs['dbset'] = self.mod._fetcher_method() + kwargs["dbset"] = self.mod._fetcher_method() return await next_pipe(**kwargs) class RecordFetcher(ModulePipe): async def pipe_request(self, next_pipe, **kwargs): self.fetch_record(kwargs) - if not kwargs['row']: + if not kwargs["row"]: response.status = 404 return self.mod.error_404() return await next_pipe(**kwargs) def fetch_record(self, kwargs): - kwargs['row'] = self.mod._select_method( - kwargs['dbset'].where(self.mod.model.id == kwargs['rid'])) - del kwargs['rid'] - del kwargs['dbset'] + kwargs["row"] = self.mod._select_method(kwargs["dbset"].where(self.mod.model.id == kwargs["rid"])) + del kwargs["rid"] + del kwargs["dbset"] class FieldPipe(ModulePipe): - def __init__(self, mod, accepted_attr_name, arg='field'): + def __init__(self, mod, accepted_attr_name, arg="field"): super().__init__(mod) self.accepted_attr_name = accepted_attr_name self.arg_name = arg self.set_accepted() def set_accepted(self): - self._accepted_dict = { - val: self.mod.model.table[val] - for val in getattr(self.mod, self.accepted_attr_name) - } + self._accepted_dict = {val: self.mod.model.table[val] for val in getattr(self.mod, self.accepted_attr_name)} async def pipe_request(self, next_pipe, **kwargs): field = self._accepted_dict.get(kwargs[self.arg_name]) @@ -86,13 +84,7 @@ async def pipe_request(self, next_pipe, **kwargs): class FieldsPipe(ModulePipe): - def __init__( - self, - mod, - accepted_attr_name, - query_param_name='fields', - arg='fields' - ): + def __init__(self, mod, accepted_attr_name, query_param_name="fields", arg="fields"): super().__init__(mod) self.accepted_attr_name = accepted_attr_name self.param_name = query_param_name @@ -104,11 +96,8 @@ def set_accepted(self): def parse_fields(self): pfields = ( - ( - isinstance(request.query_params[self.param_name], str) and - request.query_params[self.param_name] - ) or '' - ).split(',') + (isinstance(request.query_params[self.param_name], str) and request.query_params[self.param_name]) or "" + ).split(",") sfields = self._accepted_set & set(pfields) return [self.mod.model.table[key] for key in sfields] @@ -116,8 +105,6 @@ async def pipe_request(self, next_pipe, **kwargs): fields = self.parse_fields() if not fields: response.status = 400 - return self.mod.build_error_400({ - self.param_name: 'invalid value' - }) + return self.mod.build_error_400({self.param_name: "invalid value"}) kwargs[self.arg_name] = fields return await next_pipe(**kwargs) diff --git a/emmett_rest/openapi/api.py b/emmett_rest/openapi/api.py index 372df21..bb1c97a 100644 --- a/emmett_rest/openapi/api.py +++ b/emmett_rest/openapi/api.py @@ -1,35 +1,25 @@ # -*- coding: utf-8 -*- """ - emmett_rest.openapi.api - ----------------------- +emmett_rest.openapi.api +----------------------- - Provides OpenAPI user-facing helpers +Provides OpenAPI user-facing helpers - :copyright: 2017 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2017 Giovanni Barillari +:license: BSD-3-Clause """ from __future__ import annotations -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Dict, - List, - Optional, - Tuple, - Type, - TypeVar, - Union -) +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union from pydantic import BaseModel as Schema + if TYPE_CHECKING: from ..parsers import Parser - from ..serializers import Serializer from ..rest import RESTModule + from ..serializers import Serializer T = TypeVar("T") @@ -65,101 +55,100 @@ def schema(self, obj: Schema) -> Callable[[T], T]: def deco(f: T) -> T: f._openapi_def_schema = obj return f + return deco def fields(self, **specs: Union[Type, Tuple[Type, Any]]) -> Callable[[T], T]: def deco(f: T) -> T: f._openapi_def_fields = {**getattr(f, "_openapi_def_fields", {}), **specs} return f + return deco def request( self, content: Optional[str] = None, fields: Dict[str, Union[Type, Tuple[Type, Any]]] = {}, - files: List[str] = [] + files: List[str] = [], ) -> Callable[[T], T]: def deco(f: T) -> T: - f._openapi_def_request = { - "content": content or "application/json", - "fields": fields, - "files": files - } + f._openapi_def_request = {"content": content or "application/json", "fields": fields, "files": files} return f + return deco def response( self, status_code: int = 200, content: Optional[str] = None, - fields: Dict[str, Union[Type, Tuple[Type, Any]]] = {} + fields: Dict[str, Union[Type, Tuple[Type, Any]]] = {}, ) -> Callable[[T], T]: def deco(f: T) -> T: f._openapi_def_responses = { **getattr(f, "_openapi_def_responses", {}), - **{ - str(status_code): { - "content": content or "application/json", - "fields": fields - } - } + **{str(status_code): {"content": content or "application/json", "fields": fields}}, } return f + return deco def response_default_errors(self, *error_codes: int) -> Callable[[T], T]: def deco(f: T) -> T: f._openapi_def_response_codes = [str(err) for err in error_codes] return f + return deco def parser(self, parser: Parser): def deco(f: T) -> T: f._openapi_def_parser = parser return f + return deco def serializer(self, serializer: Serializer): def deco(f: T) -> T: f._openapi_def_serializer = serializer return f + return deco class OpenAPIDescribe: - def __call__( - self, - summary: str, - description: str = "" - ) -> Callable[[T], T]: + def __call__(self, summary: str, description: str = "") -> Callable[[T], T]: def deco(f: T) -> T: f._openapi_desc_summary = summary f._openapi_desc_description = description return f + return deco def summary(self, description: str) -> Callable[[T], T]: def deco(f: T) -> T: f._openapi_desc_summary = description return f + return deco def description(self, description: str) -> Callable[[T], T]: def deco(f: T) -> T: f._openapi_desc_description = description return f + return deco def request(self, description: str) -> Callable[[T], T]: def deco(f: T) -> T: f._openapi_desc_request = description return f + return deco def response(self, description: str) -> Callable[[T], T]: def deco(f: T) -> T: f._openapi_desc_response = description return f + return deco diff --git a/emmett_rest/openapi/generation.py b/emmett_rest/openapi/generation.py index c98e1bd..27fa78b 100644 --- a/emmett_rest/openapi/generation.py +++ b/emmett_rest/openapi/generation.py @@ -1,18 +1,17 @@ # -*- coding: utf-8 -*- """ - emmett_rest.openapi.generation - ------------------------------ +emmett_rest.openapi.generation +------------------------------ - Provides OpenAPI generation functions +Provides OpenAPI generation functions - :copyright: 2017 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2017 Giovanni Barillari +:license: BSD-3-Clause """ import datetime import decimal import re - from collections import defaultdict from enum import Enum from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union, get_type_hints @@ -22,24 +21,18 @@ from pydantic.fields import FieldInfo, ModelField from pydantic.schema import field_schema, model_process_schema +from ..parsers import Parser from ..rest import RESTModule from ..serializers import Serializer -from ..parsers import Parser from .schemas import OpenAPI + REF_PREFIX = "#/components/schemas/" _re_path_param = re.compile(r"<(\w+)\:(\w+)>") _re_path_param_optional = re.compile(r"\(([^<]+)?<(\w+)\:(\w+)>\)\?") _pydantic_baseconf = BaseConfig() -_path_param_types_map = { - "alpha": str, - "any": str, - "date": str, - "float": float, - "int": int, - "str": str -} +_path_param_types_map = {"alpha": str, "any": str, "date": str, "float": float, "int": int, "str": str} _model_field_types_map = { "id": int, "string": str, @@ -59,7 +52,7 @@ "password": str, "upload": str, "list:string": List[str], - "list:int": List[int] + "list:int": List[int], } _def_summaries = { "index": "List {entity}", @@ -69,7 +62,7 @@ "delete": "Delete {entity}", "sample": "List random {entity}", "group": "Group {entity}", - "stats": "Retrieve {entity} stats" + "stats": "Retrieve {entity} stats", } _def_descriptions = { "index": "Returns a list of {entity}", @@ -79,7 +72,7 @@ "delete": "Delete specific {entity} with the given identifier", "sample": "Returns random selection of {entity}", "group": "Counts {entity} grouped by the given attribute", - "stats": "Returns {entity} stats for the specified attributes" + "stats": "Returns {entity} stats for the specified attributes", } @@ -104,36 +97,11 @@ class StatModel(BaseModel): avg: Union[int, float] -_error_schema = model_process_schema( - ErrorsModel, - model_name_map={}, - ref_prefix=None -)[0] +_error_schema = model_process_schema(ErrorsModel, model_name_map={}, ref_prefix=None)[0] _def_errors = { - "400": { - "description": "Bad request", - "content": { - "application/json": { - "schema": _error_schema - } - } - }, - "404": { - "description": "Resource not found", - "content": { - "application/json": { - "schema": _error_schema - } - } - }, - "422": { - "description": "Unprocessable request", - "content": { - "application/json": { - "schema": _error_schema - } - } - } + "400": {"description": "Bad request", "content": {"application/json": {"schema": _error_schema}}}, + "404": {"description": "Resource not found", "content": {"application/json": {"schema": _error_schema}}}, + "422": {"description": "Unprocessable request", "content": {"application/json": {"schema": _error_schema}}}, } @@ -145,7 +113,7 @@ def _defs_from_item(obj: Any, key: str): rv.update(_defs_from_pydantic_model(obj, parent=key)) elif issubclass(obj, Enum): rv[obj].append(key) - except: + except Exception: pass return rv @@ -176,10 +144,7 @@ def _denormalize_schema(schema: Dict[str, Any], defs: Dict[str, Dict[str, Any]]) schema["anyOf"][idx] = defs[element["$ref"][14:]] -def _index_default_query_parameters( - module: RESTModule, - sort_enabled: bool = True -) -> List[Dict[str, Any]]: +def _index_default_query_parameters(module: RESTModule, sort_enabled: bool = True) -> List[Dict[str, Any]]: rv = [] model_map = {} @@ -197,7 +162,7 @@ def _index_default_query_parameters( model_config=_pydantic_baseconf, required=False, default=1, - field_info=FieldInfo(ge=1) + field_info=FieldInfo(ge=1), ), ModelField( name=module.ext.config.pagesize_param, @@ -207,11 +172,9 @@ def _index_default_query_parameters( required=False, default=module.ext.config.default_pagesize, field_info=FieldInfo( - description="Size of the page", - ge=module.ext.config.min_pagesize, - le=module.ext.config.max_pagesize - ) - ) + description="Size of the page", ge=module.ext.config.min_pagesize, le=module.ext.config.max_pagesize + ), + ), ] if sort_enabled: @@ -229,12 +192,12 @@ def _index_default_query_parameters( "Descendant sorting applied with -{parameter} notation. " "Multiple values should be separated by comma." ) - ) + ), ) ) if condition_fields: - where_model = create_model('Condition', **condition_fields) + where_model = create_model("Condition", **condition_fields) fields.append( ModelField( name=module.ext.config.query_param, @@ -242,29 +205,18 @@ def _index_default_query_parameters( class_validators=None, model_config=_pydantic_baseconf, required=False, - field_info=FieldInfo( - description=( - "Filter results using the provided query object." - ) - ) + field_info=FieldInfo(description=("Filter results using the provided query object.")), ) ) - model_map[where_model] = 'Condition' + model_map[where_model] = "Condition" for field in fields: - schema, defs, _ = field_schema( - field, model_name_map=model_map, ref_prefix=None - ) + schema, defs, _ = field_schema(field, model_name_map=model_map, ref_prefix=None) if field.name in enums: schema["items"]["enum"] = enums[field.name] elif field.name == module.ext.config.query_param: schema["allOf"][0] = defs["Condition"] - rv.append({ - "name": field.name, - "in": "query", - "required": field.required, - "schema": schema - }) + rv.append({"name": field.name, "in": "query", "required": field.required, "schema": schema}) return rv @@ -275,22 +227,19 @@ def _stats_default_query_parameters(module: RESTModule) -> List[Dict[str, Any]]: fields = [ ModelField( - name='fields', + name="fields", type_=List[str], class_validators=None, model_config=_pydantic_baseconf, required=True, field_info=FieldInfo( - description=( - "Add specified attribute(s) to stats. " - "Multiple values should be separated by comma." - ) - ) + description=("Add specified attribute(s) to stats. " "Multiple values should be separated by comma.") + ), ) ] if condition_fields: - where_model = create_model('Condition', **condition_fields) + where_model = create_model("Condition", **condition_fields) fields.append( ModelField( name=module.ext.config.query_param, @@ -298,36 +247,23 @@ def _stats_default_query_parameters(module: RESTModule) -> List[Dict[str, Any]]: class_validators=None, model_config=_pydantic_baseconf, required=False, - field_info=FieldInfo( - description=( - "Filter results using the provided query object." - ) - ) + field_info=FieldInfo(description=("Filter results using the provided query object.")), ) ) - model_map[where_model] = 'Condition' + model_map[where_model] = "Condition" for field in fields: - schema, defs, _ = field_schema( - field, model_name_map=model_map, ref_prefix=None - ) + schema, defs, _ = field_schema(field, model_name_map=model_map, ref_prefix=None) if field.name == "fields": schema["items"]["enum"] = module.stats_allowed_fields if field.name == module.ext.config.query_param: schema["allOf"][0] = defs["Condition"] - rv.append({ - "name": field.name, - "in": "query", - "required": field.required, - "schema": schema - }) + rv.append({"name": field.name, "in": "query", "required": field.required, "schema": schema}) return rv def build_schema_from_fields( - module: RESTModule, - fields: Dict[str, Any], - hints_check: Optional[Set[str]] = None + module: RESTModule, fields: Dict[str, Any], hints_check: Optional[Set[str]] = None ) -> Tuple[Dict[str, Any], Type[BaseModel]]: hints_check = hints_check if hints_check is not None else set(fields.keys()) schema_fields, hints_defs, fields_choices = {}, defaultdict(list), {} @@ -345,20 +281,16 @@ def build_schema_from_fields( if choices: fields_choices[key] = choices for key in set(schema_fields.keys()) & hints_check: - for type_arg in [schema_fields[key][0]] + list(getattr( - schema_fields[key][0], "__args__", [] - )): + for type_arg in [schema_fields[key][0]] + list(getattr(schema_fields[key][0], "__args__", [])): for ikey, ival in _defs_from_item(type_arg, key).items(): hints_defs[ikey].extend(ival) model = create_model(module.model.__name__, **schema_fields) schema, defs, nested = model_process_schema( - model, - model_name_map={key: key.__name__ for key in hints_defs.keys()}, - ref_prefix=None + model, model_name_map={key: key.__name__ for key in hints_defs.keys()}, ref_prefix=None ) for def_schema in defs.values(): _denormalize_schema(def_schema, defs) - for key, value in schema["properties"].items(): + for value in schema["properties"].values(): _denormalize_schema(value, defs) for key, choices in fields_choices.items(): schema["properties"][key]["enum"] = choices @@ -398,10 +330,7 @@ def __init__( self.security_schemes = security_schemes or {} def fields_from_model( - self, - model: Any, - model_fields: Dict[str, Any], - fields: List[str] + self, model: Any, model_fields: Dict[str, Any], fields: List[str] ) -> Dict[str, Tuple[Type, Any, List[Any]]]: rv = {} for key in fields: @@ -418,25 +347,18 @@ def fields_from_model( rv[key] = ( _model_field_types_map.get(ftype, Any), Field(default_factory=model_fields[key].default) - if callable(model_fields[key].default) else model_fields[key].default, - choices + if callable(model_fields[key].default) + else model_fields[key].default, + choices, ) return rv def build_schema_from_parser( - self, - module: RESTModule, - parser: Parser, - model_fields: Optional[Dict[str, Any]] = None + self, module: RESTModule, parser: Parser, model_fields: Optional[Dict[str, Any]] = None ) -> Tuple[Dict[str, Any], Type[BaseModel]]: - model_fields = model_fields or { - key: module.model.table[key] - for key in module.model._instance_()._fieldset_all - } - fields, hints_check = self.fields_from_model( - module.model, model_fields, parser.attributes - ), set() - for key, defdata in getattr(parser, '_openapi_def_fields', {}).items(): + model_fields = model_fields or {key: module.model.table[key] for key in module.model._instance_()._fieldset_all} + fields, hints_check = self.fields_from_model(module.model, model_fields, parser.attributes), set() + for key, defdata in getattr(parser, "_openapi_def_fields", {}).items(): if isinstance(defdata, (list, tuple)): type_hint, type_default = defdata else: @@ -447,27 +369,19 @@ def build_schema_from_parser( return build_schema_from_fields(module, fields, hints_check) def build_schema_from_serializer( - self, - module: RESTModule, - serializer: Serializer, - model_fields: Optional[Dict[str, Any]] = None + self, module: RESTModule, serializer: Serializer, model_fields: Optional[Dict[str, Any]] = None ) -> Tuple[Dict[str, Any], Type[BaseModel]]: - model_fields = model_fields or { - key: module.model.table[key] - for key in module.model._instance_()._fieldset_all - } - fields, hints_check = self.fields_from_model( - module.model, model_fields, serializer.attributes - ), set() + model_fields = model_fields or {key: module.model.table[key] for key in module.model._instance_()._fieldset_all} + fields, hints_check = self.fields_from_model(module.model, model_fields, serializer.attributes), set() for key in serializer._attrs_override_: - type_hint = get_type_hints(getattr(serializer, key)).get('return', Any) + type_hint = get_type_hints(getattr(serializer, key)).get("return", Any) type_hint_opt = False for type_arg in getattr(type_hint, "__args__", []): if issubclass(type_arg, type(None)): type_hint_opt = True fields[key] = (type_hint, None if type_hint_opt else ...) hints_check.add(key) - for key, defdata in getattr(serializer, '_openapi_def_fields', {}).items(): + for key, defdata in getattr(serializer, "_openapi_def_fields", {}).items(): if isinstance(defdata, (list, tuple)): type_hint, type_default = defdata else: @@ -479,92 +393,66 @@ def build_schema_from_serializer( def build_definitions(self, module: RESTModule) -> Dict[str, Any]: serializers, parsers = {}, {} - model_fields = { - key: module.model.table[key] - for key in module.model._instance_()._fieldset_all - } + model_fields = {key: module.model.table[key] for key in module.model._instance_()._fieldset_all} for serializer_name, serializer in { "__default__": module.serializer, - **module._openapi_specs["serializers"] + **module._openapi_specs["serializers"], }.items(): if serializer in serializers: continue data = serializers[serializer] = {} - serializer_schema, serializer_model = self.build_schema_from_serializer( - module, serializer, model_fields - ) - data.update( - name=serializer_name, - model=serializer_model, - schema=serializer_schema - ) + serializer_schema, serializer_model = self.build_schema_from_serializer(module, serializer, model_fields) + data.update(name=serializer_name, model=serializer_model, schema=serializer_schema) - for parser_name, parser in { - "__default__": module.parser, - **module._openapi_specs["parsers"] - }.items(): + for parser_name, parser in {"__default__": module.parser, **module._openapi_specs["parsers"]}.items(): if parser in parsers: continue data = parsers[parser] = {} - parser_schema, parser_model = self.build_schema_from_parser( - module, parser, model_fields - ) - data.update( - name=parser_name, - model=parser_model, - schema=parser_schema - ) + parser_schema, parser_model = self.build_schema_from_parser(module, parser, model_fields) + data.update(name=parser_name, model=parser_model, schema=parser_schema) return { "module": module.name, "model": module.model.__name__, "serializers": serializers, "parsers": parsers, - "schema": serializers[module.serializer]["schema"] + "schema": serializers[module.serializer]["schema"], } def build_operation_metadata( - self, - module: RESTModule, - modules_tags: Dict[str, str], - route_kind: str, - method: str + self, module: RESTModule, modules_tags: Dict[str, str], route_kind: str, method: str ) -> Dict[str, Any]: - entity_name = ( - module._openapi_specs.get("entity_name") or - module.name.rsplit(".", 1)[-1] - ) + entity_name = module._openapi_specs.get("entity_name") or module.name.rsplit(".", 1)[-1] return { "summary": _def_summaries[route_kind].format(entity=entity_name), "description": _def_descriptions[route_kind].format(entity=entity_name), "operationId": f"{module.name}.{route_kind}.{method}".replace(".", "_"), - "tags": [modules_tags[module.name]] + "tags": [modules_tags[module.name]], } def build_operation_parameters( - self, - module: RESTModule, - path_kind: str, - path_params: Dict[str, Dict[str, Any]] + self, module: RESTModule, path_kind: str, path_params: Dict[str, Dict[str, Any]] ) -> List[Dict[str, Any]]: rv = [] for pname, pdata in path_params.items(): - rv.append({ - "name": pname, - "in": "path", - "required": not pdata["optional"], - "schema": field_schema( - ModelField( - name=pname, - type_=_path_param_types_map[pdata["type"]], - class_validators=None, - model_config=_pydantic_baseconf, - required=not pdata["optional"] - ), - model_name_map={}, - ref_prefix=REF_PREFIX - )[0] - }) + rv.append( + { + "name": pname, + "in": "path", + "required": not pdata["optional"], + "schema": field_schema( + ModelField( + name=pname, + type_=_path_param_types_map[pdata["type"]], + class_validators=None, + model_config=_pydantic_baseconf, + required=not pdata["optional"], + ), + model_name_map={}, + ref_prefix=REF_PREFIX, + )[0], + } + ) if path_kind == "index": rv.extend(_index_default_query_parameters(module)) elif path_kind == "sample": @@ -576,10 +464,7 @@ def build_operation_parameters( rv.extend(_stats_default_query_parameters(module)) return rv - def build_operation_common_responses( - self, - path_kind: str - ) -> Dict[str, Any]: + def build_operation_common_responses(self, path_kind: str) -> Dict[str, Any]: rv = {} if path_kind in ["read", "update", "delete"]: rv["404"] = _def_errors["404"] @@ -589,18 +474,14 @@ def build_operation_common_responses( rv["400"] = _def_errors["400"] return rv - def build_index_schema( - self, - module: RESTModule, - item_schema: Dict[str, Any] - ) -> Dict[str, Any]: + def build_index_schema(self, module: RESTModule, item_schema: Dict[str, Any]) -> Dict[str, Any]: fields = {module.list_envelope: (List[Dict[str, Any]], ...)} if module.serialize_meta: fields[module.meta_envelope] = (MetaModel, ...) schema, defs, nested = model_process_schema( create_model(f"{module.__class__.__name__}Index", **fields), model_name_map={MetaModel: "Meta"}, - ref_prefix=None + ref_prefix=None, ) schema["properties"][module.list_envelope]["items"] = item_schema if module.serialize_meta: @@ -615,12 +496,10 @@ def build_group_schema(self, module: RESTModule) -> Dict[str, Any]: schema, defs, nested = model_process_schema( create_model(f"{module.__class__.__name__}Group", **fields), model_name_map={MetaModel: "Meta"}, - ref_prefix=None + ref_prefix=None, ) schema["properties"][module.groups_envelope]["items"] = model_process_schema( - GroupModel, - model_name_map={}, - ref_prefix=None + GroupModel, model_name_map={}, ref_prefix=None )[0] schema["properties"][module.groups_envelope]["items"]["title"] = "Group" if module.serialize_meta: @@ -633,7 +512,7 @@ def build_stats_schema(self, module: RESTModule) -> Dict[str, Any]: schema, defs, nested = model_process_schema( create_model(f"{module.__class__.__name__}Stat", **fields), model_name_map={StatModel: "Stat"}, - ref_prefix=None + ref_prefix=None, ) schema["properties"]["stats"]["additionalProperties"] = defs["Stat"] return schema["properties"]["stats"] @@ -643,17 +522,15 @@ def build_paths( module: RESTModule, modules_tags: Dict[str, str], serializers: Dict[Serializer, Dict[str, Any]], - parsers: Dict[Parser, Dict[str, Any]] + parsers: Dict[Parser, Dict[str, Any]], ) -> Dict[str, Dict[str, Dict[str, Any]]]: rv: Dict[str, Dict[str, Dict[str, Any]]] = {} - mod_name = module.name.rsplit('.', 1)[-1] + mod_name = module.name.rsplit(".", 1)[-1] entity_name = module._openapi_specs.get("entity_name") or mod_name mod_prefix: str = module.url_prefix or "/" - path_prefix: str = ( - module.app._router_http._prefix_main + ( - f"/{mod_prefix}" if not mod_prefix.startswith("/") else mod_prefix - ) + path_prefix: str = module.app._router_http._prefix_main + ( + f"/{mod_prefix}" if not mod_prefix.startswith("/") else mod_prefix ) for path_kind in set(module.enabled_methods) & { @@ -664,7 +541,7 @@ def build_paths( "delete", "sample", "group", - "stats" + "stats", }: path_relative, methods = module._methods_map[path_kind] if not isinstance(methods, list): @@ -676,7 +553,7 @@ def build_paths( for ptype, pname in _re_path_param.findall(path_scoped): path_params[pname] = {"type": ptype, "optional": False} path_scoped = path_scoped.replace(f"<{ptype}:{pname}>", f"{{{pname}}}") - for _, ptype, pname in _re_path_param_optional.findall(path_scoped): + for _, _, pname in _re_path_param_optional.findall(path_scoped): path_params[pname]["optional"] = True rv[path_scoped] = rv.get(path_scoped) or {} @@ -685,96 +562,54 @@ def build_paths( parser_obj = module._openapi_specs["parsers"].get(path_kind, module.parser) for method in methods: - operation = self.build_operation_metadata( - module, modules_tags, path_kind, method - ) - operation_parameters = self.build_operation_parameters( - module, path_kind, path_params - ) + operation = self.build_operation_metadata(module, modules_tags, path_kind, method) + operation_parameters = self.build_operation_parameters(module, path_kind, path_params) operation_responses = self.build_operation_common_responses(path_kind) if operation_parameters: operation["parameters"] = operation_parameters - if ( - path_kind in ["create", "update"] or ( - path_kind == "delete" and - "delete" in module._openapi_specs["parsers"] - ) + if path_kind in ["create", "update"] or ( + path_kind == "delete" and "delete" in module._openapi_specs["parsers"] ): operation["requestBody"] = { - "content": { - "application/json": { - "schema": parsers[parser_obj]["schema"] - } - } + "content": {"application/json": {"schema": parsers[parser_obj]["schema"]}} } if path_kind in ["create", "read", "update"]: serializer_obj = module._openapi_specs["serializers"][path_kind] response_code = "201" if path_kind == "create" else "200" - descriptions = { - "create": "Resource created", - "read": "Resource", - "update": "Resource updated" - } + descriptions = {"create": "Resource created", "read": "Resource", "update": "Resource updated"} operation_responses[response_code] = { "description": descriptions[path_kind], - "content": { - "application/json": { - "schema": serializers[serializer_obj]["schema"] - } - } + "content": {"application/json": {"schema": serializers[serializer_obj]["schema"]}}, } elif path_kind in ["index", "sample"]: operation_responses["200"] = { - "description": ( - "Resource list" if path_kind == "index" else - "Resource random list" - ), + "description": ("Resource list" if path_kind == "index" else "Resource random list"), "content": { "application/json": { - "schema": self.build_index_schema( - module, - serializers[serializer_obj]["schema"] - ) + "schema": self.build_index_schema(module, serializers[serializer_obj]["schema"]) } - } + }, } elif path_kind == "delete": operation_responses["200"] = { "description": "Resource deleted", - "content": { - "application/json": { - "schema": { - "type": "object", - "properties": {} - } - } - } + "content": {"application/json": {"schema": {"type": "object", "properties": {}}}}, } elif path_kind == "group": operation_responses["200"] = { "description": "Resource groups", - "content": { - "application/json": { - "schema": self.build_group_schema(module) - } - } + "content": {"application/json": {"schema": self.build_group_schema(module)}}, } elif path_kind == "stats": operation_responses["200"] = { "description": "Resource stats", - "content": { - "application/json": { - "schema": self.build_stats_schema(module) - } - } + "content": {"application/json": {"schema": self.build_stats_schema(module)}}, } if operation_responses: operation["responses"] = operation_responses rv[path_scoped][method] = operation - for path_name, path_target, path_data in ( - module._openapi_specs["additional_routes"] - ): + for path_name, path_target, path_data in module._openapi_specs["additional_routes"]: methods = path_data.methods for path_relative in path_data.paths: path_scoped: str = path_prefix + path_relative @@ -783,10 +618,8 @@ def build_paths( path_params = {} for ptype, pname in _re_path_param.findall(path_scoped): path_params[pname] = {"type": ptype, "optional": False} - path_scoped = path_scoped.replace( - f"<{ptype}:{pname}>", f"{{{pname}}}" - ) - for _, ptype, pname in _re_path_param_optional.findall(path_scoped): + path_scoped = path_scoped.replace(f"<{ptype}:{pname}>", f"{{{pname}}}") + for _, _, pname in _re_path_param_optional.findall(path_scoped): path_params[pname]["optional"] = True rv[path_scoped] = rv.get(path_scoped) or {} @@ -794,98 +627,47 @@ def build_paths( for method in methods: operation = { "summary": getattr( - path_target, - "_openapi_desc_summary", - f"{{name}} {path_name.rsplit('.', 1)[-1]}" - ).format(name=entity_name), - "description": getattr( - path_target, - "_openapi_desc_description", - "" + path_target, "_openapi_desc_summary", f"{{name}} {path_name.rsplit('.', 1)[-1]}" ).format(name=entity_name), + "description": getattr(path_target, "_openapi_desc_description", "").format(name=entity_name), "operationId": f"{path_name}.{method}".replace(".", "_"), - "tags": [modules_tags[module.name]] + "tags": [modules_tags[module.name]], } - operation_parameters = self.build_operation_parameters( - module, "custom", path_params - ) + operation_parameters = self.build_operation_parameters(module, "custom", path_params) operation_responses = {} if operation_parameters: operation["parameters"] = operation_parameters - operation_request = getattr( - path_target, "_openapi_def_request", None - ) + operation_request = getattr(path_target, "_openapi_def_request", None) if operation_request: - schema = build_schema_from_fields( - module, - operation_request["fields"] - )[0] + schema = build_schema_from_fields(module, operation_request["fields"])[0] for file_param in operation_request["files"]: - schema["properties"][file_param] = { - "type": "string", - "format": "binary" - } - operation["requestBody"] = { - "content": { - operation_request["content"]: { - "schema": schema - } - } - } + schema["properties"][file_param] = {"type": "string", "format": "binary"} + operation["requestBody"] = {"content": {operation_request["content"]: {"schema": schema}}} else: - parser = getattr( - path_target, "_openapi_def_parser", module.parser - ) + parser = getattr(path_target, "_openapi_def_parser", module.parser) if parser in parsers: schema = parsers[parser]["schema"] else: schema = self.build_schema_from_parser(module, parser)[0] - operation["requestBody"] = { - "content": { - "application/json": { - "schema": schema - } - } - } + operation["requestBody"] = {"content": {"application/json": {"schema": schema}}} operation_responses = {} - defined_responses = getattr( - path_target, "_openapi_def_responses", None - ) + defined_responses = getattr(path_target, "_openapi_def_responses", None) if defined_responses: for status_code, defined_response in defined_responses.items(): - schema = build_schema_from_fields( - module, - defined_response["fields"] - )[0] + schema = build_schema_from_fields(module, defined_response["fields"])[0] operation_responses[status_code] = { - "content": { - defined_response["content"]: { - "schema": schema - } - } + "content": {defined_response["content"]: {"schema": schema}} } else: - serializer = getattr( - path_target, "_openapi_def_serializer", module.serializer - ) + serializer = getattr(path_target, "_openapi_def_serializer", module.serializer) if serializer in serializers: schema = serializers[serializer]["schema"] else: - schema = self.build_schema_from_serializer( - module, serializer - )[0] - operation_responses["200"] = { - "content": { - "application/json": { - "schema": schema - } - } - } - defined_resp_errors = getattr( - path_target, "_openapi_def_response_codes", [] - ) + schema = self.build_schema_from_serializer(module, serializer)[0] + operation_responses["200"] = {"content": {"application/json": {"schema": schema}}} + defined_resp_errors = getattr(path_target, "_openapi_def_response_codes", []) for status_code in defined_resp_errors: operation_responses[status_code] = _def_errors[status_code] @@ -909,14 +691,7 @@ def __call__(self, produce_schemas: bool = False) -> Dict[str, Any]: # "name": module.name, # "description": module.name.split(".")[-1].title() # }) - paths.update( - self.build_paths( - module, - self.modules_tags, - defs["serializers"], - defs["parsers"] - ) - ) + paths.update(self.build_paths(module, self.modules_tags, defs["serializers"], defs["parsers"])) definitions[module.name] = defs if definitions and produce_schemas: components["schemas"] = { @@ -948,7 +723,7 @@ def build_schema( contact: Optional[Dict[str, Union[str, Any]]] = None, license_info: Optional[Dict[str, Union[str, Any]]] = None, security_schemes: Optional[Dict[str, Any]] = None, - generator_cls: Optional[Type[OpenAPIGenerator]] = None + generator_cls: Optional[Type[OpenAPIGenerator]] = None, ) -> Dict[str, Any]: generator_cls = generator_cls or OpenAPIGenerator generator = generator_cls( @@ -963,6 +738,6 @@ def build_schema( terms_of_service=terms_of_service, contact=contact, license_info=license_info, - security_schemes=security_schemes + security_schemes=security_schemes, ) return generator(produce_schemas=produce_schemas) diff --git a/emmett_rest/openapi/helpers.py b/emmett_rest/openapi/helpers.py index 9f1ba9c..35a7cfb 100644 --- a/emmett_rest/openapi/helpers.py +++ b/emmett_rest/openapi/helpers.py @@ -1,12 +1,12 @@ # -*- coding: utf-8 -*- """ - emmett_rest.openapi.helpers - --------------------------- +emmett_rest.openapi.helpers +--------------------------- - Provides OpenAPI internal helpers +Provides OpenAPI internal helpers - :copyright: 2017 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2017 Giovanni Barillari +:license: BSD-3-Clause """ from enum import Enum diff --git a/emmett_rest/openapi/mod.py b/emmett_rest/openapi/mod.py index 1f8a6c8..c6b6c52 100644 --- a/emmett_rest/openapi/mod.py +++ b/emmett_rest/openapi/mod.py @@ -1,21 +1,22 @@ # -*- coding: utf-8 -*- """ - emmett_rest.openapi.mod - ----------------------- +emmett_rest.openapi.mod +----------------------- - Provides OpenAPI application module +Provides OpenAPI application module - :copyright: 2017 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2017 Giovanni Barillari +:license: BSD-3-Clause """ from importlib.resources import read_text from typing import Any, Dict, List, Optional, Union from emmett import App, AppModule, response, url -from emmett.cache import RamCache, RouteCacheRule +from emmett.cache import RamCache from emmett.tools.service import JSONServicePipe from emmett.utils import cachedprop +from emmett_core.routing.cache import RouteCacheRule from ..rest import RESTModule from .generation import build_schema @@ -23,7 +24,8 @@ class OpenAPIModule(AppModule): - def __init__(self, + def __init__( + self, app: App, name: str, import_name: str, @@ -42,16 +44,9 @@ def __init__(self, ui_path: str = "/docs", url_prefix: Optional[str] = None, hostname: Optional[str] = None, - **kwargs: Any + **kwargs: Any, ): - super().__init__( - app, - name, - import_name, - url_prefix=url_prefix, - hostname=hostname, - **kwargs - ) + super().__init__(app, name, import_name, url_prefix=url_prefix, hostname=hostname, **kwargs) self._cache = RamCache() self.title = title self.description = description @@ -81,14 +76,14 @@ def _define_routes(self, expose_ui: bool, ui_path: str): name="schema_json", methods="get", pipeline=[JSONServicePipe()], - cache=RouteCacheRule(self._cache) if not bool(self.app.debug) else None + cache=RouteCacheRule(self._cache) if not bool(self.app.debug) else None, )(self._get_spec) self.route( "/openapi.yaml", name="schema_yaml", methods="get", pipeline=[YAMLPipe()], - cache=RouteCacheRule(self._cache) if not bool(self.app.debug) else None + cache=RouteCacheRule(self._cache) if not bool(self.app.debug) else None, )(self._get_spec) if expose_ui: self.route( @@ -96,7 +91,7 @@ def _define_routes(self, expose_ui: bool, ui_path: str): name="ui", methods="get", output="str", - cache=RouteCacheRule(self._cache) if not bool(self.app.debug) else None + cache=RouteCacheRule(self._cache) if not bool(self.app.debug) else None, )(self._ui_stoplight) async def _get_spec(self): @@ -112,7 +107,7 @@ async def _get_spec(self): terms_of_service=self.terms_of_service, contact=self.contact, license_info=self.license_info, - security_schemes=self.security_schemes + security_schemes=self.security_schemes, ) @cachedprop @@ -126,7 +121,7 @@ def _default_description(self): file_path="__emmett_rest__/openapi/description.md", context={ "title": self.title, - } + }, ) async def _ui_stoplight(self): @@ -134,10 +129,7 @@ async def _ui_stoplight(self): return self.app.templater._render( self._stoplight_template, file_path="__emmett_rest__/openapi/stoplight.html", - context={ - "title": self.title, - "openapi_url": url(f"{self.name}.schema_yaml") - } + context={"title": self.title, "openapi_url": url(f"{self.name}.schema_yaml")}, ) def regroup(self, module_name: str, destination: str): diff --git a/emmett_rest/openapi/schemas.py b/emmett_rest/openapi/schemas.py index 91def91..0834c51 100644 --- a/emmett_rest/openapi/schemas.py +++ b/emmett_rest/openapi/schemas.py @@ -1,12 +1,12 @@ # -*- coding: utf-8 -*- """ - emmett_rest.openapi.schemas - --------------------------- +emmett_rest.openapi.schemas +--------------------------- - Provides OpenAPI schemas +Provides OpenAPI schemas - :copyright: 2017 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2017 Giovanni Barillari +:license: BSD-3-Clause """ from __future__ import annotations diff --git a/emmett_rest/parsers.py b/emmett_rest/parsers.py index a4d37af..89700b6 100644 --- a/emmett_rest/parsers.py +++ b/emmett_rest/parsers.py @@ -1,12 +1,12 @@ # -*- coding: utf-8 -*- """ - emmett_rest.parsers - ------------------- +emmett_rest.parsers +------------------- - Provides REST de-serialization tools +Provides REST de-serialization tools - :copyright: 2017 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2017 Giovanni Barillari +:license: BSD-3-Clause """ from collections import OrderedDict @@ -17,7 +17,7 @@ class VParserDefinition: - __slots__ = ['param', 'f'] + __slots__ = ["param", "f"] def __init__(self, param): self.param = param @@ -28,7 +28,7 @@ def __call__(self, f): class ProcParserDefinition: - __slots__ = ['f', '_inst_count_'] + __slots__ = ["f", "_inst_count_"] _all_inst_count_ = 0 def __init__(self): @@ -58,9 +58,9 @@ def __new__(cls, name, bases, attrs): new_class._declared_vparsers_ = declared_vparsers new_class._declared_procs_ = declared_procs for base in reversed(new_class.__mro__[1:]): - if hasattr(base, '_declared_vparsers_'): + if hasattr(base, "_declared_vparsers_"): all_vparsers.update(base._declared_vparsers_) - if hasattr(base, '_declared_procs_'): + if hasattr(base, "_declared_procs_"): all_procs.update(base._declared_procs_) all_vparsers.update(declared_vparsers) all_procs.update(declared_procs) @@ -97,7 +97,7 @@ def __init__(self, model): writable_map = {} for fieldname in self._model.table.fields: writable_map[fieldname] = self._model.table[fieldname].writable - if hasattr(self._model, 'rest_rw'): + if hasattr(self._model, "rest_rw"): self.attributes = [] for key, value in self._model.rest_rw.items(): if isinstance(value, tuple): @@ -113,11 +113,8 @@ def __init__(self, model): if el in self.attributes: self.attributes.remove(el) _attrs_override_ = [] - for key in ( - set(dir(self)) - set(self._all_vparsers_.keys()) - - set(self._all_procs_.keys()) - ): - if not key.startswith('_') and callable(getattr(self, key)): + for key in set(dir(self)) - set(self._all_vparsers_.keys()) - set(self._all_procs_.keys()): + if not key.startswith("_") and callable(getattr(self, key)): _attrs_override_.append(key) self._attrs_override_ = _attrs_override_ self._init() @@ -162,7 +159,5 @@ def parse_params_with_parser(parser_instance, **extras): async def parse_params(*accepted_params, **kwargs): - params = _envelope_filter( - await request.body_params, kwargs.get('envelope') - ) + params = _envelope_filter(await request.body_params, kwargs.get("envelope")) return _parse(set(accepted_params), params) diff --git a/emmett_rest/queries/errors.py b/emmett_rest/queries/errors.py index 3631a70..334b8dd 100644 --- a/emmett_rest/queries/errors.py +++ b/emmett_rest/queries/errors.py @@ -1,12 +1,12 @@ # -*- coding: utf-8 -*- """ - emmett_rest.queries.errors - -------------------------- +emmett_rest.queries.errors +-------------------------- - Provides REST query language exception classes +Provides REST query language exception classes - :copyright: 2017 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2017 Giovanni Barillari +:license: BSD-3-Clause """ from __future__ import annotations @@ -18,8 +18,8 @@ def __init__(self, **kwargs): super().__init__(self.gen_msg()) def init(self, **kwargs): - self.op = kwargs['op'] - self.value = kwargs['value'] + self.op = kwargs["op"] + self.value = kwargs["value"] def gen_msg(self) -> str: return "Invalid {} condition: {!r}".format(self.op, self.value) diff --git a/emmett_rest/queries/helpers.py b/emmett_rest/queries/helpers.py index 49ea4e0..f056534 100644 --- a/emmett_rest/queries/helpers.py +++ b/emmett_rest/queries/helpers.py @@ -1,25 +1,27 @@ # -*- coding: utf-8 -*- """ - emmett_rest.queries.helpers - --------------------------- +emmett_rest.queries.helpers +--------------------------- - Provides REST query language helpers +Provides REST query language helpers - :copyright: 2017 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2017 Giovanni Barillari +:license: BSD-3-Clause """ from __future__ import annotations +from typing import Any, Dict + from emmett import request, response from emmett.parsers import Parsers -from typing import Any, Dict from ..helpers import ModulePipe from .errors import QueryError from .parser import parse_conditions as _parse_conditions -_json_load = Parsers.get_for('json') + +_json_load = Parsers.get_for("json") class JSONQueryPipe(ModulePipe): @@ -34,21 +36,16 @@ def set_accepted(self): async def pipe_request(self, next_pipe, **kwargs): if request.query_params[self.query_param] and self._accepted_set: try: - input_condition = self._parse_where_param( - request.query_params[self.query_param] - ) + input_condition = self._parse_where_param(request.query_params[self.query_param]) except ValueError: response.status = 400 - return self.mod.error_400({self.query_param: 'invalid value'}) + return self.mod.error_400({self.query_param: "invalid value"}) try: - dbset = _parse_conditions( - self.mod.model, kwargs['dbset'], - input_condition, self._accepted_set - ) + dbset = _parse_conditions(self.mod.model, kwargs["dbset"], input_condition, self._accepted_set) except QueryError as exc: response.status = 400 return self.mod.error_400({self.query_param: exc.gen_msg()}) - kwargs['dbset'] = dbset + kwargs["dbset"] = dbset return await next_pipe(**kwargs) @staticmethod @@ -59,5 +56,5 @@ def _parse_where_param(param: str) -> Dict[str, Any]: param = _json_load(param) assert isinstance(param, dict) except Exception: - raise ValueError('Invalid param') + raise ValueError("Invalid param") return param diff --git a/emmett_rest/queries/parser.py b/emmett_rest/queries/parser.py index e790885..8dac687 100644 --- a/emmett_rest/queries/parser.py +++ b/emmett_rest/queries/parser.py @@ -1,22 +1,22 @@ # -*- coding: utf-8 -*- """ - emmett_rest.queries.parser - -------------------------- +emmett_rest.queries.parser +-------------------------- - Provides REST query language parser +Provides REST query language parser - :copyright: 2017 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2017 Giovanni Barillari +:license: BSD-3-Clause """ from __future__ import annotations import operator +from functools import reduce +from typing import Any, Callable, Dict, Optional, Set, Union from emmett import sdict from emmett.orm.objects import Expression, Query, Set as DBSet -from functools import reduce -from typing import Any, Callable, Dict, Optional, Set, Union from ..typing import ModelType from .errors import QueryError @@ -24,60 +24,32 @@ _query_operators = { - '$and': operator.and_, - '$or': operator.or_, - '$not': lambda field, value: operator.inv(value), - '$eq': operator.eq, - '$ne': operator.ne, - '$lt': operator.lt, - '$gt': operator.gt, - '$le': operator.le, - '$ge': operator.ge, - '$lte': operator.le, - '$gte': operator.ge, - '$in': lambda field, value: operator.methodcaller('belongs', value)(field), - '$exists': lambda field, value: ( - operator.ne(field, None) if value else operator.eq(field, None) - ), - '$contains': lambda field, value: ( - operator.methodcaller('contains', value, case_sensitive=True)(field) - ), - '$icontains': lambda field, value: ( - operator.methodcaller('contains', value, case_sensitive=False)(field) - ), - '$like': lambda field, value: ( - operator.methodcaller('like', value, case_sensitive=True)(field) - ), - '$ilike': lambda field, value: ( - operator.methodcaller('like', value, case_sensitive=False)(field) - ), - '$regex': lambda field, value: ( - operator.methodcaller('contains', value, case_sensitive=True)(field) - ), - '$iregex': lambda field, value: ( - operator.methodcaller('contains', value, case_sensitive=False)(field) - ), - '$geo.contains': lambda field, value: ( - operator.methodcaller('st_contains', value)(field) - ), - '$geo.equals': lambda field, value: ( - operator.methodcaller('st_equals', value)(field) - ), - '$geo.intersects': lambda field, value: ( - operator.methodcaller('st_intersects', value)(field) - ), - '$geo.overlaps': lambda field, value: ( - operator.methodcaller('st_overlaps', value)(field) - ), - '$geo.touches': lambda field, value: ( - operator.methodcaller('st_touches', value)(field) - ), - '$geo.within': lambda field, value: ( - operator.methodcaller('st_within', value)(field) - ), - '$geo.dwithin': lambda field, value: ( - operator.methodcaller('st_dwithin', value[0], value[1])(field) - ) + "$and": operator.and_, + "$or": operator.or_, + "$not": lambda field, value: operator.inv(value), + "$eq": operator.eq, + "$ne": operator.ne, + "$lt": operator.lt, + "$gt": operator.gt, + "$le": operator.le, + "$ge": operator.ge, + "$lte": operator.le, + "$gte": operator.ge, + "$in": lambda field, value: operator.methodcaller("belongs", value)(field), + "$exists": lambda field, value: (operator.ne(field, None) if value else operator.eq(field, None)), + "$contains": lambda field, value: (operator.methodcaller("contains", value, case_sensitive=True)(field)), + "$icontains": lambda field, value: (operator.methodcaller("contains", value, case_sensitive=False)(field)), + "$like": lambda field, value: (operator.methodcaller("like", value, case_sensitive=True)(field)), + "$ilike": lambda field, value: (operator.methodcaller("like", value, case_sensitive=False)(field)), + "$regex": lambda field, value: (operator.methodcaller("contains", value, case_sensitive=True)(field)), + "$iregex": lambda field, value: (operator.methodcaller("contains", value, case_sensitive=False)(field)), + "$geo.contains": lambda field, value: (operator.methodcaller("st_contains", value)(field)), + "$geo.equals": lambda field, value: (operator.methodcaller("st_equals", value)(field)), + "$geo.intersects": lambda field, value: (operator.methodcaller("st_intersects", value)(field)), + "$geo.overlaps": lambda field, value: (operator.methodcaller("st_overlaps", value)(field)), + "$geo.touches": lambda field, value: (operator.methodcaller("st_touches", value)(field)), + "$geo.within": lambda field, value: (operator.methodcaller("st_within", value)(field)), + "$geo.dwithin": lambda field, value: (operator.methodcaller("st_dwithin", value[0], value[1])(field)), } @@ -86,13 +58,13 @@ def _glue_op_parser(key: str, value: Any, ctx: sdict) -> Expression: raise QueryError(op=key, value=value) op = _query_operators[key] return reduce( - lambda a, b: op(a, b) if a and b else None, map( + lambda a, b: op(a, b) if a and b else None, + map( # noqa: C417 lambda v: _conditions_parser( - ctx.op_set, ctx.op_validators, ctx.op_parsers, - ctx.model, v, ctx.accepted_set, - parent=key - ), value - ) + ctx.op_set, ctx.op_validators, ctx.op_parsers, ctx.model, v, ctx.accepted_set, parent=key + ), + value, + ), ) @@ -101,9 +73,7 @@ def _dict_op_parser(key: str, value: Any, ctx: sdict) -> Expression: raise QueryError(op=key, value=value) op = _query_operators[key] inner = _conditions_parser( - ctx.op_set, ctx.op_validators, ctx.op_parsers, - ctx.model, value, ctx.accepted_set, - parent=key + ctx.op_set, ctx.op_validators, ctx.op_parsers, ctx.model, value, ctx.accepted_set, parent=key ) return op(None, inner) @@ -120,11 +90,7 @@ def _generic_op_parser(key: str, value: Any, ctx: sdict) -> Expression: op_parsers = {key: _generic_op_parser for key in op_validators.keys()} -op_parsers.update({ - '$or': _glue_op_parser, - '$and': _glue_op_parser, - '$not': _dict_op_parser -}) +op_parsers.update({"$or": _glue_op_parser, "$and": _glue_op_parser, "$not": _dict_op_parser}) def _conditions_parser( @@ -134,66 +100,49 @@ def _conditions_parser( model: ModelType, query_dict: Dict[str, Any], accepted_set: Set[str], - parent: Optional[str] = None + parent: Optional[str] = None, ) -> Union[Query, None]: - query, ctx = None, sdict( - op_set=op_set, - op_validators=op_validators, - op_parsers=op_parsers, - model=model, - accepted_set=accepted_set, - parent=parent + query, ctx = ( + None, + sdict( + op_set=op_set, + op_validators=op_validators, + op_parsers=op_parsers, + model=model, + accepted_set=accepted_set, + parent=parent, + ), ) query_key_set = set(query_dict.keys()) step_conditions, inner_conditions = [], [] for key in query_key_set & op_set: step_conditions.append(op_parsers[key](key, query_dict[key], ctx)) if step_conditions: - step_query = reduce( - lambda a, b: operator.and_(a, b) if a and b else None, - step_conditions - ) + step_query = reduce(lambda a, b: operator.and_(a, b) if a and b else None, step_conditions) query = query & step_query if query else step_query for key in accepted_set & query_key_set: value = query_dict[key] if not isinstance(value, dict): - value = {'$eq': value} + value = {"$eq": value} if not value: continue inner_conditions.append( - _conditions_parser( - op_set, op_validators, op_parsers, - model, - value, accepted_set, parent=key - ) + _conditions_parser(op_set, op_validators, op_parsers, model, value, accepted_set, parent=key) ) if inner_conditions: - inner_query = reduce( - lambda a, b: operator.and_(a, b) if a and b else None, - inner_conditions - ) + inner_query = reduce(lambda a, b: operator.and_(a, b) if a and b else None, inner_conditions) query = query & inner_query if query else inner_query return query def _build_scoped_conditions_parser( - op_validators: Dict[str, Callable[[Any], Any]], - op_parsers: Dict[str, Callable[[str, Any, sdict], Any]] + op_validators: Dict[str, Callable[[Any], Any]], op_parsers: Dict[str, Callable[[str, Any, sdict], Any]] ) -> Callable[[ModelType, DBSet, Dict[str, Any], Set[str]], DBSet]: op_set = set(op_validators.keys()) - def scoped( - model: ModelType, - dbset: DBSet, - query_dict: Dict[str, Any], - accepted_set: Set[str] - ) -> DBSet: - return dbset.where( - _conditions_parser( - op_set, op_validators, op_parsers, - model, query_dict, accepted_set - ) - ) + def scoped(model: ModelType, dbset: DBSet, query_dict: Dict[str, Any], accepted_set: Set[str]) -> DBSet: + return dbset.where(_conditions_parser(op_set, op_validators, op_parsers, model, query_dict, accepted_set)) + return scoped diff --git a/emmett_rest/queries/validation.py b/emmett_rest/queries/validation.py index 044efa3..cf46c86 100644 --- a/emmett_rest/queries/validation.py +++ b/emmett_rest/queries/validation.py @@ -1,12 +1,12 @@ # -*- coding: utf-8 -*- """ - emmett_rest.queries.validation - ------------------------------ +emmett_rest.queries.validation +------------------------------ - Provides REST query language validation +Provides REST query language validation - :copyright: 2017 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2017 Giovanni Barillari +:license: BSD-3-Clause """ from __future__ import annotations @@ -16,12 +16,8 @@ from emmett.orm import geo -_geo_helpers = { - 'POINT': geo.Point, - 'LINE': geo.Line, - 'LINESTRING': geo.Line, - 'POLYGON': geo.Polygon -} + +_geo_helpers = {"POINT": geo.Point, "LINE": geo.Line, "LINESTRING": geo.Line, "POLYGON": geo.Polygon} validate_default = lambda v: v @@ -38,6 +34,7 @@ def op_validation_generator(*types) -> Callable[[Any], Any]: def op_validator(v: Any) -> Any: assert isinstance(v, types) return v + return op_validator @@ -49,49 +46,49 @@ def validate_glue(v: Any) -> List[Dict[str, Any]]: def validate_geo(v: Any) -> Any: - assert isinstance(v, dict) and set(v.keys()) == {'type', 'coordinates'} - objkey = v['type'] + assert isinstance(v, dict) and set(v.keys()) == {"type", "coordinates"} + objkey = v["type"] geohelper = _geo_helpers.get(objkey.upper()) - assert geohelper and isinstance(v['coordinates'], list) + assert geohelper and isinstance(v["coordinates"], list) try: - return geohelper(*_tuplify_list(v['coordinates'])) + return geohelper(*_tuplify_list(v["coordinates"])) except Exception: raise AssertionError def validate_geo_dwithin(v: Any) -> Any: - assert isinstance(v, dict) and set(v.keys()) == {'geometry', 'distance'} - assert v['distance'] - obj = validate_geo(v['geometry']) - return (obj, v['distance']) + assert isinstance(v, dict) and set(v.keys()) == {"geometry", "distance"} + assert v["distance"] + obj = validate_geo(v["geometry"]) + return (obj, v["distance"]) op_validators = { - '$and': validate_glue, - '$or': validate_glue, - '$eq': validate_default, - '$not': op_validation_generator(dict), - '$ne': validate_default, - '$in': op_validation_generator(list), - '$nin': op_validation_generator(list), - '$lt': op_validation_generator(int, float, datetime), - '$gt': op_validation_generator(int, float, datetime), - '$le': op_validation_generator(int, float, datetime), - '$ge': op_validation_generator(int, float, datetime), - '$lte': op_validation_generator(int, float, datetime), - '$gte': op_validation_generator(int, float, datetime), - '$exists': op_validation_generator(bool), - '$like': validate_default, - '$ilike': validate_default, - '$contains': validate_default, - '$icontains': validate_default, - '$regex': validate_default, - '$iregex': validate_default, - '$geo.contains': validate_geo, - '$geo.equals': validate_geo, - '$geo.intersects': validate_geo, - '$geo.overlaps': validate_geo, - '$geo.touches': validate_geo, - '$geo.within': validate_geo, - '$geo.dwithin': validate_geo_dwithin + "$and": validate_glue, + "$or": validate_glue, + "$eq": validate_default, + "$not": op_validation_generator(dict), + "$ne": validate_default, + "$in": op_validation_generator(list), + "$nin": op_validation_generator(list), + "$lt": op_validation_generator(int, float, datetime), + "$gt": op_validation_generator(int, float, datetime), + "$le": op_validation_generator(int, float, datetime), + "$ge": op_validation_generator(int, float, datetime), + "$lte": op_validation_generator(int, float, datetime), + "$gte": op_validation_generator(int, float, datetime), + "$exists": op_validation_generator(bool), + "$like": validate_default, + "$ilike": validate_default, + "$contains": validate_default, + "$icontains": validate_default, + "$regex": validate_default, + "$iregex": validate_default, + "$geo.contains": validate_geo, + "$geo.equals": validate_geo, + "$geo.intersects": validate_geo, + "$geo.overlaps": validate_geo, + "$geo.touches": validate_geo, + "$geo.within": validate_geo, + "$geo.dwithin": validate_geo_dwithin, } diff --git a/emmett_rest/rest.py b/emmett_rest/rest.py index 40dcf0d..d9f6566 100644 --- a/emmett_rest/rest.py +++ b/emmett_rest/rest.py @@ -1,18 +1,17 @@ # -*- coding: utf-8 -*- """ - emmett_rest.rest - ---------------- +emmett_rest.rest +---------------- - Provides main REST logics +Provides main REST logics - :copyright: 2017 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2017 Giovanni Barillari +:license: BSD-3-Clause """ from __future__ import annotations import operator - from functools import reduce from typing import Any, Callable, Dict, List, Optional, Union @@ -24,28 +23,16 @@ from emmett.routing.router import RoutingCtxGroup from emmett.tools.service import JSONServicePipe +from .helpers import FieldPipe, FieldsPipe, RecordFetcher, RESTRoutingCtx, SetFetcher from .openapi.api import ModuleOpenAPI -from .helpers import RecordFetcher, SetFetcher, FieldPipe, FieldsPipe, RESTRoutingCtx -from .parsers import ( - parse_params as _parse_params, - parse_params_with_parser as _parse_params_wparser -) +from .parsers import parse_params as _parse_params, parse_params_with_parser as _parse_params_wparser from .queries import JSONQueryPipe from .serializers import serialize as _serialize from .typing import ModelType, ParserType, SerializerType class RESTModule(AppModule): - _all_methods = { - 'index', - 'create', - 'read', - 'update', - 'delete', - 'group', - 'stats', - 'sample' - } + _all_methods = {"index", "create", "read", "update", "delete", "group", "stats", "sample"} @classmethod def from_app( @@ -71,7 +58,7 @@ def from_app( id_path: Optional[str] = None, url_prefix: Optional[str] = None, hostname: Optional[str] = None, - opts: Dict[str, Any] = {} + opts: Dict[str, Any] = {}, ) -> RESTModule: return cls( ext, @@ -95,7 +82,7 @@ def from_app( id_path=id_path, url_prefix=url_prefix, hostname=hostname, - **opts + **opts, ) @classmethod @@ -123,17 +110,14 @@ def from_module( id_path: Optional[str] = None, url_prefix: Optional[str] = None, hostname: Optional[str] = None, - opts: Dict[str, Any] = {} + opts: Dict[str, Any] = {}, ) -> RESTModule: - if '.' in name: - raise RuntimeError( - "Nested app modules' names should not contains dots" - ) - name = mod.name + '.' + name - if url_prefix and not url_prefix.startswith('/'): - url_prefix = '/' + url_prefix - module_url_prefix = (mod.url_prefix + (url_prefix or '')) \ - if mod.url_prefix else url_prefix + if "." in name: + raise RuntimeError("Nested app modules' names should not contains dots") + name = mod.name + "." + name + if url_prefix and not url_prefix.startswith("/"): + url_prefix = "/" + url_prefix + module_url_prefix = (mod.url_prefix + (url_prefix or "")) if mod.url_prefix else url_prefix hostname = hostname or mod.hostname return cls( ext, @@ -158,7 +142,7 @@ def from_module( url_prefix=module_url_prefix, hostname=hostname, pipeline=mod.pipeline, - **opts + **opts, ) def __init__( @@ -185,7 +169,7 @@ def __init__( url_prefix: Optional[str] = None, hostname: Optional[str] = None, pipeline: List[Pipe] = [], - **kwargs: Any + **kwargs: Any, ): if len(model._instance_()._fieldset_pk) > 1: raise RuntimeError("Emmett-REST doesn't support multiple PKs models") @@ -206,9 +190,7 @@ def __init__( #: service pipe injection add_service_pipe = True super_pipeline = list(pipeline) - if any( - isinstance(pipe, JSONServicePipe) for pipe in ext.app.pipeline - ) or any( + if any(isinstance(pipe, JSONServicePipe) for pipe in ext.app.pipeline) or any( isinstance(pipe, JSONServicePipe) for pipe in super_pipeline ): add_service_pipe = False @@ -216,104 +198,53 @@ def __init__( super_pipeline.insert(0, JSONServicePipe()) #: initialize super().__init__( - ext.app, - name, - import_name, - url_prefix=url_prefix, - hostname=hostname, - pipeline=super_pipeline, - **kwargs + ext.app, name, import_name, url_prefix=url_prefix, hostname=hostname, pipeline=super_pipeline, **kwargs ) self.ext = ext self._pagination = sdict() - for key in ( - 'page_param', 'pagesize_param', - 'min_pagesize', 'max_pagesize', 'default_pagesize' - ): + for key in ("page_param", "pagesize_param", "min_pagesize", "max_pagesize", "default_pagesize"): self._pagination[key] = self.ext.config[key] self._sort_param = self.ext.config.sort_param - self.default_sort = ( - default_sort or - self.ext.config.default_sort or - model.table._id.name - ) + self.default_sort = default_sort or self.ext.config.default_sort or model.table._id.name self._path_base = base_path or self.ext.config.base_path self._path_rid = id_path or self.ext.config.id_path - self._serializer_class = serializer or \ - self.ext.config.default_serializer + self._serializer_class = serializer or self.ext.config.default_serializer self._parser_class = parser or self.ext.config.default_parser self._parsing_params_kwargs = {} self.model = model - self.use_save = ( - use_save if use_save is not None else - self.ext.config.use_save - ) - self.use_destroy = ( - use_destroy if use_destroy is not None else - self.ext.config.use_destroy - ) + self.use_save = use_save if use_save is not None else self.ext.config.use_save + self.use_destroy = use_destroy if use_destroy is not None else self.ext.config.use_destroy self.serializer = self._serializer_class(self.model) self.parser = self._parser_class(self.model) - self.enabled_methods = list(self._all_methods & set( - list( - enabled_methods if enabled_methods is not None else - self.ext.config.default_enabled_methods - ) - )) - self.disabled_methods = list(self._all_methods & set( - list( - disabled_methods if disabled_methods is not None else - self.ext.config.default_disabled_methods - ) - )) - self.list_envelope = list_envelope or self.ext.config.list_envelope - self.single_envelope = ( - single_envelope if single_envelope is not None else - self.ext.config.single_envelope + self.enabled_methods = list( + self._all_methods + & set(enabled_methods if enabled_methods is not None else self.ext.config.default_enabled_methods) ) - self.meta_envelope = ( - meta_envelope if meta_envelope is not None else - self.ext.config.meta_envelope - ) - self.groups_envelope = ( - groups_envelope if groups_envelope is not None else - self.ext.config.groups_envelope + self.disabled_methods = list( + self._all_methods + & set(disabled_methods if disabled_methods is not None else self.ext.config.default_disabled_methods) ) + self.list_envelope = list_envelope or self.ext.config.list_envelope + self.single_envelope = single_envelope if single_envelope is not None else self.ext.config.single_envelope + self.meta_envelope = meta_envelope if meta_envelope is not None else self.ext.config.meta_envelope + self.groups_envelope = groups_envelope if groups_envelope is not None else self.ext.config.groups_envelope self.use_envelope_on_parse = ( - use_envelope_on_parse if use_envelope_on_parse is not None else - self.ext.config.use_envelope_on_parse - ) - self.serialize_meta = ( - serialize_meta if serialize_meta is not None else - self.ext.config.serialize_meta + use_envelope_on_parse if use_envelope_on_parse is not None else self.ext.config.use_envelope_on_parse ) + self.serialize_meta = serialize_meta if serialize_meta is not None else self.ext.config.serialize_meta self._queryable_fields = [] self._sortable_fields = [] self._sortable_dict = {} self._groupable_fields = [] self._statsable_fields = [] self._json_query_pipe = JSONQueryPipe(self) - self._group_field_pipe = FieldPipe(self, '_groupable_fields') - self._stats_field_pipe = FieldsPipe(self, '_statsable_fields') + self._group_field_pipe = FieldPipe(self, "_groupable_fields") + self._stats_field_pipe = FieldsPipe(self, "_statsable_fields") self.allowed_sorts = [self.default_sort] self._openapi_specs = { - 'serializers': { - key: self.serializer for key in [ - 'index', - 'create', - 'read', - 'update', - 'delete', - 'sample' - ] - }, - 'parsers': { - key: self.parser for key in [ - 'create', - 'update' - ] - }, - 'additional_routes': [] + "serializers": {key: self.serializer for key in ["index", "create", "read", "update", "delete", "sample"]}, + "parsers": {key: self.parser for key in ["create", "update"]}, + "additional_routes": [], } self.openapi = ModuleOpenAPI(self) self._init_pipelines() @@ -328,39 +259,27 @@ def _init_pipelines(self): self.read_pipeline = [SetFetcher(self), RecordFetcher(self)] self.update_pipeline = [SetFetcher(self)] self.delete_pipeline = [SetFetcher(self)] - self.group_pipeline = [ - self._group_field_pipe, - SetFetcher(self), - self._json_query_pipe - ] - self.stats_pipeline = [ - self._stats_field_pipe, - SetFetcher(self), - self._json_query_pipe - ] + self.group_pipeline = [self._group_field_pipe, SetFetcher(self), self._json_query_pipe] + self.stats_pipeline = [self._stats_field_pipe, SetFetcher(self), self._json_query_pipe] self.sample_pipeline = [SetFetcher(self), self._json_query_pipe] def init(self): pass def _after_initialize(self): - self.list_envelope = self.list_envelope or 'data' + self.list_envelope = self.list_envelope or "data" #: adjust single row serialization based on envelope self.serialize_many = ( - self.serialize_with_list_envelope_and_meta if self.serialize_meta - else self.serialize_with_list_envelope + self.serialize_with_list_envelope_and_meta if self.serialize_meta else self.serialize_with_list_envelope ) if self.single_envelope: self.serialize_one = self.serialize_with_single_envelope if self.use_envelope_on_parse: self.parser.envelope = self.single_envelope - self._parsing_params_kwargs = {'envelope': self.single_envelope} + self._parsing_params_kwargs = {"envelope": self.single_envelope} else: self.serialize_one = self.serialize - self.pack_data = ( - self.pack_with_list_envelope_and_meta if self.serialize_meta - else self.pack_with_list_envelope - ) + self.pack_data = self.pack_with_list_envelope_and_meta if self.serialize_meta else self.pack_with_list_envelope #: adjust enabled methods for method_name in self.disabled_methods: self.enabled_methods.remove(method_name) @@ -368,42 +287,27 @@ def _after_initialize(self): self._expose_routes() def route( - self, - paths: Optional[Union[str, List[str]]] = None, - name: Optional[str] = None, - **kwargs + self, paths: Optional[Union[str, List[str]]] = None, name: Optional[str] = None, **kwargs ) -> RESTRoutingCtx: rv = super().route(paths, name, **kwargs) return RESTRoutingCtx(self, rv) def _expose_routes(self): - path_base_trail = ( - self._path_base.endswith('/') and self._path_base or - f'{self._path_base}/' - ) + path_base_trail = self._path_base.endswith("/") and self._path_base or f"{self._path_base}/" self._methods_map = { - 'index': (self._path_base, 'get'), - 'read': (self._path_rid, 'get'), - 'create': (self._path_base, 'post'), - 'update': (self._path_rid, ['put', 'patch']), - 'delete': (self._path_rid, 'delete'), - 'group': (f'{path_base_trail}group/', 'get'), - 'stats': (f'{path_base_trail}stats', 'get'), - 'sample': (f'{path_base_trail}sample', 'get') + "index": (self._path_base, "get"), + "read": (self._path_rid, "get"), + "create": (self._path_base, "post"), + "update": (self._path_rid, ["put", "patch"]), + "delete": (self._path_rid, "delete"), + "group": (f"{path_base_trail}group/", "get"), + "stats": (f"{path_base_trail}stats", "get"), + "sample": (f"{path_base_trail}sample", "get"), } self._functions_map = { - **{ - key: f'_{key}' - for key in self._all_methods - {'create', 'update', 'delete'} - }, - **{ - key: f'_{key}_without_save' if not self.use_save else f'_{key}' - for key in {'create', 'update'} - }, - **{ - key: f'_{key}_without_destroy' if not self.use_destroy else f'_{key}' - for key in {'delete'} - } + **{key: f"_{key}" for key in self._all_methods - {"create", "update", "delete"}}, + **{key: f"_{key}_without_save" if not self.use_save else f"_{key}" for key in {"create", "update"}}, + **{key: f"_{key}_without_destroy" if not self.use_destroy else f"_{key}" for key in {"delete"}}, } for key in self.enabled_methods: path, methods = self._methods_map[key] @@ -424,62 +328,49 @@ def get_pagination(self): except Exception: page = 1 try: - page_size = int( - request.query_params[self._pagination.pagesize_param] or 20) - assert ( - self._pagination.min_pagesize <= page_size <= - self._pagination.max_pagesize) + page_size = int(request.query_params[self._pagination.pagesize_param] or 20) + assert self._pagination.min_pagesize <= page_size <= self._pagination.max_pagesize except Exception: page_size = self._pagination.default_pagesize return page, page_size def get_sort(self, default=None, allowed_fields=None): default = default or self.default_sort - pfields = ( - ( - isinstance(request.query_params.sort_by, str) and - request.query_params.sort_by - ) or default - ).split(',') + pfields = ((isinstance(request.query_params.sort_by, str) and request.query_params.sort_by) or default).split( + "," + ) rv = [] allowed_fields = allowed_fields or self._sortable_dict for pfield in pfields: asc = True - if pfield.startswith('-'): + if pfield.startswith("-"): pfield = pfield[1:] asc = False field = allowed_fields.get(pfield) if not field: continue rv.append(field if asc else ~field) - return reduce( - lambda a, b: operator.or_(a, b) if a and b else None, - rv - ) if rv else allowed_fields.get(default) + return reduce(lambda a, b: operator.or_(a, b) if a and b else None, rv) if rv else allowed_fields.get(default) def build_error_400(self, errors=None): if errors: - return {'errors': errors} - return {'errors': {'request': 'bad request'}} + return {"errors": errors} + return {"errors": {"request": "bad request"}} def build_error_404(self): - return {'errors': {'id': 'record not found'}} + return {"errors": {"id": "record not found"}} def build_error_422(self, errors=None, to_dict=True): if errors: if to_dict and hasattr(errors, "as_dict"): errors = errors.as_dict() - return {'errors': errors} - return {'errors': {'request': 'unprocessable entity'}} + return {"errors": errors} + return {"errors": {"request": "unprocessable entity"}} def _build_meta(self, dbset, pagination, **kwargs): - count = kwargs.get('count', dbset.count()) + count = kwargs.get("count", dbset.count()) page, page_size = pagination - return { - 'object': 'list', - 'has_more': count > (page * page_size), - 'total_objects': count - } + return {"object": "list", "has_more": count > (page * page_size), "total_objects": count} def serialize(self, data, **extras): return _serialize(data, self.serializer, **extras) @@ -487,13 +378,11 @@ def serialize(self, data, **extras): def serialize_with_list_envelope(self, data, dbset, pagination, **extras): return {self.list_envelope: self.serialize(data, **extras)} - def serialize_with_list_envelope_and_meta( - self, data, dbset, pagination, **extras - ): - mextras = extras.pop('meta_extras', {}) + def serialize_with_list_envelope_and_meta(self, data, dbset, pagination, **extras): + mextras = extras.pop("meta_extras", {}) return { self.list_envelope: self.serialize(data, **extras), - self.meta_envelope: self.build_meta(dbset, pagination, **mextras) + self.meta_envelope: self.build_meta(dbset, pagination, **mextras), } def serialize_with_single_envelope(self, data, **extras): @@ -506,11 +395,8 @@ def pack_with_list_envelope_and_meta(self, envelope, data, **extras): count = len(data) return { envelope: data, - self.meta_envelope: self.build_meta( - sdict(count=lambda c=count: c), - (1, count) - ), - **extras + self.meta_envelope: self.build_meta(sdict(count=lambda c=count: c), (1, count)), + **extras, } async def parse_params(self, *params): @@ -611,43 +497,27 @@ async def _delete_without_destroy(self, dbset, rid): #: additional routes async def _group(self, dbset, field): count_field = self.model.table._id.count() - sort = self.get_sort( - default='count', - allowed_fields={'count': count_field} - ) + sort = self.get_sort(default="count", allowed_fields={"count": count_field}) data = [ - { - 'value': row[self.model.table][field.name], - 'count': row[count_field] - } for row in dbset.select( - field, count_field, groupby=field, orderby=sort - ) + {"value": row[self.model.table][field.name], "count": row[count_field]} + for row in dbset.select(field, count_field, groupby=field, orderby=sort) ] return self.pack_data(self.groups_envelope, data) async def _stats(self, dbset, fields): grouped_fields, select_fields, rv = {}, [], {} for field in fields: - grouped_fields[field.name] = { - 'min': field.min(), - 'max': field.max(), - 'avg': field.avg() - } + grouped_fields[field.name] = {"min": field.min(), "max": field.max(), "avg": field.avg()} select_fields.extend(grouped_fields[field.name].values()) row = dbset.select(*select_fields).first() for key, attrs in grouped_fields.items(): - rv[key] = { - attr_key: row[field] for attr_key, field in attrs.items() - } + rv[key] = {attr_key: row[field] for attr_key, field in attrs.items()} return rv async def _sample(self, dbset): _, page_size = self.get_pagination() - rows = dbset.select(paginate=(1, page_size), orderby='') - return self.serialize_many( - rows, dbset, (1, page_size), - meta_extras={'count': len(rows)} - ) + rows = dbset.select(paginate=(1, page_size), orderby="") + return self.serialize_many(rows, dbset, (1, page_size), meta_extras={"count": len(rows)}) #: properties @property @@ -657,9 +527,7 @@ def allowed_sorts(self) -> List[str]: @allowed_sorts.setter def allowed_sorts(self, val: List[str]): self._sortable_fields = val - self._sortable_dict = { - field: self.model.table[field] for field in self._sortable_fields - } + self._sortable_dict = {field: self.model.table[field] for field in self._sortable_fields} @property def query_allowed_fields(self) -> List[str]: @@ -689,38 +557,23 @@ def stats_allowed_fields(self, val: List[str]): self._stats_field_pipe.set_accepted() #: decorators - def get_dbset( - self, - f: Callable[[RESTModule], DBSet] - ) -> Callable[[RESTModule], DBSet]: + def get_dbset(self, f: Callable[[RESTModule], DBSet]) -> Callable[[RESTModule], DBSet]: self._fetcher_method = f return f - def get_row( - self, - f: Callable[[DBSet], Optional[Row]] - ) -> Callable[[DBSet], Optional[Row]]: + def get_row(self, f: Callable[[DBSet], Optional[Row]]) -> Callable[[DBSet], Optional[Row]]: self._select_method = f return f - def before_create( - self, - f: Callable[[sdict], None] - ) -> Callable[[sdict], None]: + def before_create(self, f: Callable[[sdict], None]) -> Callable[[sdict], None]: self._before_create_callbacks.append(f) return f - def before_update( - self, - f: Callable[[int, sdict], None] - ) -> Callable[[int, sdict], None]: + def before_update(self, f: Callable[[int, sdict], None]) -> Callable[[int, sdict], None]: self._before_update_callbacks.append(f) return f - def after_parse_params( - self, - f: Callable[[sdict], None] - ) -> Callable[[sdict], None]: + def after_parse_params(self, f: Callable[[sdict], None]) -> Callable[[sdict], None]: self._after_params_callbacks.append(f) return f @@ -738,34 +591,23 @@ def after_delete(self, f: Callable[[int], None]) -> Callable[[int], None]: def index(self, pipeline=[]): pipeline = self.index_pipeline + pipeline - return self.route( - self._path_base, pipeline=pipeline, methods='get', name='index' - ) + return self.route(self._path_base, pipeline=pipeline, methods="get", name="index") def read(self, pipeline=[]): pipeline = self.read_pipeline + pipeline - return self.route( - self._path_rid, pipeline=pipeline, methods='get', name='read' - ) + return self.route(self._path_rid, pipeline=pipeline, methods="get", name="read") def create(self, pipeline=[]): pipeline = self.create_pipeline + pipeline - return self.route( - self._path_base, pipeline=pipeline, methods='post', name='create' - ) + return self.route(self._path_base, pipeline=pipeline, methods="post", name="create") def update(self, pipeline=[]): pipeline = self.update_pipeline + pipeline - return self.route( - self._path_rid, pipeline=pipeline, methods=['put', 'patch'], - name='update' - ) + return self.route(self._path_rid, pipeline=pipeline, methods=["put", "patch"], name="update") def delete(self, pipeline=[]): pipeline = self.delete_pipeline + pipeline - return self.route( - self._path_rid, pipeline=pipeline, methods='delete', name='delete' - ) + return self.route(self._path_rid, pipeline=pipeline, methods="delete", name="delete") def on_400(self, f): self.error_400 = f @@ -793,9 +635,7 @@ def allowed_sorts(self) -> List[str]: def allowed_sorts(self, val: List[str]): for module in self.modules: module._sortable_fields = val - module._sortable_dict = { - field: module.model.table[field] for field in module._sortable_fields - } + module._sortable_dict = {field: module.model.table[field] for field in module._sortable_fields} @property def query_allowed_fields(self) -> List[str]: @@ -827,42 +667,27 @@ def stats_allowed_fields(self, val: List[str]): module._statsable_fields = val module._stats_field_pipe.set_accepted() - def get_dbset( - self, - f: Callable[[RESTModule], DBSet] - ) -> Callable[[RESTModule], DBSet]: + def get_dbset(self, f: Callable[[RESTModule], DBSet]) -> Callable[[RESTModule], DBSet]: for module in self.modules: module._fetcher_method = f return f - def get_row( - self, - f: Callable[[DBSet], Optional[Row]] - ) -> Callable[[DBSet], Optional[Row]]: + def get_row(self, f: Callable[[DBSet], Optional[Row]]) -> Callable[[DBSet], Optional[Row]]: for module in self.modules: module._select_method = f return f - def before_create( - self, - f: Callable[[sdict], None] - ) -> Callable[[sdict], None]: + def before_create(self, f: Callable[[sdict], None]) -> Callable[[sdict], None]: for module in self.modules: module._before_create_callbacks.append(f) return f - def before_update( - self, - f: Callable[[int, sdict], None] - ) -> Callable[[int, sdict], None]: + def before_update(self, f: Callable[[int, sdict], None]) -> Callable[[int, sdict], None]: for module in self.modules: module._before_update_callbacks.append(f) return f - def after_parse_params( - self, - f: Callable[[sdict], None] - ) -> Callable[[sdict], None]: + def after_parse_params(self, f: Callable[[sdict], None]) -> Callable[[sdict], None]: for module in self.modules: module._after_params_callbacks.append(f) return f @@ -886,42 +711,21 @@ def index(self, pipeline=[]): routes = [] for module in self.modules: route_pipeline = module.index_pipeline + pipeline - routes.append( - module.route( - module._path_base, - pipeline=route_pipeline, - methods='get', - name='index' - ) - ) + routes.append(module.route(module._path_base, pipeline=route_pipeline, methods="get", name="index")) return RoutingCtxGroup(routes) def read(self, pipeline=[]): routes = [] for module in self.modules: route_pipeline = module.read_pipeline + pipeline - routes.append( - module.route( - module._path_rid, - pipeline=route_pipeline, - methods='get', - name='read' - ) - ) + routes.append(module.route(module._path_rid, pipeline=route_pipeline, methods="get", name="read")) return RoutingCtxGroup(routes) def create(self, pipeline=[]): routes = [] for module in self.modules: route_pipeline = module.create_pipeline + pipeline - routes.append( - module.route( - module._path_base, - pipeline=route_pipeline, - methods='post', - name='create' - ) - ) + routes.append(module.route(module._path_base, pipeline=route_pipeline, methods="post", name="create")) return RoutingCtxGroup(routes) def update(self, pipeline=[]): @@ -929,12 +733,7 @@ def update(self, pipeline=[]): for module in self.modules: route_pipeline = module.update_pipeline + pipeline routes.append( - module.route( - module._path_rid, - pipeline=route_pipeline, - methods=['put', 'patch'], - name='update' - ) + module.route(module._path_rid, pipeline=route_pipeline, methods=["put", "patch"], name="update") ) return RoutingCtxGroup(routes) @@ -942,14 +741,7 @@ def delete(self, pipeline=[]): routes = [] for module in self.modules: route_pipeline = module.delete_pipeline + pipeline - routes.append( - module.route( - module._path_rid, - pipeline=route_pipeline, - methods='delete', - name='delete' - ) - ) + routes.append(module.route(module._path_rid, pipeline=route_pipeline, methods="delete", name="delete")) return RoutingCtxGroup(routes) def on_400(self, f): diff --git a/emmett_rest/serializers.py b/emmett_rest/serializers.py index d98106d..73e71f0 100644 --- a/emmett_rest/serializers.py +++ b/emmett_rest/serializers.py @@ -1,12 +1,12 @@ # -*- coding: utf-8 -*- """ - emmett_rest.serializers - ----------------------- +emmett_rest.serializers +----------------------- - Provides REST serialization tools +Provides REST serialization tools - :copyright: 2017 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2017 Giovanni Barillari +:license: BSD-3-Clause """ from typing import List, Optional @@ -27,7 +27,7 @@ def __init__(self, model): readable_map = {} for fieldname in self._model.table.fields: readable_map[fieldname] = self._model.table[fieldname].readable - if hasattr(self._model, 'rest_rw'): + if hasattr(self._model, "rest_rw"): self.attributes = [] for key, value in self._model.rest_rw.items(): if isinstance(value, tuple): @@ -44,7 +44,7 @@ def __init__(self, model): self.attributes.remove(el) _attrs_override_ = [] for key in dir(self): - if not key.startswith('_') and callable(getattr(self, key)): + if not key.startswith("_") and callable(getattr(self, key)): _attrs_override_.append(key) self._attrs_override_ = _attrs_override_ self._init() diff --git a/emmett_rest/typing.py b/emmett_rest/typing.py index 8eba2d3..328ec11 100644 --- a/emmett_rest/typing.py +++ b/emmett_rest/typing.py @@ -1,12 +1,12 @@ # -*- coding: utf-8 -*- """ - emmett_rest.typing - ------------------ +emmett_rest.typing +------------------ - Provides typing helpers +Provides typing helpers - :copyright: 2017 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2017 Giovanni Barillari +:license: BSD-3-Clause """ from typing import Type diff --git a/emmett_rest/wrappers.py b/emmett_rest/wrappers.py index 8162638..95b90c9 100644 --- a/emmett_rest/wrappers.py +++ b/emmett_rest/wrappers.py @@ -1,12 +1,12 @@ # -*- coding: utf-8 -*- """ - emmett_rest.wrappers - -------------------- +emmett_rest.wrappers +-------------------- - Provides wrappers for the REST extension +Provides wrappers for the REST extension - :copyright: 2017 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2017 Giovanni Barillari +:license: BSD-3-Clause """ from functools import wraps @@ -43,7 +43,7 @@ def rest_module_from_app( url_prefix: Optional[str] = None, hostname: Optional[str] = None, module_class: Optional[Type[RESTModule]] = None, - **kwargs: Any + **kwargs: Any, ) -> RESTModule: module_class = module_class or ext.config.default_module_class return module_class.from_app( @@ -68,8 +68,9 @@ def rest_module_from_app( id_path=id_path, url_prefix=url_prefix, hostname=hostname, - opts=kwargs + opts=kwargs, ) + return rest_module_from_app @@ -97,7 +98,7 @@ def rest_module_from_module( url_prefix: Optional[str] = None, hostname: Optional[str] = None, module_class: Optional[Type[RESTModule]] = None, - **kwargs: Any + **kwargs: Any, ) -> RESTModule: module_class = module_class or ext.config.default_module_class return module_class.from_module( @@ -123,8 +124,9 @@ def rest_module_from_module( id_path=id_path, url_prefix=url_prefix, hostname=hostname, - opts=kwargs + opts=kwargs, ) + return rest_module_from_module @@ -152,7 +154,7 @@ def rest_module_from_modulegroup( url_prefix: Optional[str] = None, hostname: Optional[str] = None, module_class: Optional[Type[RESTModule]] = None, - **kwargs: Any + **kwargs: Any, ) -> RESTModulesGrouped: module_class = module_class or ext.config.default_module_class mods = [] @@ -180,10 +182,11 @@ def rest_module_from_modulegroup( id_path=id_path, url_prefix=url_prefix, hostname=hostname, - opts=kwargs + opts=kwargs, ) mods.append(mod) return RESTModulesGrouped(*mods) + return rest_module_from_modulegroup @@ -191,4 +194,5 @@ def wrap_method_on_obj(method, obj): @wraps(method) def wrapped(*args, **kwargs): return method(obj, *args, **kwargs) + return wrapped diff --git a/pyproject.toml b/pyproject.toml index 17571c9..bcce829 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,16 +1,18 @@ -[project] -name = "emmett-rest" +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" -[tool.poetry] +[project] name = "emmett-rest" -version = "1.5.2" +version = "1.6.0" description = "REST extension for Emmett framework" -authors = ["Giovanni Barillari "] +readme = "README.md" license = "BSD-3-Clause" +requires-python = ">=3.8" -readme = "README.md" -homepage = "https://github.com/emmett-framework/rest" -repository = "https://github.com/emmett-framework/rest" +authors = [ + { name = "Giovanni Barillari", email = "g@baro.dev" } +] keywords = ["rest", "web", "emmett"] classifiers = [ @@ -25,32 +27,78 @@ classifiers = [ "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", "Topic :: Internet :: WWW/HTTP :: Dynamic Content", "Topic :: Software Development :: Libraries :: Python Modules" ] -packages = [ - {include = "emmett_rest/**/*.*", format = "sdist"}, - {include = "tests", format = "sdist"} +dependencies = [ + "emmett~=2.6", + "pydantic~=1.9", ] + +[project.urls] +Homepage = 'https://github.com/emmett-framework/rest' +Funding = 'https://github.com/sponsors/gi0baro' +Source = 'https://github.com/emmett-framework/rest' +Issues = 'https://github.com/emmett-framework/rest/issues' + +[tool.hatch.build.targets.sdist] include = [ - "CHANGES.md", - "LICENSE" + '/README.md', + '/CHANGES.md', + '/LICENSE', + '/emmett_rest', + '/tests', ] -[tool.poetry.dependencies] -python = "^3.8" -emmett = "^2.5" -pydantic = "^1.9.0" +[tool.ruff] +line-length = 120 -[tool.poetry.dev-dependencies] -pytest = "^6.2" -pytest-asyncio = "^0.15" -psycopg2-binary = "~2.9.5" +[tool.ruff.format] +quote-style = 'double' -[tool.poetry.urls] -"Issue Tracker" = "https://github.com/emmett-framework/rest/issues" +[tool.ruff.lint] +extend-select = [ + # E and F are enabled by default + 'B', # flake8-bugbear + 'C4', # flake8-comprehensions + 'C90', # mccabe + 'I', # isort + 'N', # pep8-naming + 'Q', # flake8-quotes + 'RUF100', # ruff (unused noqa) + 'S', # flake8-bandit + 'W', # pycodestyle +] +extend-ignore = [ + 'B006', # mutable function args are fine + 'B904', # rising without from is fine + 'E731', # assigning lambdas is fine + 'N815', # leave to us var naming + 'S101', # assert is fine + 'S110', # pass on exceptions is fine +] +mccabe = { max-complexity = 44 } -[build-system] -requires = ["poetry-core>=1.0.0"] -build-backend = "poetry.core.masonry.api" +[tool.ruff.lint.isort] +combine-as-imports = true +lines-after-imports = 2 +known-first-party = ['emmett_rest', 'tests'] + +[tool.ruff.lint.per-file-ignores] +'emmett_rest/__init__.py' = ['F401'] +'emmett_rest/openapi/__init__.py' = ['F401'] +'emmett_rest/queries/__init__.py' = ['F401'] + +[tool.pytest.ini_options] +asyncio_mode = 'auto' + +[tool.uv] +dev-dependencies = [ + "ruff~=0.5.0", + "pytest>=7.1", + "pytest-asyncio>=0.15", + "psycopg2-binary~=2.9", +] diff --git a/tests/conftest.py b/tests/conftest.py index 700231c..f073a81 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,39 +1,40 @@ # -*- coding: utf-8 -*- import os -import pytest +import pytest from emmett import App, sdict from emmett.orm import Database from emmett.orm.migrations.utils import generate_runtime_migration from emmett.parsers import Parsers from emmett.serializers import Serializers + from emmett_rest import REST -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def json_dump(): - return Serializers.get_for('json') + return Serializers.get_for("json") -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def json_load(): - return Parsers.get_for('json') + return Parsers.get_for("json") -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def db_config(): config = sdict() - config.adapter = 'postgres:psycopg2' - config.host = os.environ.get('POSTGRES_HOST', 'localhost') - config.port = int(os.environ.get('POSTGRES_PORT', 5432)) - config.user = os.environ.get('POSTGRES_USER', 'postgres') - config.password = os.environ.get('POSTGRES_PASSWORD', 'postgres') - config.database = os.environ.get('POSTGRES_DB', 'test') + config.adapter = "postgres:psycopg2" + config.host = os.environ.get("POSTGRES_HOST", "localhost") + config.port = int(os.environ.get("POSTGRES_PORT", 5432)) + config.user = os.environ.get("POSTGRES_USER", "postgres") + config.password = os.environ.get("POSTGRES_PASSWORD", "postgres") + config.database = os.environ.get("POSTGRES_DB", "test") return config -@pytest.fixture(scope='session') +@pytest.fixture(scope="function") def app(db_config): rv = App(__name__) rv.config.db = db_config @@ -47,16 +48,17 @@ def _db_teardown_generator(db, migration): def teardown(): with db.connection(): migration.down() + return teardown -@pytest.fixture(scope='function') +@pytest.fixture(scope="function") def raw_db(request, app): rv = Database(app) return rv -@pytest.fixture(scope='function') +@pytest.fixture(scope="function") def migration_db(request, app): def generator(*models): rv = Database(app) @@ -66,4 +68,5 @@ def generator(*models): migration.up() request.addfinalizer(_db_teardown_generator(rv, migration)) return rv + return generator diff --git a/tests/test_endpoints.py b/tests/test_endpoints.py index 4fdd14e..7c0db74 100644 --- a/tests/test_endpoints.py +++ b/tests/test_endpoints.py @@ -1,8 +1,8 @@ # -*- coding: utf-8 -*- -import pytest - from datetime import datetime + +import pytest from emmett import sdict from emmett.orm import Field, Model @@ -14,158 +14,106 @@ class Sample(Model): datetime = Field.datetime() -@pytest.fixture(scope='function') +@pytest.fixture(scope="function") def db(migration_db): return migration_db(Sample) -@pytest.fixture(scope='function') +@pytest.fixture(scope="function") def rest_app(app, db): app.pipeline = [db.pipe] - app.rest_module( - __name__, 'sample', Sample, - url_prefix='sample' - ) - app.rest_module( - __name__, 'sample_row', Sample, - url_prefix='sample_row', use_save=True, use_destroy=True - ) + app.rest_module(__name__, "sample", Sample, url_prefix="sample") + app.rest_module(__name__, "sample_row", Sample, url_prefix="sample_row", use_save=True, use_destroy=True) return app -@pytest.fixture(scope='function', autouse=True) +@pytest.fixture(scope="function", autouse=True) def db_sample(db): with db.connection(): - Sample.create( - str='foo', - int=1, - float=3.14, - datetime=datetime(1955, 11, 12) - ) + Sample.create(str="foo", int=1, float=3.14, datetime=datetime(1955, 11, 12)) -@pytest.fixture(scope='function') +@pytest.fixture(scope="function") def client(rest_app): return rest_app.test_client() def test_modules(rest_app): - mod1 = rest_app._modules['sample'] - mod2 = rest_app._modules['sample_row'] + mod1 = rest_app._modules["sample"] + mod2 = rest_app._modules["sample_row"] - assert mod1._functions_map['create'] == '_create_without_save' - assert mod1._functions_map['update'] == '_update_without_save' - assert mod1._functions_map['delete'] == '_delete_without_destroy' + assert mod1._functions_map["create"] == "_create_without_save" + assert mod1._functions_map["update"] == "_update_without_save" + assert mod1._functions_map["delete"] == "_delete_without_destroy" - assert mod2._functions_map['create'] == '_create' - assert mod2._functions_map['update'] == '_update' - assert mod2._functions_map['delete'] == '_delete' + assert mod2._functions_map["create"] == "_create" + assert mod2._functions_map["update"] == "_update" + assert mod2._functions_map["delete"] == "_delete" def test_index(client, json_load): - req = client.get('/sample') + req = client.get("/sample") assert req.status == 200 data = json_load(req.data) - assert {'data', 'meta'} == set(data.keys()) - assert {'id', 'str', 'int', 'float', 'datetime'} == set( - data['data'][0].keys()) - assert data['meta']['total_objects'] == 1 - assert not data['meta']['has_more'] + assert {"data", "meta"} == set(data.keys()) + assert {"id", "str", "int", "float", "datetime"} == set(data["data"][0].keys()) + assert data["meta"]["total_objects"] == 1 + assert not data["meta"]["has_more"] def test_get(client, json_load, db): with db.connection(): row = Sample.first() - req = client.get(f'/sample/{row.id}') + req = client.get(f"/sample/{row.id}") assert req.status == 200 data = json_load(req.data) - assert {'id', 'str', 'int', 'float', 'datetime'} == set( - data.keys()) + assert {"id", "str", "int", "float", "datetime"} == set(data.keys()) @pytest.mark.parametrize("base_path", ["/sample", "/sample_row"]) def test_create(client, json_load, json_dump, base_path): - body = sdict( - str='bar', - int=2, - float=1.1, - datetime=datetime(2000, 1, 1) - ) - req = client.post( - base_path, - data=json_dump(body), - headers=[('content-type', 'application/json')] - ) + body = sdict(str="bar", int=2, float=1.1, datetime=datetime(2000, 1, 1)) + req = client.post(base_path, data=json_dump(body), headers=[("content-type", "application/json")]) assert req.status == 201 data = json_load(req.data) - assert data['id'] - assert data['str'] == 'bar' + assert data["id"] + assert data["str"] == "bar" #: validation tests - body = sdict( - str='bar', - int='foo', - float=1.1, - datetime=datetime(2000, 1, 1) - ) - req = client.post( - base_path, - data=json_dump(body), - headers=[('content-type', 'application/json')] - ) + body = sdict(str="bar", int="foo", float=1.1, datetime=datetime(2000, 1, 1)) + req = client.post(base_path, data=json_dump(body), headers=[("content-type", "application/json")]) assert req.status == 422 data = json_load(req.data) - assert data['errors']['int'] + assert data["errors"]["int"] @pytest.mark.parametrize("base_path", ["/sample", "/sample_row"]) def test_update(client, json_load, json_dump, base_path): - body = sdict( - str='bar', - int=2, - float=1.1, - datetime=datetime(2000, 1, 1) - ) - req = client.post( - base_path, - data=json_dump(body), - headers=[('content-type', 'application/json')] - ) + body = sdict(str="bar", int=2, float=1.1, datetime=datetime(2000, 1, 1)) + req = client.post(base_path, data=json_dump(body), headers=[("content-type", "application/json")]) data = json_load(req.data) - rid = data['id'] - - change = sdict( - str='baz' - ) - req = client.put( - f'{base_path}/{rid}', - data=json_dump(change), - headers=[('content-type', 'application/json')] - ) + rid = data["id"] + + change = sdict(str="baz") + req = client.put(f"{base_path}/{rid}", data=json_dump(change), headers=[("content-type", "application/json")]) assert req.status == 200 data = json_load(req.data) - assert data['str'] == 'baz' + assert data["str"] == "baz" #: validation tests - change = sdict( - int='baz' - ) - req = client.put( - f'{base_path}/{rid}', - data=json_dump(change), - headers=[('content-type', 'application/json')] - ) + change = sdict(int="baz") + req = client.put(f"{base_path}/{rid}", data=json_dump(change), headers=[("content-type", "application/json")]) assert req.status == 422 data = json_load(req.data) - assert data['errors']['int'] + assert data["errors"]["int"] @pytest.mark.parametrize("base_path", ["/sample", "/sample_row"]) @@ -173,10 +121,7 @@ def test_delete(client, db, base_path): with db.connection(): row = Sample.first() - req = client.delete( - f'{base_path}/{row.id}', - headers=[('content-type', 'application/json')] - ) + req = client.delete(f"{base_path}/{row.id}", headers=[("content-type", "application/json")]) assert req.status == 200 with db.connection(): diff --git a/tests/test_endpoints_additional.py b/tests/test_endpoints_additional.py index ada8678..9004b88 100644 --- a/tests/test_endpoints_additional.py +++ b/tests/test_endpoints_additional.py @@ -1,7 +1,6 @@ # -*- coding: utf-8 -*- import pytest - from emmett.orm import Field, Model @@ -11,66 +10,63 @@ class Sample(Model): float = Field.float(default=0.0) -@pytest.fixture(scope='function') +@pytest.fixture(scope="function") def db(migration_db): return migration_db(Sample) -@pytest.fixture(scope='function') +@pytest.fixture(scope="function") def rest_app(app, db): app.pipeline = [db.pipe] - mod = app.rest_module( - __name__, 'sample', Sample, url_prefix='sample', - enabled_methods=['group', 'stats', 'sample'] - ) - mod.grouping_allowed_fields = ['str'] - mod.stats_allowed_fields = ['int', 'float'] + mod = app.rest_module(__name__, "sample", Sample, url_prefix="sample", enabled_methods=["group", "stats", "sample"]) + mod.grouping_allowed_fields = ["str"] + mod.stats_allowed_fields = ["int", "float"] return app -@pytest.fixture(scope='function', autouse=True) +@pytest.fixture(scope="function", autouse=True) def db_sample(db): with db.connection(): - Sample.create(str='foo') - Sample.create(str='foo', int=5, float=5.0) - Sample.create(str='bar', int=10, float=10.0) + Sample.create(str="foo") + Sample.create(str="foo", int=5, float=5.0) + Sample.create(str="bar", int=10, float=10.0) -@pytest.fixture(scope='function') +@pytest.fixture(scope="function") def client(rest_app): return rest_app.test_client() def test_grouping(client, json_load): - req = client.get('/sample/group/str', query_string={'sort_by': '-count'}) + req = client.get("/sample/group/str", query_string={"sort_by": "-count"}) assert req.status == 200 data = json_load(req.data) - assert data['meta']['total_objects'] == 2 + assert data["meta"]["total_objects"] == 2 - assert data['data'][0]['value'] == 'foo' - assert data['data'][0]['count'] == 2 - assert data['data'][1]['value'] == 'bar' - assert data['data'][1]['count'] == 1 + assert data["data"][0]["value"] == "foo" + assert data["data"][0]["count"] == 2 + assert data["data"][1]["value"] == "bar" + assert data["data"][1]["count"] == 1 def test_stats(client, json_load): - req = client.get('/sample/stats', query_string={'fields': 'int,float'}) + req = client.get("/sample/stats", query_string={"fields": "int,float"}) assert req.status == 200 data = json_load(req.data) - assert data['int']['min'] == 0 - assert data['int']['max'] == 10 - assert data['int']['avg'] == 5 - assert data['float']['min'] == 0.0 - assert data['float']['max'] == 10.0 - assert data['float']['avg'] == 5.0 + assert data["int"]["min"] == 0 + assert data["int"]["max"] == 10 + assert data["int"]["avg"] == 5 + assert data["float"]["min"] == 0.0 + assert data["float"]["max"] == 10.0 + assert data["float"]["avg"] == 5.0 def test_sample(client, json_load): - req = client.get('/sample/sample') + req = client.get("/sample/sample") assert req.status == 200 data = json_load(req.data) - assert data['meta']['total_objects'] == 3 - assert not data['meta']['has_more'] + assert data["meta"]["total_objects"] == 3 + assert not data["meta"]["has_more"] diff --git a/tests/test_envelopes.py b/tests/test_envelopes.py index c3aa285..e0230ea 100644 --- a/tests/test_envelopes.py +++ b/tests/test_envelopes.py @@ -1,7 +1,6 @@ # -*- coding: utf-8 -*- import pytest - from emmett.orm import Field, Model @@ -9,107 +8,103 @@ class Sample(Model): str = Field() -@pytest.fixture(scope='function') +@pytest.fixture(scope="function") def db(migration_db): return migration_db(Sample) -@pytest.fixture(scope='function') +@pytest.fixture(scope="function") def rest_app(app, db): app.pipeline = [db.pipe] return app -@pytest.fixture(scope='function', autouse=True) +@pytest.fixture(scope="function", autouse=True) def db_sample(db): with db.connection(): - Sample.create(str='foo') + Sample.create(str="foo") -@pytest.fixture(scope='function') +@pytest.fixture(scope="function") def client_default(rest_app): - rest_app.rest_module( - __name__, 'sample', Sample, url_prefix='sample' - ) + rest_app.rest_module(__name__, "sample", Sample, url_prefix="sample") return rest_app.test_client() -@pytest.fixture(scope='function') +@pytest.fixture(scope="function") def client_envelopes(rest_app): rest_app.rest_module( - __name__, 'sample', Sample, url_prefix='sample', - single_envelope='sample', list_envelope='samples', - use_envelope_on_parse=True + __name__, + "sample", + Sample, + url_prefix="sample", + single_envelope="sample", + list_envelope="samples", + use_envelope_on_parse=True, ) return rest_app.test_client() def test_default_index(client_default, json_load): - req = client_default.get('/sample') + req = client_default.get("/sample") assert req.status == 200 data = json_load(req.data) - assert {'data', 'meta'} == set(data.keys()) + assert {"data", "meta"} == set(data.keys()) def test_default_get(client_default, json_load, db): with db.connection(): row = Sample.first() - req = client_default.get(f'/sample/{row.id}') + req = client_default.get(f"/sample/{row.id}") assert req.status == 200 data = json_load(req.data) - assert {'id', 'str'} == set(data.keys()) + assert {"id", "str"} == set(data.keys()) def test_envelopes_index(client_envelopes, json_load): - req = client_envelopes.get('/sample') + req = client_envelopes.get("/sample") assert req.status == 200 data = json_load(req.data) - assert {'samples', 'meta'} == set(data.keys()) + assert {"samples", "meta"} == set(data.keys()) def test_envelopes_get(client_envelopes, json_load, db): with db.connection(): row = Sample.first() - req = client_envelopes.get(f'/sample/{row.id}') + req = client_envelopes.get(f"/sample/{row.id}") assert req.status == 200 data = json_load(req.data) - assert {'sample'} == set(data.keys()) + assert {"sample"} == set(data.keys()) def test_envelopes_create(client_envelopes, json_load, json_dump): req = client_envelopes.post( - '/sample', - data=json_dump({'sample': {'str': 'foo'}}), - headers=[('content-type', 'application/json')] + "/sample", data=json_dump({"sample": {"str": "foo"}}), headers=[("content-type", "application/json")] ) assert req.status == 201 data = json_load(req.data) - assert {'sample'} == set(data.keys()) - assert data['sample']['id'] + assert {"sample"} == set(data.keys()) + assert data["sample"]["id"] def test_envelopes_update(client_envelopes, json_load, json_dump): req = client_envelopes.post( - '/sample', - data=json_dump({'sample': {'str': 'foo'}}), - headers=[('content-type', 'application/json')] + "/sample", data=json_dump({"sample": {"str": "foo"}}), headers=[("content-type", "application/json")] ) data = json_load(req.data) - rid = data['sample']['id'] + rid = data["sample"]["id"] req = client_envelopes.put( - f'/sample/{rid}', - data=json_dump({'sample': {'str': 'baz'}}), - headers=[('content-type', 'application/json')] + f"/sample/{rid}", data=json_dump({"sample": {"str": "baz"}}), headers=[("content-type", "application/json")] ) assert req.status == 200 data = json_load(req.data) - assert data['sample']['str'] == 'baz' + assert data["sample"]["str"] == "baz" diff --git a/tests/test_meta.py b/tests/test_meta.py index a4a1b27..72c4e58 100644 --- a/tests/test_meta.py +++ b/tests/test_meta.py @@ -1,8 +1,8 @@ # -*- coding: utf-8 -*- import math -import pytest +import pytest from emmett.orm import Field, Model @@ -10,45 +10,38 @@ class Sample(Model): str = Field() -@pytest.fixture(scope='function') +@pytest.fixture(scope="function") def db(migration_db): return migration_db(Sample) -@pytest.fixture(scope='function') +@pytest.fixture(scope="function") def rest_app(app, db): app.pipeline = [db.pipe] return app -@pytest.fixture(scope='function', autouse=True) +@pytest.fixture(scope="function", autouse=True) def db_sample(db): with db.connection(): - Sample.create(str='foo') + Sample.create(str="foo") -@pytest.fixture(scope='function') +@pytest.fixture(scope="function") def client_meta(rest_app): - rest_app.rest_module( - __name__, 'sample', Sample, url_prefix='sample' - ) + rest_app.rest_module(__name__, "sample", Sample, url_prefix="sample") return rest_app.test_client() -@pytest.fixture(scope='function') +@pytest.fixture(scope="function") def client_nometa(rest_app): - rest_app.rest_module( - __name__, 'sample', Sample, url_prefix='sample', - serialize_meta=False - ) + rest_app.rest_module(__name__, "sample", Sample, url_prefix="sample", serialize_meta=False) return rest_app.test_client() -@pytest.fixture(scope='function') +@pytest.fixture(scope="function") def client_metacustom(rest_app): - mod = rest_app.rest_module( - __name__, 'sample', Sample, url_prefix='sample' - ) + mod = rest_app.rest_module(__name__, "sample", Sample, url_prefix="sample") @mod.meta_builder def _meta(dbset, pagination, **kwargs): @@ -56,43 +49,43 @@ def _meta(dbset, pagination, **kwargs): page, page_size = pagination total_pages = math.ceil(count / page_size) return { - 'page': page, - 'page_prev': page - 1 if page > 1 else None, - 'page_next': page + 1 if page < total_pages else None, - 'total_pages': total_pages, - 'total_objects': count + "page": page, + "page_prev": page - 1 if page > 1 else None, + "page_next": page + 1 if page < total_pages else None, + "total_pages": total_pages, + "total_objects": count, } return rest_app.test_client() def test_meta_index(client_meta, json_load): - req = client_meta.get('/sample') + req = client_meta.get("/sample") assert req.status == 200 data = json_load(req.data) - assert {'data', 'meta'} == set(data.keys()) - assert data['meta']['object'] == 'list' - assert data['meta']['total_objects'] == 1 - assert not data['meta']['has_more'] + assert {"data", "meta"} == set(data.keys()) + assert data["meta"]["object"] == "list" + assert data["meta"]["total_objects"] == 1 + assert not data["meta"]["has_more"] def test_nometa_index(client_nometa, json_load): - req = client_nometa.get('/sample') + req = client_nometa.get("/sample") assert req.status == 200 data = json_load(req.data) - assert {'data'} == set(data.keys()) + assert {"data"} == set(data.keys()) def test_metacustom_index(client_metacustom, json_load): - req = client_metacustom.get('/sample') + req = client_metacustom.get("/sample") assert req.status == 200 data = json_load(req.data) - assert {'data', 'meta'} == set(data.keys()) - assert data['meta']['total_objects'] == 1 - assert data['meta']['total_pages'] == 1 - assert data['meta']['page'] == 1 - assert data['meta']['page_prev'] is None - assert data['meta']['page_next'] is None + assert {"data", "meta"} == set(data.keys()) + assert data["meta"]["total_objects"] == 1 + assert data["meta"]["total_pages"] == 1 + assert data["meta"]["page"] == 1 + assert data["meta"]["page_prev"] is None + assert data["meta"]["page_next"] is None diff --git a/tests/test_queries.py b/tests/test_queries.py index 23b3e16..94630d5 100644 --- a/tests/test_queries.py +++ b/tests/test_queries.py @@ -1,10 +1,10 @@ # -*- coding: utf-8 -*- import pytest - -from pydal.objects import Query -from emmett import sdict, current, now +from emmett import current, now, sdict from emmett.orm import Field, Model, geo +from pydal.objects import Query + from emmett_rest.queries import JSONQueryPipe from emmett_rest.queries.parser import parse_conditions @@ -18,7 +18,7 @@ class Sample(Model): geopoly = Field.geometry("POLYGON") -@pytest.fixture(scope='function') +@pytest.fixture(scope="function") def db(raw_db): raw_db.define_models(Sample) return raw_db @@ -38,291 +38,195 @@ def query_component_equal(c1, c2): def queries_equal(q1, q2): - ctx = [ - {'op': q1.op, 'elements': [q1.first, q1.second]}, - {'op': q2.op, 'elements': [q2.first, q2.second]} - ] - if ctx[0]['op'] != ctx[1]['op']: + ctx = [{"op": q1.op, "elements": [q1.first, q1.second]}, {"op": q2.op, "elements": [q2.first, q2.second]}] + if ctx[0]["op"] != ctx[1]["op"]: return False equality_count = 0 - for element1 in ctx[0]['elements']: - for element2 in ctx[1]['elements']: + for element1 in ctx[0]["elements"]: + for element2 in ctx[1]["elements"]: if query_component_equal(element1, element2): equality_count += 1 return equality_count == 2 def test_parse_fields(db): - qdict = { - 'str': 'bar' - } - parsed = parse_conditions(Sample, Sample.all(), qdict, {'str'}) - assert queries_equal( - parsed.query, - Sample.all().where(lambda m: m.str == 'bar').query - ) + qdict = {"str": "bar"} + parsed = parse_conditions(Sample, Sample.all(), qdict, {"str"}) + assert queries_equal(parsed.query, Sample.all().where(lambda m: m.str == "bar").query) - qdict = { - 'str': {'$regex': 'bar'} - } - parsed = parse_conditions(Sample, Sample.all(), qdict, {'str'}) - assert queries_equal( - parsed.query, - Sample.all().where( - lambda m: m.str.contains('bar', case_sensitive=True) - ).query - ) + qdict = {"str": {"$regex": "bar"}} + parsed = parse_conditions(Sample, Sample.all(), qdict, {"str"}) + assert queries_equal(parsed.query, Sample.all().where(lambda m: m.str.contains("bar", case_sensitive=True)).query) - qdict = { - 'str': {'$iregex': 'bar'} - } - parsed = parse_conditions(Sample, Sample.all(), qdict, {'str'}) - assert queries_equal( - parsed.query, - Sample.all().where(lambda m: m.str.contains('bar')).query - ) + qdict = {"str": {"$iregex": "bar"}} + parsed = parse_conditions(Sample, Sample.all(), qdict, {"str"}) + assert queries_equal(parsed.query, Sample.all().where(lambda m: m.str.contains("bar")).query) - qdict = { - 'int': 2 - } - parsed = parse_conditions(Sample, Sample.all(), qdict, {'int'}) - assert queries_equal( - parsed.query, - Sample.all().where(lambda m: m.int == 2).query - ) + qdict = {"int": 2} + parsed = parse_conditions(Sample, Sample.all(), qdict, {"int"}) + assert queries_equal(parsed.query, Sample.all().where(lambda m: m.int == 2).query) - qdict = { - 'int': {'$gte': 0, '$lt': 2} - } - parsed = parse_conditions(Sample, Sample.all(), qdict, {'int'}) - assert queries_equal( - parsed.query, - Sample.all().where(lambda m: (m.int >= 0) & (m.int < 2)).query - ) + qdict = {"int": {"$gte": 0, "$lt": 2}} + parsed = parse_conditions(Sample, Sample.all(), qdict, {"int"}) + assert queries_equal(parsed.query, Sample.all().where(lambda m: (m.int >= 0) & (m.int < 2)).query) - qdict = { - 'float': 2.3 - } - parsed = parse_conditions(Sample, Sample.all(), qdict, {'float'}) - assert queries_equal( - parsed.query, - Sample.all().where(lambda m: m.float == 2.3).query - ) + qdict = {"float": 2.3} + parsed = parse_conditions(Sample, Sample.all(), qdict, {"float"}) + assert queries_equal(parsed.query, Sample.all().where(lambda m: m.float == 2.3).query) - qdict = { - 'float': {'$gte': 2, '$lt': 5.5} - } - parsed = parse_conditions(Sample, Sample.all(), qdict, {'float'}) - assert queries_equal( - parsed.query, - Sample.all().where(lambda m: (m.float >= 2) & (m.float < 5.5)).query - ) + qdict = {"float": {"$gte": 2, "$lt": 5.5}} + parsed = parse_conditions(Sample, Sample.all(), qdict, {"float"}) + assert queries_equal(parsed.query, Sample.all().where(lambda m: (m.float >= 2) & (m.float < 5.5)).query) dt1, dt2 = now(), now().add(days=1) - qdict = { - 'datetime': dt1 - } - parsed = parse_conditions(Sample, Sample.all(), qdict, {'datetime'}) - assert queries_equal( - parsed.query, - Sample.all().where(lambda m: m.datetime == dt1).query - ) + qdict = {"datetime": dt1} + parsed = parse_conditions(Sample, Sample.all(), qdict, {"datetime"}) + assert queries_equal(parsed.query, Sample.all().where(lambda m: m.datetime == dt1).query) - qdict = { - 'datetime': {'$gte': dt1, '$lt': dt2} - } - parsed = parse_conditions(Sample, Sample.all(), qdict, {'datetime'}) - assert queries_equal( - parsed.query, - Sample.all().where( - lambda m: (m.datetime >= dt1) & (m.datetime < dt2) - ).query - ) + qdict = {"datetime": {"$gte": dt1, "$lt": dt2}} + parsed = parse_conditions(Sample, Sample.all(), qdict, {"datetime"}) + assert queries_equal(parsed.query, Sample.all().where(lambda m: (m.datetime >= dt1) & (m.datetime < dt2)).query) - qdict = { - 'geopoly': {'$geo.contains': {"type": "point", "coordinates": [1, 2]}} - } - parsed = parse_conditions(Sample, Sample.all(), qdict, {'geopoly'}) - assert queries_equal( - parsed.query, - Sample.all().where( - lambda m: m.geopoly.st_contains(geo.Point(1, 2)) - ).query - ) + qdict = {"geopoly": {"$geo.contains": {"type": "point", "coordinates": [1, 2]}}} + parsed = parse_conditions(Sample, Sample.all(), qdict, {"geopoly"}) + assert queries_equal(parsed.query, Sample.all().where(lambda m: m.geopoly.st_contains(geo.Point(1, 2))).query) qdict = { - 'geopoint': {'$geo.equals': {"type": "point", "coordinates": [1, 2]}}, - 'geopoly': {'$geo.equals': { - "type": "polygon", "coordinates": [[1, 2], [2, 2], [2, 1], [1, 2]] - }} + "geopoint": {"$geo.equals": {"type": "point", "coordinates": [1, 2]}}, + "geopoly": {"$geo.equals": {"type": "polygon", "coordinates": [[1, 2], [2, 2], [2, 1], [1, 2]]}}, } - parsed = parse_conditions(Sample, Sample.all(), qdict, {'geopoint', 'geopoly'}) + parsed = parse_conditions(Sample, Sample.all(), qdict, {"geopoint", "geopoly"}) assert queries_equal( parsed.query, - Sample.all().where( + Sample.all() + .where( lambda m: ( - m.geopoint.st_equals(geo.Point(1, 2)) & - m.geopoly.st_equals(geo.Polygon((1, 2), (2, 2), (2, 1), (1, 2))) + m.geopoint.st_equals(geo.Point(1, 2)) & m.geopoly.st_equals(geo.Polygon((1, 2), (2, 2), (2, 1), (1, 2))) ) - ).query + ) + .query, ) qdict = { - 'geopoint': {'$geo.intersects': { - "type": "polygon", "coordinates": [[1, 2], [2, 2], [2, 1], [1, 2]] - }}, - 'geopoly': {'$geo.intersects': { - "type": "polygon", "coordinates": [[1, 2], [2, 2], [2, 1], [1, 2]] - }} + "geopoint": {"$geo.intersects": {"type": "polygon", "coordinates": [[1, 2], [2, 2], [2, 1], [1, 2]]}}, + "geopoly": {"$geo.intersects": {"type": "polygon", "coordinates": [[1, 2], [2, 2], [2, 1], [1, 2]]}}, } - parsed = parse_conditions(Sample, Sample.all(), qdict, {'geopoint', 'geopoly'}) + parsed = parse_conditions(Sample, Sample.all(), qdict, {"geopoint", "geopoly"}) assert queries_equal( parsed.query, - Sample.all().where( + Sample.all() + .where( lambda m: ( - m.geopoint.st_intersects(geo.Polygon((1, 2), (2, 2), (2, 1), (1, 2))) & - m.geopoly.st_intersects(geo.Polygon((1, 2), (2, 2), (2, 1), (1, 2))) + m.geopoint.st_intersects(geo.Polygon((1, 2), (2, 2), (2, 1), (1, 2))) + & m.geopoly.st_intersects(geo.Polygon((1, 2), (2, 2), (2, 1), (1, 2))) ) - ).query + ) + .query, ) qdict = { - 'geopoint': {'$geo.overlaps': { - "type": "polygon", "coordinates": [[1, 2], [2, 2], [2, 1], [1, 2]] - }}, - 'geopoly': {'$geo.overlaps': { - "type": "polygon", "coordinates": [[1, 2], [2, 2], [2, 1], [1, 2]] - }} + "geopoint": {"$geo.overlaps": {"type": "polygon", "coordinates": [[1, 2], [2, 2], [2, 1], [1, 2]]}}, + "geopoly": {"$geo.overlaps": {"type": "polygon", "coordinates": [[1, 2], [2, 2], [2, 1], [1, 2]]}}, } - parsed = parse_conditions(Sample, Sample.all(), qdict, {'geopoint', 'geopoly'}) + parsed = parse_conditions(Sample, Sample.all(), qdict, {"geopoint", "geopoly"}) assert queries_equal( parsed.query, - Sample.all().where( + Sample.all() + .where( lambda m: ( - m.geopoint.st_overlaps(geo.Polygon((1, 2), (2, 2), (2, 1), (1, 2))) & - m.geopoly.st_overlaps(geo.Polygon((1, 2), (2, 2), (2, 1), (1, 2))) + m.geopoint.st_overlaps(geo.Polygon((1, 2), (2, 2), (2, 1), (1, 2))) + & m.geopoly.st_overlaps(geo.Polygon((1, 2), (2, 2), (2, 1), (1, 2))) ) - ).query + ) + .query, ) qdict = { - 'geopoint': {'$geo.touches': { - "type": "polygon", "coordinates": [[1, 2], [2, 2], [2, 1], [1, 2]] - }}, - 'geopoly': {'$geo.touches': { - "type": "polygon", "coordinates": [[1, 2], [2, 2], [2, 1], [1, 2]] - }} + "geopoint": {"$geo.touches": {"type": "polygon", "coordinates": [[1, 2], [2, 2], [2, 1], [1, 2]]}}, + "geopoly": {"$geo.touches": {"type": "polygon", "coordinates": [[1, 2], [2, 2], [2, 1], [1, 2]]}}, } - parsed = parse_conditions(Sample, Sample.all(), qdict, {'geopoint', 'geopoly'}) + parsed = parse_conditions(Sample, Sample.all(), qdict, {"geopoint", "geopoly"}) assert queries_equal( parsed.query, - Sample.all().where( + Sample.all() + .where( lambda m: ( - m.geopoint.st_touches(geo.Polygon((1, 2), (2, 2), (2, 1), (1, 2))) & - m.geopoly.st_touches(geo.Polygon((1, 2), (2, 2), (2, 1), (1, 2))) + m.geopoint.st_touches(geo.Polygon((1, 2), (2, 2), (2, 1), (1, 2))) + & m.geopoly.st_touches(geo.Polygon((1, 2), (2, 2), (2, 1), (1, 2))) ) - ).query + ) + .query, ) qdict = { - 'geopoint': {'$geo.within': { - "type": "polygon", "coordinates": [[1, 2], [2, 2], [2, 1], [1, 2]] - }}, - 'geopoly': {'$geo.within': { - "type": "polygon", "coordinates": [[1, 2], [2, 2], [2, 1], [1, 2]] - }} + "geopoint": {"$geo.within": {"type": "polygon", "coordinates": [[1, 2], [2, 2], [2, 1], [1, 2]]}}, + "geopoly": {"$geo.within": {"type": "polygon", "coordinates": [[1, 2], [2, 2], [2, 1], [1, 2]]}}, } - parsed = parse_conditions(Sample, Sample.all(), qdict, {'geopoint', 'geopoly'}) + parsed = parse_conditions(Sample, Sample.all(), qdict, {"geopoint", "geopoly"}) assert queries_equal( parsed.query, - Sample.all().where( + Sample.all() + .where( lambda m: ( - m.geopoint.st_within(geo.Polygon((1, 2), (2, 2), (2, 1), (1, 2))) & - m.geopoly.st_within(geo.Polygon((1, 2), (2, 2), (2, 1), (1, 2))) + m.geopoint.st_within(geo.Polygon((1, 2), (2, 2), (2, 1), (1, 2))) + & m.geopoly.st_within(geo.Polygon((1, 2), (2, 2), (2, 1), (1, 2))) ) - ).query + ) + .query, ) qdict = { - 'geopoint': { - '$geo.dwithin': { - "geometry": { - "type": "polygon", - "coordinates": [[1, 2], [2, 2], [2, 1], [1, 2]] - }, - "distance": 3.2 + "geopoint": { + "$geo.dwithin": { + "geometry": {"type": "polygon", "coordinates": [[1, 2], [2, 2], [2, 1], [1, 2]]}, + "distance": 3.2, } }, - 'geopoly': { - '$geo.dwithin': { - "geometry": { - "type": "polygon", - "coordinates": [[1, 2], [2, 2], [2, 1], [1, 2]] - }, - "distance": 4 + "geopoly": { + "$geo.dwithin": { + "geometry": {"type": "polygon", "coordinates": [[1, 2], [2, 2], [2, 1], [1, 2]]}, + "distance": 4, } - } + }, } - parsed = parse_conditions(Sample, Sample.all(), qdict, {'geopoint', 'geopoly'}) + parsed = parse_conditions(Sample, Sample.all(), qdict, {"geopoint", "geopoly"}) assert queries_equal( parsed.query, - Sample.all().where( + Sample.all() + .where( lambda m: ( - m.geopoint.st_dwithin( - geo.Polygon((1, 2), (2, 2), (2, 1), (1, 2)), 3.2 - ) & - m.geopoly.st_dwithin( - geo.Polygon((1, 2), (2, 2), (2, 1), (1, 2)), 4 - ) + m.geopoint.st_dwithin(geo.Polygon((1, 2), (2, 2), (2, 1), (1, 2)), 3.2) + & m.geopoly.st_dwithin(geo.Polygon((1, 2), (2, 2), (2, 1), (1, 2)), 4) ) - ).query + ) + .query, ) def test_parse_combined(db): dt1, dt2 = now(), now().add(days=1) qdict = { - 'str': 'bar', - 'int': {'$gt': 2}, - '$not': {'int': {'$in': [4, 5]}}, - '$or': [ - {'float': 3.2}, - {'datetime': {'$gte': dt1, '$lt': dt2}} - ] + "str": "bar", + "int": {"$gt": 2}, + "$not": {"int": {"$in": [4, 5]}}, + "$or": [{"float": 3.2}, {"datetime": {"$gte": dt1, "$lt": dt2}}], } - parsed = parse_conditions( - Sample, Sample.all(), qdict, {'str', 'int', 'float', 'datetime'}) + parsed = parse_conditions(Sample, Sample.all(), qdict, {"str", "int", "float", "datetime"}) assert queries_equal( parsed.query, - Sample.all().where( - lambda m: - (m.str == 'bar') & - (m.int > 2) & ( - ~m.int.belongs([4, 5]) & ( - (m.float == 3.2) | ( - (m.datetime >= dt1) & (m.datetime < dt2) - ) - ) - ) - ).query + Sample.all() + .where( + lambda m: (m.str == "bar") + & (m.int > 2) + & (~m.int.belongs([4, 5]) & ((m.float == 3.2) | ((m.datetime >= dt1) & (m.datetime < dt2)))) + ) + .query, ) - qdict = { - '$or': [ - {'float': 3.2}, - {'datetime': {'$gte': dt1, '$lt': dt2}} - ] - } - parsed = parse_conditions( - Sample, Sample.all(), qdict, {'str', 'int', 'float', 'datetime'}) + qdict = {"$or": [{"float": 3.2}, {"datetime": {"$gte": dt1, "$lt": dt2}}]} + parsed = parse_conditions(Sample, Sample.all(), qdict, {"str", "int", "float", "datetime"}) assert queries_equal( - parsed.query, - Sample.all().where( - lambda m: - (m.float == 3.2) | ( - (m.datetime >= dt1) & (m.datetime < dt2) - ) - ).query + parsed.query, Sample.all().where(lambda m: (m.float == 3.2) | ((m.datetime >= dt1) & (m.datetime < dt2))).query ) @@ -332,28 +236,10 @@ async def _fake_pipe(**kwargs): @pytest.mark.asyncio async def test_pipes(db, json_dump): - fake_mod = sdict( - _queryable_fields=['str', 'int'], - model=Sample, - ext=sdict( - config=sdict( - query_param='where' - ) - ) - ) + fake_mod = sdict(_queryable_fields=["str", "int"], model=Sample, ext=sdict(config=sdict(query_param="where"))) pipe = JSONQueryPipe(fake_mod) - qdict = { - '$or': [ - {'str': 'bar'}, - {'int': {'$gt': 0}} - ] - } - current.request = sdict( - query_params=sdict( - where=json_dump(qdict))) + qdict = {"$or": [{"str": "bar"}, {"int": {"$gt": 0}}]} + current.request = sdict(query_params=sdict(where=json_dump(qdict))) res = await pipe.pipe_request(_fake_pipe, dbset=Sample.all()) - assert queries_equal( - res['dbset'].query, - Sample.all().where(lambda m: (m.str == 'bar') | (m.int > 0)).query - ) + assert queries_equal(res["dbset"].query, Sample.all().where(lambda m: (m.str == "bar") | (m.int > 0)).query)