diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml new file mode 100644 index 00000000..70de72bc --- /dev/null +++ b/.github/workflows/lint.yml @@ -0,0 +1,30 @@ +name: lint + +on: + pull_request: + types: [opened, synchronize] + branches: + - master + +env: + PYTHON_VERSION: 3.12 + +jobs: + lint: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + - name: Set up Python ${{ env.PYTHON_VERSION }} + uses: actions/setup-python@v5 + with: + python-version: ${{ env.PYTHON_VERSION }} + - name: Install uv + uses: astral-sh/setup-uv@v3 + - name: Install dependencies + run: | + uv sync --dev + - name: Lint + run: | + source .venv/bin/activate + make lint diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 499971c1..e377885e 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -7,6 +7,9 @@ on: tags-ignore: - "**" pull_request: + types: [opened, synchronize] + branches: + - master jobs: Linux: diff --git a/Makefile b/Makefile new file mode 100644 index 00000000..14fd8afd --- /dev/null +++ b/Makefile @@ -0,0 +1,19 @@ +.DEFAULT_GOAL := all +pysources = emmett tests + +.PHONY: format +format: + ruff check --fix $(pysources) + ruff format $(pysources) + +.PHONY: lint +lint: + ruff check $(pysources) + ruff format --check $(pysources) + +.PHONY: test +test: + pytest -v tests + +.PHONY: all +all: format lint test diff --git a/emmett/__init__.py b/emmett/__init__.py index 26cb35a2..0ed4bb2f 100644 --- a/emmett/__init__.py +++ b/emmett/__init__.py @@ -9,5 +9,5 @@ from .http import redirect from .locals import T, now, request, response, session, websocket from .orm import Field -from .pipeline import Pipe, Injector +from .pipeline import Injector, Pipe from .routing.urls import url diff --git a/emmett/__main__.py b/emmett/__main__.py index 6dedb24f..5e4c7d2b 100644 --- a/emmett/__main__.py +++ b/emmett/__main__.py @@ -1,14 +1,15 @@ # -*- coding: utf-8 -*- """ - emmett.__main__ - ---------------- +emmett.__main__ +---------------- - Alias for Emmett CLI. +Alias for Emmett CLI. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ from emmett.cli import main + main(as_module=True) diff --git a/emmett/_internal.py b/emmett/_internal.py index f075b060..1b1aae5a 100644 --- a/emmett/_internal.py +++ b/emmett/_internal.py @@ -1,16 +1,16 @@ # -*- coding: utf-8 -*- """ - emmett._internal - ---------------- +emmett._internal +---------------- - Provides internally used helpers and objects. +Provides internally used helpers and objects. - :copyright: 2014 Giovanni Barillari +:copyright: 2014 Giovanni Barillari - Several parts of this code comes from Flask and Werkzeug. - :copyright: (c) 2014 by Armin Ronacher. +Several parts of this code comes from Flask and Werkzeug. +:copyright: (c) 2014 by Armin Ronacher. - :license: BSD-3-Clause +:license: BSD-3-Clause """ from __future__ import annotations @@ -23,18 +23,13 @@ #: monkey patches def _pendulum_to_datetime(obj): return datetime.datetime( - obj.year, obj.month, obj.day, - obj.hour, obj.minute, obj.second, obj.microsecond, - tzinfo=obj.tzinfo + obj.year, obj.month, obj.day, obj.hour, obj.minute, obj.second, obj.microsecond, tzinfo=obj.tzinfo ) def _pendulum_to_naive_datetime(obj): - obj = obj.in_timezone('UTC') - return datetime.datetime( - obj.year, obj.month, obj.day, - obj.hour, obj.minute, obj.second, obj.microsecond - ) + obj = obj.in_timezone("UTC") + return datetime.datetime(obj.year, obj.month, obj.day, obj.hour, obj.minute, obj.second, obj.microsecond) def _pendulum_json(obj): diff --git a/emmett/_reloader.py b/emmett/_reloader.py index d5a862f5..4dcf06a0 100644 --- a/emmett/_reloader.py +++ b/emmett/_reloader.py @@ -1,16 +1,16 @@ # -*- coding: utf-8 -*- """ - emmett._reloader - ---------------- +emmett._reloader +---------------- - Provides auto-reloading utilities. +Provides auto-reloading utilities. - :copyright: 2014 Giovanni Barillari +:copyright: 2014 Giovanni Barillari - Adapted from werkzeug code (http://werkzeug.pocoo.org) - :copyright: (c) 2015 by Armin Ronacher. +Adapted from werkzeug code (http://werkzeug.pocoo.org) +:copyright: (c) 2015 by Armin Ronacher. - :license: BSD-3-Clause +:license: BSD-3-Clause """ import multiprocessing @@ -19,12 +19,10 @@ import subprocess import sys import time - from itertools import chain from typing import Optional import click - from emmett_core._internal import locate_app from emmett_core.server import run as _server_run @@ -37,7 +35,7 @@ def _iter_module_files(): for module in list(sys.modules.values()): if module is None: continue - filename = getattr(module, '__file__', None) + filename = getattr(module, "__file__", None) if filename: old = None while not os.path.isfile(filename): @@ -58,15 +56,9 @@ def _get_args_for_reloading(): """ rv = [sys.executable] py_script = sys.argv[0] - if ( - os.name == 'nt' and not os.path.exists(py_script) and - os.path.exists(py_script + '.exe') - ): - py_script += '.exe' - if ( - os.path.splitext(rv[0])[1] == '.exe' and - os.path.splitext(py_script)[1] == '.exe' - ): + if os.name == "nt" and not os.path.exists(py_script) and os.path.exists(py_script + ".exe"): + py_script += ".exe" + if os.path.splitext(rv[0])[1] == ".exe" and os.path.splitext(py_script)[1] == ".exe": rv.pop(0) rv.append(py_script) rv.extend(sys.argv[1:]) @@ -82,8 +74,7 @@ class ReloaderLoop(object): _sleep = staticmethod(time.sleep) def __init__(self, extra_files=None, interval=1): - self.extra_files = set( - os.path.abspath(x) for x in extra_files or ()) + self.extra_files = {os.path.abspath(x) for x in extra_files or ()} self.interval = interval def run(self): @@ -94,10 +85,10 @@ def restart_with_reloader(self): but running the reloader thread. """ while 1: - click.secho('> Restarting (%s mode)' % self.name, fg='yellow') + click.secho("> Restarting (%s mode)" % self.name, fg="yellow") args = _get_args_for_reloading() new_environ = os.environ.copy() - new_environ['EMMETT_RUN_MAIN'] = 'true' + new_environ["EMMETT_RUN_MAIN"] = "true" # a weird bug on windows. sometimes unicode strings end up in the # environment and subprocess.call does not like this, encode them @@ -107,20 +98,20 @@ def restart_with_reloader(self): # if isinstance(value, unicode): # new_environ[key] = value.encode('iso-8859-1') - exit_code = subprocess.call(args, env=new_environ) + exit_code = subprocess.call(args, env=new_environ) # noqa: S603 if exit_code != 3: return exit_code def trigger_reload(self, process, filename): filename = os.path.abspath(filename) - click.secho('> Detected change in %r, reloading' % filename, fg='cyan') + click.secho("> Detected change in %r, reloading" % filename, fg="cyan") os.kill(process.pid, signal.SIGTERM) process.join() sys.exit(3) class StatReloaderLoop(ReloaderLoop): - name = 'stat' + name = "stat" def run(self, process): mtimes = {} @@ -141,18 +132,18 @@ def run(self, process): reloader_loops = { - 'stat': StatReloaderLoop, + "stat": StatReloaderLoop, } -reloader_loops['auto'] = reloader_loops['stat'] +reloader_loops["auto"] = reloader_loops["stat"] def run_with_reloader( interface, app_target, - host='127.0.0.1', + host="127.0.0.1", port=8000, - loop='auto', + loop="auto", log_level=None, log_access=False, threads=1, @@ -161,14 +152,15 @@ def run_with_reloader( ssl_keyfile: Optional[str] = None, extra_files=None, interval=1, - reloader_type='auto' + reloader_type="auto", ): reloader = reloader_loops[reloader_type](extra_files, interval) signal.signal(signal.SIGTERM, lambda *args: sys.exit(0)) try: - if os.environ.get('EMMETT_RUN_MAIN') == 'true': + if os.environ.get("EMMETT_RUN_MAIN") == "true": from .app import App + # FIXME: find a better way to have app files in stat checker locate_app(App, *app_target) @@ -185,7 +177,7 @@ def run_with_reloader( "threading_mode": threading_mode, "ssl_certfile": ssl_certfile, "ssl_keyfile": ssl_keyfile, - } + }, ) process.start() reloader.run(process) diff --git a/emmett/_shortcuts.py b/emmett/_shortcuts.py index f92129bc..570e78df 100644 --- a/emmett/_shortcuts.py +++ b/emmett/_shortcuts.py @@ -1,34 +1,34 @@ # -*- coding: utf-8 -*- """ - emmett._shortcuts - ----------------- +emmett._shortcuts +----------------- - Some shortcuts +Some shortcuts - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ import hashlib - from uuid import uuid4 -hashlib_md5 = lambda s: hashlib.md5(bytes(s, 'utf8')) -hashlib_sha1 = lambda s: hashlib.sha1(bytes(s, 'utf8')) + +hashlib_md5 = lambda s: hashlib.md5(bytes(s, "utf8")) +hashlib_sha1 = lambda s: hashlib.sha1(bytes(s, "utf8")) uuid = lambda: str(uuid4()) -def to_bytes(obj, charset='utf8', errors='strict'): +def to_bytes(obj, charset="utf8", errors="strict"): if obj is None: return None if isinstance(obj, (bytes, bytearray, memoryview)): return bytes(obj) if isinstance(obj, str): return obj.encode(charset, errors) - raise TypeError('Expected bytes') + raise TypeError("Expected bytes") -def to_unicode(obj, charset='utf8', errors='strict'): +def to_unicode(obj, charset="utf8", errors="strict"): if obj is None: return None if not isinstance(obj, bytes): diff --git a/emmett/app.py b/emmett/app.py index b9c8f614..888755a5 100644 --- a/emmett/app.py +++ b/emmett/app.py @@ -1,22 +1,20 @@ # -*- coding: utf-8 -*- """ - emmett.app - ---------- +emmett.app +---------- - Provides the central application object. +Provides the central application object. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ from __future__ import annotations import os - from typing import Any, Dict, List, Optional, Type, Union import click - from emmett_core._internal import create_missing_app_folders, get_root_path from emmett_core.app import App as _App, AppModule as _AppModule, AppModuleGroup as _AppModuleGroup, Config as _Config from emmett_core.protocols.asgi.handlers import HTTPHandler as ASGIHTTPHandler, WSHandler as ASGIWSHandler @@ -30,8 +28,8 @@ from .html import asis from .language.helpers import Tstr from .language.translator import Translator -from .pipeline import Pipe, Injector -from .routing.router import HTTPRouter, WebsocketRouter, RoutingCtx, RoutingCtxGroup +from .pipeline import Injector, Pipe +from .routing.router import HTTPRouter, RoutingCtx, RoutingCtxGroup, WebsocketRouter from .routing.urls import url from .templating.templater import Templater from .testing import EmmettTestClient @@ -44,8 +42,8 @@ class Config(_Config): def __init__(self, app: App): super().__init__(app) self._templates_auto_reload = app.debug or False - self._templates_encoding = 'utf8' - self._templates_escape = 'common' + self._templates_encoding = "utf8" + self._templates_escape = "common" self._templates_indent = False @property @@ -85,7 +83,6 @@ def templates_adjust_indent(self, value: bool): self._app.templater._set_indent(value) - class AppModule(_AppModule): @classmethod def from_app( @@ -103,7 +100,7 @@ def from_app( root_path: Optional[str], pipeline: List[Pipe], injectors: List[Injector], - opts: Dict[str, Any] = {} + opts: Dict[str, Any] = {}, ): return cls( app, @@ -119,7 +116,7 @@ def from_app( root_path=root_path, pipeline=pipeline, injectors=injectors, - **opts + **opts, ) @classmethod @@ -136,17 +133,14 @@ def from_module( hostname: Optional[str], cache: Optional[RouteCacheRule], root_path: Optional[str], - opts: Dict[str, Any] = {} + opts: Dict[str, Any] = {}, ): - if '.' in name: - raise RuntimeError( - "Nested app modules' names should not contains dots" - ) - name = appmod.name + '.' + name - if url_prefix and not url_prefix.startswith('/'): - url_prefix = '/' + url_prefix - module_url_prefix = (appmod.url_prefix + (url_prefix or '')) \ - if appmod.url_prefix else url_prefix + if "." in name: + raise RuntimeError("Nested app modules' names should not contains dots") + name = appmod.name + "." + name + if url_prefix and not url_prefix.startswith("/"): + url_prefix = "/" + url_prefix + module_url_prefix = (appmod.url_prefix + (url_prefix or "")) if appmod.url_prefix else url_prefix hostname = hostname or appmod.hostname cache = cache or appmod.cache return cls( @@ -163,7 +157,7 @@ def from_module( root_path=root_path, pipeline=appmod.pipeline, injectors=appmod.injectors, - **opts + **opts, ) @classmethod @@ -180,7 +174,7 @@ def from_module_group( hostname: Optional[str], cache: Optional[RouteCacheRule], root_path: Optional[str], - opts: Dict[str, Any] = {} + opts: Dict[str, Any] = {}, ) -> AppModulesGrouped: mods = [] for module in appmodgroup.modules: @@ -196,7 +190,7 @@ def from_module_group( hostname=hostname, cache=cache, root_path=root_path, - opts=opts + opts=opts, ) mods.append(mod) return AppModulesGrouped(*mods) @@ -214,7 +208,7 @@ def module( cache: Optional[RouteCacheRule] = None, root_path: Optional[str] = None, module_class: Optional[Type[AppModule]] = None, - **kwargs: Any + **kwargs: Any, ) -> AppModule: module_class = module_class or self.__class__ return module_class.from_module( @@ -229,7 +223,7 @@ def module( hostname=hostname, cache=cache, root_path=root_path, - opts=kwargs + opts=kwargs, ) def __init__( @@ -247,7 +241,7 @@ def __init__( root_path: Optional[str] = None, pipeline: Optional[List[Pipe]] = None, injectors: Optional[List[Injector]] = None, - **kwargs: Any + **kwargs: Any, ): super().__init__( app=app, @@ -260,7 +254,7 @@ def __init__( cache=cache, root_path=root_path, pipeline=pipeline, - **kwargs + **kwargs, ) #: - `template_folder` is referred to application `template_path` # - `template_path` is referred to module root_directory unless absolute @@ -284,22 +278,20 @@ def route( paths: Optional[Union[str, List[str]]] = None, name: Optional[str] = None, template: Optional[str] = None, - **kwargs + **kwargs, ) -> RoutingCtx: if name is not None and "." in name: - raise RuntimeError( - "App modules' route names should not contains dots" - ) + raise RuntimeError("App modules' route names should not contains dots") name = self.name + "." + (name or "") - pipeline = kwargs.get('pipeline', []) - injectors = kwargs.get('injectors', []) + pipeline = kwargs.get("pipeline", []) + injectors = kwargs.get("injectors", []) if self.pipeline: pipeline = self.pipeline + pipeline - kwargs['pipeline'] = pipeline + kwargs["pipeline"] = pipeline if self.injectors: injectors = self.injectors + injectors - kwargs['injectors'] = injectors - kwargs['cache'] = kwargs.get('cache', self.cache) + kwargs["injectors"] = injectors + kwargs["cache"] = kwargs.get("cache", self.cache) return self.app.route( paths=paths, name=name, @@ -308,17 +300,12 @@ def route( template_folder=self.template_folder, template_path=self.template_path, hostname=self.hostname, - **kwargs + **kwargs, ) class App(_App): - __slots__ = [ - 'template_default_extension', - 'template_path', - 'templater', - 'translator' - ] + __slots__ = ["template_default_extension", "template_path", "templater", "translator"] debug = None @@ -332,24 +319,30 @@ def __init__( import_name: str, root_path: Optional[str] = None, url_prefix: Optional[str] = None, - template_folder: str = 'templates', - config_folder: str = 'config' + template_folder: str = "templates", + config_folder: str = "config", ): - super().__init__(import_name=import_name, root_path=root_path, url_prefix=url_prefix, config_folder=config_folder, template_folder=template_folder) + super().__init__( + import_name=import_name, + root_path=root_path, + url_prefix=url_prefix, + config_folder=config_folder, + template_folder=template_folder, + ) self.cli = click.Group(self.import_name) self.translator = Translator( - os.path.join(self.root_path, 'languages'), - default_language=self.language_default or 'en', + os.path.join(self.root_path, "languages"), + default_language=self.language_default or "en", watch_changes=self.debug, - str_class=Tstr + str_class=Tstr, ) - self.template_default_extension = '.html' + self.template_default_extension = ".html" self.templater: Templater = Templater( path=self.template_path, encoding=self.config.templates_encoding, escape=self.config.templates_escape, adjust_indent=self.config.templates_adjust_indent, - reload=self.config.templates_auto_reload + reload=self.config.templates_auto_reload, ) def _configure_paths(self, root_path, opts): @@ -359,21 +352,21 @@ def _configure_paths(self, root_path, opts): self.static_path = os.path.join(self.root_path, "static") self.template_path = os.path.join(self.root_path, opts["template_folder"]) self.config_path = os.path.join(self.root_path, opts["config_folder"]) - create_missing_app_folders(self, ['languages', 'logs', 'static']) + create_missing_app_folders(self, ["languages", "logs", "static"]) def _init_routers(self, url_prefix): self._router_http = HTTPRouter(self, current, url_prefix=url_prefix) self._router_ws = WebsocketRouter(self, current, url_prefix=url_prefix) def _init_handlers(self): - self._asgi_handlers['http'] = ASGIHTTPHandler(self, current) - self._asgi_handlers['ws'] = ASGIWSHandler(self, current) - self._rsgi_handlers['http'] = RSGIHTTPHandler(self, current) - self._rsgi_handlers['ws'] = RSGIWSHandler(self, current) + self._asgi_handlers["http"] = ASGIHTTPHandler(self, current) + self._asgi_handlers["ws"] = ASGIWSHandler(self, current) + self._rsgi_handlers["http"] = RSGIHTTPHandler(self, current) + self._rsgi_handlers["ws"] = RSGIWSHandler(self, current) def _configure_asgi_handlers(self): - self._asgi_handlers['http']._configure_methods() - self._rsgi_handlers['http']._configure_methods() + self._asgi_handlers["http"]._configure_methods() + self._rsgi_handlers["http"]._configure_methods() def _register_with_ctx(self): current.app = self @@ -385,7 +378,7 @@ def language_default(self) -> Optional[str]: @language_default.setter def language_default(self, value: str): self._language_default = value - self.translator._update_config(self._language_default or 'en') + self.translator._update_config(self._language_default or "en") @property def injectors(self) -> List[Injector]: @@ -409,10 +402,10 @@ def route( template_folder: Optional[str] = None, template_path: Optional[str] = None, cache: Optional[RouteCacheRule] = None, - output: str = 'auto' + output: str = "auto", ) -> RoutingCtx: if callable(paths): - raise SyntaxError('Use @route(), not @route.') + raise SyntaxError("Use @route(), not @route.") return self._router_http( paths=paths, name=name, @@ -426,7 +419,7 @@ def route( template_folder=template_folder, template_path=template_path, cache=cache, - output=output + output=output, ) @property @@ -438,10 +431,7 @@ def command_group(self): return self.cli.group def render_template(self, filename: str) -> str: - ctx = { - 'current': current, 'url': url, 'asis': asis, - 'load_component': load_component - } + ctx = {"current": current, "url": url, "asis": asis, "load_component": load_component} return self.templater.render(filename, ctx) def config_from_yaml(self, filename: str, namespace: Optional[str] = None): @@ -471,7 +461,7 @@ def module( pipeline: Optional[List[Pipe]] = None, injectors: Optional[List[Injector]] = None, module_class: Optional[Type[AppModule]] = None, - **kwargs: Any + **kwargs: Any, ) -> AppModule: module_class = module_class or self.modules_class return module_class.from_app( @@ -488,7 +478,7 @@ def module( root_path=root_path, pipeline=pipeline or [], injectors=injectors or [], - opts=kwargs + opts=kwargs, ) def module_group(self, *modules: AppModule) -> AppModuleGroup: @@ -509,7 +499,7 @@ def module( cache: Optional[RouteCacheRule] = None, root_path: Optional[str] = None, module_class: Optional[Type[AppModule]] = None, - **kwargs: Any + **kwargs: Any, ) -> AppModulesGrouped: module_class = module_class or AppModule return module_class.from_module_group( @@ -524,7 +514,7 @@ def module( hostname=hostname, cache=cache, root_path=root_path, - opts=kwargs + opts=kwargs, ) def route( @@ -532,12 +522,9 @@ def route( paths: Optional[Union[str, List[str]]] = None, name: Optional[str] = None, template: Optional[str] = None, - **kwargs + **kwargs, ) -> RoutingCtxGroup: - return RoutingCtxGroup([ - mod.route(paths=paths, name=name, template=template, **kwargs) - for mod in self.modules - ]) + return RoutingCtxGroup([mod.route(paths=paths, name=name, template=template, **kwargs) for mod in self.modules]) class AppModulesGrouped(AppModuleGroup): diff --git a/emmett/asgi/handlers.py b/emmett/asgi/handlers.py index 857a86c5..0be20684 100644 --- a/emmett/asgi/handlers.py +++ b/emmett/asgi/handlers.py @@ -1,12 +1,12 @@ # -*- coding: utf-8 -*- """ - emmett.asgi.handlers - -------------------- +emmett.asgi.handlers +-------------------- - Provides ASGI handlers. +Provides ASGI handlers. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ from __future__ import annotations @@ -15,13 +15,13 @@ from importlib import resources from typing import Awaitable, Callable -from emmett_core.http.response import HTTPResponse, HTTPBytesResponse +from emmett_core.http.response import HTTPBytesResponse, HTTPResponse from emmett_core.protocols.asgi.handlers import HTTPHandler as _HTTPHandler, WSHandler as _WSHandler from emmett_core.protocols.asgi.typing import Receive, Scope, Send from emmett_core.utils import cachedprop from ..ctx import current -from ..debug import smart_traceback, debug_handler +from ..debug import debug_handler, smart_traceback from ..libs.contenttype import contenttype from ..wrappers.response import Response from .wrappers import Request, Websocket @@ -34,9 +34,7 @@ class HTTPHandler(_HTTPHandler): @cachedprop def error_handler(self) -> Callable[[], Awaitable[str]]: - return ( - self._debug_handler if self.app.debug else self.exception_handler - ) + return self._debug_handler if self.app.debug else self.exception_handler async def _static_content(self, content: bytes, content_type: str) -> HTTPBytesResponse: content_len = str(len(content)) @@ -44,44 +42,32 @@ async def _static_content(self, content: bytes, content_type: str) -> HTTPBytesR 200, content, headers={ - 'content-type': content_type, - 'content-length': content_len, - 'last-modified': self._internal_assets_md[1], - 'etag': md5( - f"{self._internal_assets_md[0]}_{content_len}".encode("utf8") - ).hexdigest() - } + "content-type": content_type, + "content-length": content_len, + "last-modified": self._internal_assets_md[1], + "etag": md5(f"{self._internal_assets_md[0]}_{content_len}".encode("utf8")).hexdigest(), + }, ) - def _static_handler( - self, - scope: Scope, - receive: Receive, - send: Send - ) -> Awaitable[HTTPResponse]: - path = scope['emt.path'] + def _static_handler(self, scope: Scope, receive: Receive, send: Send) -> Awaitable[HTTPResponse]: + path = scope["emt.path"] #: handle internal assets - if path.startswith('/__emmett__'): + if path.startswith("/__emmett__"): file_name = path[12:] if not file_name or file_name.endswith(".html"): return self._http_response(404) pkg = None - if '/' in file_name: - pkg, file_name = file_name.split('/', 1) + if "/" in file_name: + pkg, file_name = file_name.split("/", 1) try: - file_contents = resources.read_binary( - f'emmett.assets.{pkg}' if pkg else 'emmett.assets', - file_name - ) + file_contents = resources.read_binary(f"emmett.assets.{pkg}" if pkg else "emmett.assets", file_name) except FileNotFoundError: return self._http_response(404) return self._static_content(file_contents, contenttype(file_name)) return super()._static_handler(scope, receive, send) async def _debug_handler(self) -> str: - current.response.headers._data['content-type'] = ( - 'text/html; charset=utf-8' - ) + current.response.headers._data["content-type"] = "text/html; charset=utf-8" return debug_handler(smart_traceback(self.app)) diff --git a/emmett/asgi/wrappers.py b/emmett/asgi/wrappers.py index 23f9745e..cf12204d 100644 --- a/emmett/asgi/wrappers.py +++ b/emmett/asgi/wrappers.py @@ -1,16 +1,15 @@ # -*- coding: utf-8 -*- """ - emmett.asgi.wrappers - -------------------- +emmett.asgi.wrappers +-------------------- - Provides ASGI request and websocket wrappers +Provides ASGI request and websocket wrappers - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ import pendulum - from emmett_core.protocols.asgi.wrappers import Request as _Request, Websocket as Websocket from emmett_core.utils import cachedprop diff --git a/emmett/cache.py b/emmett/cache.py index 39025dd5..e9bbebe0 100644 --- a/emmett/cache.py +++ b/emmett/cache.py @@ -1,12 +1,12 @@ # -*- coding: utf-8 -*- """ - emmett.cache - ------------ +emmett.cache +------------ - Provides a caching system. +Provides a caching system. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ from __future__ import annotations @@ -16,10 +16,9 @@ import tempfile import threading import time - from typing import Any, List, Optional -from emmett_core.cache import Cache as Cache, RamCache as RamCache +from emmett_core.cache import Cache as Cache from emmett_core.cache.handlers import CacheHandler, RamCache as RamCache, RedisCache as RedisCache from ._shortcuts import hashlib_sha1 @@ -29,15 +28,10 @@ class DiskCache(CacheHandler): lock = threading.RLock() - _fs_transaction_suffix = '.__mt_cache' + _fs_transaction_suffix = ".__mt_cache" _fs_mode = 0o600 - def __init__( - self, - cache_dir: str = 'cache', - threshold: int = 500, - default_expire: int = 300 - ): + def __init__(self, cache_dir: str = "cache", threshold: int = 500, default_expire: int = 300): super().__init__(default_expire=default_expire) self._threshold = threshold self._path = os.path.join(current.app.root_path, cache_dir) @@ -70,7 +64,7 @@ def _prune(self): try: for i, fpath in enumerate(entries): remove = False - f = LockedFile(fpath, 'rb') + f = LockedFile(fpath, "rb") exp = pickle.load(f.file) f.close() remove = exp <= now or i % 3 == 0 @@ -84,7 +78,7 @@ def get(self, key: str) -> Any: try: with self.lock: now = time.time() - f = LockedFile(filename, 'rb') + f = LockedFile(filename, "rb") exp = pickle.load(f.file) if exp < now: f.close() @@ -104,10 +98,9 @@ def set(self, key: str, value: Any, **kwargs): if os.path.exists(filepath): self._del_file(filepath) try: - fd, tmp = tempfile.mkstemp( - suffix=self._fs_transaction_suffix, dir=self._path) - with os.fdopen(fd, 'wb') as f: - pickle.dump(kwargs['expiration'], f, 1) + fd, tmp = tempfile.mkstemp(suffix=self._fs_transaction_suffix, dir=self._path) + with os.fdopen(fd, "wb") as f: + pickle.dump(kwargs["expiration"], f, 1) pickle.dump(value, f, pickle.HIGHEST_PROTOCOL) os.rename(tmp, filename) os.chmod(filename, self._fs_mode) diff --git a/emmett/cli.py b/emmett/cli.py index 25ed29ac..0a64d7b6 100644 --- a/emmett/cli.py +++ b/emmett/cli.py @@ -1,16 +1,16 @@ # -*- coding: utf-8 -*- """ - emmett.cli - --------- +emmett.cli +--------- - Provide command line tools for Emmett applications. +Provide command line tools for Emmett applications. - :copyright: 2014 Giovanni Barillari +:copyright: 2014 Giovanni Barillari - Based on the code of Flask (http://flask.pocoo.org) - :copyright: (c) 2014 by Armin Ronacher. +Based on the code of Flask (http://flask.pocoo.org) +:copyright: (c) 2014 by Armin Ronacher. - :license: BSD-3-Clause +:license: BSD-3-Clause """ import code @@ -20,8 +20,7 @@ import types import click - -from emmett_core._internal import locate_app, get_app_module +from emmett_core._internal import get_app_module, locate_app from emmett_core.log import LOG_LEVELS from emmett_core.server import run as sgi_run @@ -68,9 +67,7 @@ def find_db(module, var_name=None): from .orm import Database - matches = [ - v for k, v in module.__dict__.items() if isinstance(v, Database) - ] + matches = [v for k, v in module.__dict__.items() if isinstance(v, Database)] return matches @@ -131,6 +128,7 @@ def load_app(self): return self._loaded_app from .app import App + import_name, app_name = self._get_import_name() app = locate_app(App, import_name, app_name) if import_name else None @@ -170,30 +168,19 @@ def set_app_value(ctx, param, value): ctx.ensure_object(ScriptInfo).app_import_path = value -app_option = click.Option( - ['-a', '--app'], - help='The application to run', - callback=set_app_value, - is_eager=True -) +app_option = click.Option(["-a", "--app"], help="The application to run", callback=set_app_value, is_eager=True) class EmmettGroup(click.Group): - def __init__( - self, - add_default_commands=True, - add_app_option=True, - add_debug_option=True, - **extra - ): - params = list(extra.pop('params', None) or ()) + def __init__(self, add_default_commands=True, add_app_option=True, add_debug_option=True, **extra): + params = list(extra.pop("params", None) or ()) if add_app_option: params.append(app_option) - #if add_debug_option: + # if add_debug_option: # params.append(debug_option) click.Group.__init__(self, params=params, **extra) - #self.create_app = create_app + # self.create_app = create_app if add_default_commands: self.add_command(develop_command) @@ -231,57 +218,47 @@ def get_command(self, ctx, name): pass def main(self, *args, **kwargs): - obj = kwargs.get('obj') + obj = kwargs.get("obj") if obj is None: obj = ScriptInfo() - kwargs['obj'] = obj + kwargs["obj"] = obj return super().main(*args, **kwargs) -@click.command('develop', short_help='Runs a development server.') -@click.option( - '--host', '-h', default='127.0.0.1', help='The interface to bind to.') -@click.option( - '--port', '-p', type=int, default=8000, help='The port to bind to.') -@click.option( - '--interface', type=click.Choice(['rsgi', 'asgi']), default='rsgi', - help='Application interface.') -@click.option( - '--loop', type=click.Choice(['auto', 'asyncio', 'uvloop']), default='auto', - help='Event loop implementation.') +@click.command("develop", short_help="Runs a development server.") +@click.option("--host", "-h", default="127.0.0.1", help="The interface to bind to.") +@click.option("--port", "-p", type=int, default=8000, help="The port to bind to.") +@click.option("--interface", type=click.Choice(["rsgi", "asgi"]), default="rsgi", help="Application interface.") @click.option( - '--ssl-certfile', type=str, default=None, help='SSL certificate file') -@click.option( - '--ssl-keyfile', type=str, default=None, help='SSL key file') -@click.option( - '--reloader/--no-reloader', is_flag=True, default=True, - help='Runs with reloader.') + "--loop", type=click.Choice(["auto", "asyncio", "uvloop"]), default="auto", help="Event loop implementation." +) +@click.option("--ssl-certfile", type=str, default=None, help="SSL certificate file") +@click.option("--ssl-keyfile", type=str, default=None, help="SSL key file") +@click.option("--reloader/--no-reloader", is_flag=True, default=True, help="Runs with reloader.") @pass_script_info -def develop_command( - info, host, port, interface, loop, ssl_certfile, ssl_keyfile, reloader -): - os.environ["EMMETT_RUN_ENV"] = 'true' +def develop_command(info, host, port, interface, loop, ssl_certfile, ssl_keyfile, reloader): + os.environ["EMMETT_RUN_ENV"] = "true" app_target = info._get_import_name() - if os.environ.get('EMMETT_RUN_MAIN') != 'true': + if os.environ.get("EMMETT_RUN_MAIN") != "true": click.echo( - ' '.join([ - "> Starting Emmett development server on app", - click.style(app_target[0], fg="cyan", bold=True) - ]) + " ".join(["> Starting Emmett development server on app", click.style(app_target[0], fg="cyan", bold=True)]) ) click.echo( - ' '.join([ - click.style("> Emmett application", fg="green"), - click.style(app_target[0], fg="cyan", bold=True), - click.style("running on", fg="green"), - click.style(f"http://{host}:{port}", fg="cyan"), - click.style("(press CTRL+C to quit)", fg="green") - ]) + " ".join( + [ + click.style("> Emmett application", fg="green"), + click.style(app_target[0], fg="cyan", bold=True), + click.style("running on", fg="green"), + click.style(f"http://{host}:{port}", fg="cyan"), + click.style("(press CTRL+C to quit)", fg="green"), + ] + ) ) if reloader: from ._reloader import run_with_reloader + runner = run_with_reloader else: runner = sgi_run @@ -292,7 +269,7 @@ def develop_command( host=host, port=port, loop=loop, - log_level='debug', + log_level="debug", log_access=True, threading_mode="workers", ssl_certfile=ssl_certfile, @@ -300,54 +277,46 @@ def develop_command( ) -@click.command('serve', short_help='Serve the app.') -@click.option( - '--host', '-h', default='0.0.0.0', help='The interface to bind to.') -@click.option( - '--port', '-p', type=int, default=8000, help='The port to bind to.') -@click.option( - "--workers", '-w', type=int, default=1, - help="Number of worker processes. Defaults to 1.") -@click.option( - "--threads", type=int, default=1, help="Number of worker threads.") -@click.option( - "--threading-mode", type=click.Choice(['runtime', 'workers']), default='workers', - help="Server threading mode.") -@click.option( - '--interface', type=click.Choice(['rsgi', 'asgi']), default='rsgi', - help='Application interface.') -@click.option( - '--http', type=click.Choice(['auto', '1', '2']), default='auto', - help='HTTP version.') -@click.option( - '--ws/--no-ws', is_flag=True, default=True, - help='Enable websockets support.') -@click.option( - '--loop', type=click.Choice(['auto', 'asyncio', 'uvloop']), default='auto', - help='Event loop implementation.') -@click.option( - '--opt/--no-opt', is_flag=True, default=False, - help='Enable loop optimizations.') -@click.option( - '--log-level', type=click.Choice(LOG_LEVELS.keys()), default='info', - help='Logging level.') -@click.option( - '--access-log/--no-access-log', is_flag=True, default=False, - help='Enable access log.') -@click.option( - '--backlog', type=int, default=2048, - help='Maximum number of connections to hold in backlog') +@click.command("serve", short_help="Serve the app.") +@click.option("--host", "-h", default="0.0.0.0", help="The interface to bind to.") +@click.option("--port", "-p", type=int, default=8000, help="The port to bind to.") +@click.option("--workers", "-w", type=int, default=1, help="Number of worker processes. Defaults to 1.") +@click.option("--threads", type=int, default=1, help="Number of worker threads.") @click.option( - '--backpressure', type=int, - help='Maximum number of requests to process concurrently (per worker)') -@click.option( - '--ssl-certfile', type=str, default=None, help='SSL certificate file') + "--threading-mode", type=click.Choice(["runtime", "workers"]), default="workers", help="Server threading mode." +) +@click.option("--interface", type=click.Choice(["rsgi", "asgi"]), default="rsgi", help="Application interface.") +@click.option("--http", type=click.Choice(["auto", "1", "2"]), default="auto", help="HTTP version.") +@click.option("--ws/--no-ws", is_flag=True, default=True, help="Enable websockets support.") @click.option( - '--ssl-keyfile', type=str, default=None, help='SSL key file') + "--loop", type=click.Choice(["auto", "asyncio", "uvloop"]), default="auto", help="Event loop implementation." +) +@click.option("--opt/--no-opt", is_flag=True, default=False, help="Enable loop optimizations.") +@click.option("--log-level", type=click.Choice(LOG_LEVELS.keys()), default="info", help="Logging level.") +@click.option("--access-log/--no-access-log", is_flag=True, default=False, help="Enable access log.") +@click.option("--backlog", type=int, default=2048, help="Maximum number of connections to hold in backlog") +@click.option("--backpressure", type=int, help="Maximum number of requests to process concurrently (per worker)") +@click.option("--ssl-certfile", type=str, default=None, help="SSL certificate file") +@click.option("--ssl-keyfile", type=str, default=None, help="SSL key file") @pass_script_info def serve_command( - info, host, port, workers, threads, threading_mode, interface, http, ws, loop, opt, - log_level, access_log, backlog, backpressure, ssl_certfile, ssl_keyfile + info, + host, + port, + workers, + threads, + threading_mode, + interface, + http, + ws, + loop, + opt, + log_level, + access_log, + backlog, + backpressure, + ssl_certfile, + ssl_keyfile, ): app_target = info._get_import_name() sgi_run( @@ -371,31 +340,22 @@ def serve_command( ) -@click.command('shell', short_help='Runs a shell in the app context.') +@click.command("shell", short_help="Runs a shell in the app context.") @pass_script_info def shell_command(info): - os.environ['EMMETT_CLI_ENV'] = 'true' + os.environ["EMMETT_CLI_ENV"] = "true" ctx = info.load_appctx() app = info.load_app() - banner = 'Python %s on %s\nEmmett %s shell on app: %s' % ( - sys.version, - sys.platform, - fw_version, - app.import_name - ) + banner = "Python %s on %s\nEmmett %s shell on app: %s" % (sys.version, sys.platform, fw_version, app.import_name) code.interact(banner=banner, local=app.make_shell_context(ctx)) -@click.command('routes', short_help='Display the app routing table.') +@click.command("routes", short_help="Display the app routing table.") @pass_script_info def routes_command(info): app = info.load_app() click.echo( - "".join([ - "> Routing table for Emmett application ", - click.style(app.import_name, fg="cyan", bold=True), - ":" - ]) + "".join(["> Routing table for Emmett application ", click.style(app.import_name, fg="cyan", bold=True), ":"]) ) for route in app._router_http._routes_str.values(): click.echo(route) @@ -410,114 +370,100 @@ def set_db_value(ctx, param, value): ctx.ensure_object(ScriptInfo).db_var_name = value -@cli.group('migrations', short_help='Runs migration operations.') -@click.option( - '--db', help='The db instance to use', callback=set_db_value, is_eager=True -) +@cli.group("migrations", short_help="Runs migration operations.") +@click.option("--db", help="The db instance to use", callback=set_db_value, is_eager=True) def migrations_cli(db): pass -@migrations_cli.command('status', short_help='Shows current database revision.') -@click.option('--verbose', '-v', default=False, is_flag=True) +@migrations_cli.command("status", short_help="Shows current database revision.") +@click.option("--verbose", "-v", default=False, is_flag=True) @pass_script_info def migrations_status(info, verbose): from .orm.migrations.commands import status + app = info.load_app() dbs = info.load_db() status(app, dbs, verbose) -@migrations_cli.command('history', short_help="Shows migrations history.") -@click.option('--range', '-r', default=None) -@click.option('--verbose', '-v', default=False, is_flag=True) +@migrations_cli.command("history", short_help="Shows migrations history.") +@click.option("--range", "-r", default=None) +@click.option("--verbose", "-v", default=False, is_flag=True) @pass_script_info def migrations_history(info, range, verbose): from .orm.migrations.commands import history + app = info.load_app() dbs = info.load_db() history(app, dbs, range, verbose) -@migrations_cli.command( - 'generate', short_help='Generates a new migration from application models.' -) -@click.option( - '--message', '-m', default='Generated migration', - help='The description for the new migration.' -) -@click.option('-head', default='head', help='The migration to generate from') +@migrations_cli.command("generate", short_help="Generates a new migration from application models.") +@click.option("--message", "-m", default="Generated migration", help="The description for the new migration.") +@click.option("-head", default="head", help="The migration to generate from") @pass_script_info def migrations_generate(info, message, head): from .orm.migrations.commands import generate + app = info.load_app() dbs = info.load_db() generate(app, dbs, message, head) -@migrations_cli.command('new', short_help='Generates a new empty migration.') -@click.option( - '--message', '-m', default='New migration', - help='The description for the new migration.' -) -@click.option('-head', default='head', help='The migration to generate from') +@migrations_cli.command("new", short_help="Generates a new empty migration.") +@click.option("--message", "-m", default="New migration", help="The description for the new migration.") +@click.option("-head", default="head", help="The migration to generate from") @pass_script_info def migrations_new(info, message, head): from .orm.migrations.commands import new + app = info.load_app() dbs = info.load_db() new(app, dbs, message, head) -@migrations_cli.command( - 'up', short_help='Upgrades the database to the selected migration.' -) -@click.option('--revision', '-r', default='head', help='The migration to upgrade to.') +@migrations_cli.command("up", short_help="Upgrades the database to the selected migration.") +@click.option("--revision", "-r", default="head", help="The migration to upgrade to.") @click.option( - '--dry-run', + "--dry-run", default=False, is_flag=True, - help='Only print SQL instructions, without actually applying the migration.' + help="Only print SQL instructions, without actually applying the migration.", ) @pass_script_info def migrations_up(info, revision, dry_run): from .orm.migrations.commands import up + app = info.load_app() dbs = info.load_db() up(app, dbs, revision, dry_run) -@migrations_cli.command( - 'down', short_help='Downgrades the database to the selected migration.' -) -@click.option('--revision', '-r', required=True, help='The migration to downgrade to.') +@migrations_cli.command("down", short_help="Downgrades the database to the selected migration.") +@click.option("--revision", "-r", required=True, help="The migration to downgrade to.") @click.option( - '--dry-run', + "--dry-run", default=False, is_flag=True, - help='Only print SQL instructions, without actually applying the migration.' + help="Only print SQL instructions, without actually applying the migration.", ) @pass_script_info def migrations_down(info, revision, dry_run): from .orm.migrations.commands import down + app = info.load_app() dbs = info.load_db() down(app, dbs, revision, dry_run) -@migrations_cli.command( - 'set', short_help='Overrides database revision with selected migration.' -) -@click.option('--revision', '-r', default='head', help='The migration to set.') -@click.option( - '--auto-confirm', - default=False, - is_flag=True, - help='Skip asking confirmation.' -) +@migrations_cli.command("set", short_help="Overrides database revision with selected migration.") +@click.option("--revision", "-r", default="head", help="The migration to set.") +@click.option("--auto-confirm", default=False, is_flag=True, help="Skip asking confirmation.") @pass_script_info def migrations_set(info, revision, auto_confirm): from .orm.migrations.commands import set_revision + app = info.load_app() dbs = info.load_db() set_revision(app, dbs, revision, auto_confirm) @@ -527,5 +473,5 @@ def main(as_module=False): cli.main(prog_name="python -m emmett" if as_module else None) -if __name__ == '__main__': +if __name__ == "__main__": main(as_module=True) diff --git a/emmett/ctx.py b/emmett/ctx.py index be52cf13..45dd62d9 100644 --- a/emmett/ctx.py +++ b/emmett/ctx.py @@ -1,25 +1,24 @@ # -*- coding: utf-8 -*- """ - emmett.ctx - ---------- +emmett.ctx +---------- - Provides the current object. - Used by application to deal with request related objects. +Provides the current object. +Used by application to deal with request related objects. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ from datetime import datetime import pendulum - from emmett_core.ctx import ( - Current as _Current, Context as _Context, + Current as _Current, RequestContext as _RequestContext, WSContext as _WsContext, - _ctxv + _ctxv, ) from emmett_core.utils import cachedprop @@ -37,9 +36,7 @@ class RequestContext(_RequestContext): @cachedprop def language(self): - return self.request.accept_language.best_match( - list(self.app.translator._langmap) - ) + return self.request.accept_language.best_match(list(self.app.translator._langmap)) class WSContext(_WsContext): @@ -51,9 +48,7 @@ def now(self): @cachedprop def language(self): - return self.websocket.accept_language.best_match( - list(self.app.translator._langmap) - ) + return self.websocket.accept_language.best_match(list(self.app.translator._langmap)) class Current(_Current): diff --git a/emmett/datastructures.py b/emmett/datastructures.py index 11693c65..a40a0836 100644 --- a/emmett/datastructures.py +++ b/emmett/datastructures.py @@ -1,12 +1,12 @@ # -*- coding: utf-8 -*- """ - emmett.datastructures - --------------------- +emmett.datastructures +--------------------- - Provide some useful data structures. +Provide some useful data structures. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ from emmett_core.datastructures import sdict as sdict @@ -55,7 +55,7 @@ def __add__(self, other): return self.union(other) def __repr__(self): - return '%s(%r)' % (self.__class__.__name__, self._list) + return "%s(%r)" % (self.__class__.__name__, self._list) __str__ = __repr__ @@ -124,6 +124,4 @@ def _unique_list(seq, hashfunc=None): seen_add = seen.add if not hashfunc: return [x for x in seq if x not in seen and not seen_add(x)] - return [ - x for x in seq if hashfunc(x) not in seen and not seen_add(hashfunc(x)) - ] + return [x for x in seq if hashfunc(x) not in seen and not seen_add(hashfunc(x))] diff --git a/emmett/debug.py b/emmett/debug.py index 47c2124d..278d8502 100644 --- a/emmett/debug.py +++ b/emmett/debug.py @@ -1,12 +1,12 @@ # -*- coding: utf-8 -*- """ - emmett.debug - ------------ +emmett.debug +------------ - Provides debugging utilities. +Provides debugging utilities. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ import inspect @@ -27,10 +27,8 @@ def __init__(self, app, exc_type, exc_value, tb): self.exc_value = exc_value if not isinstance(exc_type, str): exception_type = exc_type.__name__ - if exc_type.__module__ not in ( - '__builtin__', 'builtins', 'exceptions' - ): - exception_type = exc_type.__module__ + '.' + exception_type + if exc_type.__module__ not in ("__builtin__", "builtins", "exceptions"): + exception_type = exc_type.__module__ + "." + exception_type else: exception_type = exc_type self.exception_type = exception_type @@ -44,39 +42,31 @@ def __init__(self, app, exc_type, exc_value, tb): def exception(self): """String representation of the exception.""" buf = traceback.format_exception_only(self.exc_type, self.exc_value) - return ''.join(buf).strip() + return "".join(buf).strip() def generate_plaintext_traceback(self): """Like the plaintext attribute but returns a generator""" - yield u'Traceback (most recent call last):' + yield "Traceback (most recent call last):" for frame in self.frames: - yield u' File "%s", line %s, in %s' % ( - frame.filename, - frame.lineno, - frame.function_name - ) - yield u' ' + frame.current_line.strip() + yield ' File "%s", line %s, in %s' % (frame.filename, frame.lineno, frame.function_name) + yield " " + frame.current_line.strip() yield self.exception def generate_plain_tb_app(self): - yield u'Traceback (most recent call last):' + yield "Traceback (most recent call last):" for frame in self.frames: if frame.is_in_app: - yield u' File "%s", line %s, in %s' % ( - frame.filename, - frame.lineno, - frame.function_name - ) - yield u' ' + frame.current_line.strip() + yield ' File "%s", line %s, in %s' % (frame.filename, frame.lineno, frame.function_name) + yield " " + frame.current_line.strip() yield self.exception @property def full_tb(self): - return u'\n'.join(self.generate_plaintext_traceback()) + return "\n".join(self.generate_plaintext_traceback()) @property def app_tb(self): - return u'\n'.join(self.generate_plain_tb_app()) + return "\n".join(self.generate_plain_tb_app()) class Frame: @@ -90,50 +80,47 @@ def __init__(self, app, exc_type, exc_value, tb): self.globals = tb.tb_frame.f_globals fn = inspect.getsourcefile(tb) or inspect.getfile(tb) - if fn[-4:] in ('.pyo', '.pyc'): + if fn[-4:] in (".pyo", ".pyc"): fn = fn[:-1] # if it's a file on the file system resolve the real filename. if os.path.isfile(fn): fn = os.path.realpath(fn) self.filename = fn - self.module = self.globals.get('__name__') + self.module = self.globals.get("__name__") self.code = tb.tb_frame.f_code @property def is_in_fw(self): fw_path = os.path.dirname(__file__) - return self.filename[0:len(fw_path)] == fw_path + return self.filename[0 : len(fw_path)] == fw_path @property def is_in_app(self): - return self.filename[0:len(self.app.root_path)] == self.app.root_path + return self.filename[0 : len(self.app.root_path)] == self.app.root_path @property def rendered_filename(self): if self.is_in_app: - return self.filename[len(self.app.root_path) + 1:] + return self.filename[len(self.app.root_path) + 1 :] if self.is_in_fw: - return ''.join([ - "emmett.", - self.filename[ - len(os.path.dirname(__file__)) + 1: - ].replace("/", ".").split(".py")[0] - ]) + return "".join( + ["emmett.", self.filename[len(os.path.dirname(__file__)) + 1 :].replace("/", ".").split(".py")[0]] + ) return self.filename @cachedprop def sourcelines(self): try: - with open(self.filename, 'rb') as file: - source = file.read().decode('utf8') + with open(self.filename, "rb") as file: + source = file.read().decode("utf8") except IOError: - source = '' + source = "" return source.splitlines() @property def sourceblock(self): lmax = self.lineno + 4 - return u'\n'.join(self.sourcelines[self.first_line_no - 1:lmax]) + return "\n".join(self.sourcelines[self.first_line_no - 1 : lmax]) @property def first_line_no(self): @@ -151,22 +138,20 @@ def current_line(self): try: return self.sourcelines[self.lineno - 1] except IndexError: - return u'' + return "" @cachedprop def render_locals(self): - rv = dict() + rv = {} for k, v in self.locals.items(): try: rv[k] = str(v) except Exception: - rv[k] = '' + rv[k] = "" return rv -debug_templater = Renoir( - path=os.path.join(os.path.dirname(__file__), 'assets', 'debug') -) +debug_templater = Renoir(path=os.path.join(os.path.dirname(__file__), "assets", "debug")) def smart_traceback(app): @@ -175,4 +160,4 @@ def smart_traceback(app): def debug_handler(tb): - return debug_templater.render('view.html', {'tb': tb}) + return debug_templater.render("view.html", {"tb": tb}) diff --git a/emmett/extensions.py b/emmett/extensions.py index 56a45bdb..8f9377cc 100644 --- a/emmett/extensions.py +++ b/emmett/extensions.py @@ -1,12 +1,12 @@ # -*- coding: utf-8 -*- """ - emmett.extensions - ----------------- +emmett.extensions +----------------- - Provides base classes to create extensions. +Provides base classes to create extensions. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ from __future__ import annotations diff --git a/emmett/forms.py b/emmett/forms.py index 09495037..fd0b87aa 100644 --- a/emmett/forms.py +++ b/emmett/forms.py @@ -1,12 +1,12 @@ # -*- coding: utf-8 -*- """ - emmett.forms - ------------ +emmett.forms +------------ - Provides classes to create and style forms. +Provides classes to create and style forms. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ from __future__ import annotations @@ -19,12 +19,13 @@ from .ctx import current from .datastructures import sdict -from .html import HtmlTag, tag, cat, asis +from .html import HtmlTag, asis, cat, tag from .orm import Field, Model from .orm.objects import Row, Table from .security import CSRFStorage from .validators import isEmptyOr + __all__ = ["Form", "ModelForm"] @@ -47,7 +48,7 @@ def __init__( _action: str = "", _enctype: str = "multipart/form-data", _method: str = "POST", - **kwargs: Any + **kwargs: Any, ): self._action = _action self._csrf = csrf @@ -62,9 +63,7 @@ def __init__( self.onvalidation = onvalidation self.writable_fields = writable_fields if not issubclass(self._formstyle, FormStyle): - raise RuntimeError( - "{!r} is an invalid Emmett form style".format(formstyle) - ) + raise RuntimeError("{!r} is an invalid Emmett form style".format(formstyle)) self._preprocess(**kwargs) def _preprocess(self, **kwargs): @@ -76,7 +75,7 @@ def _preprocess(self, **kwargs): "id_prefix": self._id_prefix, "hidden": {}, "submit": self._submit_text, - "upload": self._upload + "upload": self._upload, } self.attributes.update(kwargs) #: init the form @@ -91,9 +90,7 @@ def _preprocess(self, **kwargs): @property def csrf(self) -> bool: - return self._csrf is True or ( - self._csrf == "auto" and self._submit_method == "POST" - ) + return self._csrf is True or (self._csrf == "auto" and self._submit_method == "POST") def _load_csrf(self): if not self.csrf: @@ -176,16 +173,9 @@ async def _process(self, write_defaults=True): if self.csrf and not self.accepted: self.formkey = current.session._csrf.gen_token() # reset default values in form - if ( - write_defaults and ( - not self.processed or (self.accepted and not self.keepvalues) - ) - ): + if write_defaults and (not self.processed or (self.accepted and not self.keepvalues)): for field in self.fields: - self.input_params[field.name] = ( - field.default() if callable(field.default) else - field.default - ) + self.input_params[field.name] = field.default() if callable(field.default) else field.default return self def _render(self): @@ -230,12 +220,9 @@ def custom(self): custom.begin = asis(f"
") # add hidden stuffs to get process working hidden = cat() - hidden.append( - tag.input(_name="_csrf_token", _type="hidden", _value=self.formkey) - ) + hidden.append(tag.input(_name="_csrf_token", _type="hidden", _value=self.formkey)) for key, value in self.attributes["hidden"].items(): - hidden.append(tag.input(_name=key, _type="hidden", _value=value) - ) + hidden.append(tag.input(_name=key, _type="hidden", _value=value)) # provides end attribute custom.end = asis(f"{hidden.__html__()}
") return custom @@ -258,7 +245,7 @@ def __init__( _action: str = "", _enctype: str = "multipart/form-data", _method: str = "POST", - **kwargs: Any + **kwargs: Any, ): fields = fields or {} #: get fields from kwargs @@ -293,7 +280,7 @@ def __init__( upload=upload, _action=_action, _enctype=_enctype, - _method=_method + _method=_method, ) @@ -314,26 +301,23 @@ def __init__( _action: str = "", _enctype: str = "multipart/form-data", _method: str = "POST", - **attributes + **attributes, ): self.model = model._instance_() self.table: Table = self.model.table - self.record = record or ( - self.model.get(record_id) if record_id else - self.model.new() - ) + self.record = record or (self.model.get(record_id) if record_id else self.model.new()) #: build fields for form fields_list_all = [] fields_list_writable = [] if fields is not None: #: developer has selected specific fields if not isinstance(fields, dict): - fields = {'writable': fields, 'readable': fields} + fields = {"writable": fields, "readable": fields} for field in self.table: - if field.name not in fields['readable']: + if field.name not in fields["readable"]: continue fields_list_all.append(field) - if field.name in fields['writable']: + if field.name in fields["writable"]: fields_list_writable.append(field) else: #: use table fields @@ -360,7 +344,7 @@ def __init__( upload=upload, _action=_action, _enctype=_enctype, - _method=_method + _method=_method, ) def _get_id_value(self): @@ -369,16 +353,13 @@ def _get_id_value(self): return self.record[self.table._id.name] def _validate_input(self): - record, fields = self.record.clone(), { - field.name: self._get_input_val(field) - for field in self.writable_fields - } + record, fields = self.record.clone(), {field.name: self._get_input_val(field) for field in self.writable_fields} for field in filter(lambda f: f.type == "upload", self.writable_fields): val = fields[field.name] if ( - (val == b"" or val is None) and - not self.input_params.get(field.name + "__del", False) and - self.record[field.name] + (val == b"" or val is None) + and not self.input_params.get(field.name + "__del", False) + and self.record[field.name] ): fields.pop(field.name) record.update(fields) @@ -414,9 +395,7 @@ async def _process(self, **kwargs): continue else: source_file, original_filename = upload.stream, upload.filename - newfilename = field.store( - source_file, original_filename, field.uploadfolder - ) + newfilename = field.store(source_file, original_filename, field.uploadfolder) if isinstance(field.uploadfield, str): self.params[field.uploadfield] = source_file.read() self.params[field.name] = newfilename @@ -429,15 +408,11 @@ async def _process(self, **kwargs): #: cleanup inputs if not self.processed or (self.accepted and not self.keepvalues): for field in self.fields: - self.input_params[field.name] = field.formatter( - self.record[field.name] - ) + self.input_params[field.name] = field.formatter(self.record[field.name]) elif self.processed and not self.accepted and self.record._concrete: for field in self.writable_fields: if field.type == "upload" and field.name not in self.params: - self.input_params[field.name] = field.formatter( - self.record[field.name] - ) + self.input_params[field.name] = field.formatter(self.record[field.name]) return self @@ -457,17 +432,12 @@ def widget_string(attr, field, value, _class="string", _id=None): _name=field.name, _value=value if value is not None else "", _class=_class, - _id=_id or field.name + _id=_id or field.name, ) @staticmethod def widget_text(attr, field, value, _class="text", _id=None): - return tag.textarea( - value or "", - _name=field.name, - _class=_class, - _id=_id or field.name - ) + return tag.textarea(value or "", _name=field.name, _class=_class, _id=_id or field.name) @staticmethod def widget_int(attr, field, value, _class="int", _id=None): @@ -488,7 +458,7 @@ def widget_date(attr, field, value, _class="date", _id=None): _name=field.name, _value=value if value is not None else "", _class=_class, - _id=_id or field.name + _id=_id or field.name, ) @staticmethod @@ -498,7 +468,7 @@ def widget_time(attr, field, value, _class="time", _id=None): _name=field.name, _value=value if value is not None else "", _class=_class, - _id=_id or field.name + _id=_id or field.name, ) @staticmethod @@ -508,18 +478,12 @@ def widget_datetime(attr, field, value, _class="datetime", _id=None): _name=field.name, _value=value if value is not None else "", _class=_class, - _id=_id or field.name + _id=_id or field.name, ) @staticmethod def widget_password(attr, field, value, _class="password", _id=None): - return tag.input( - _type="password", - _name=field.name, - _value=value or "", - _class=_class, - _id=_id or field.name - ) + return tag.input(_type="password", _name=field.name, _value=value or "", _class=_class, _id=_id or field.name) @staticmethod def widget_bool(attr, field, value, _class="bool", _id=None): @@ -528,7 +492,7 @@ def widget_bool(attr, field, value, _class="bool", _id=None): _name=field.name, _checked="checked" if value else None, _class=_class, - _id=_id or field.name + _id=_id or field.name, ) @staticmethod @@ -538,16 +502,12 @@ def selected(k): options, multiple = FormStyle._field_options(field) if multiple: - return FormStyle.widget_multiple( - attr, field, value, options, _class=_class, _id=_id - ) + return FormStyle.widget_multiple(attr, field, value, options, _class=_class, _id=_id) return tag.select( - *[ - tag.option(n, _value=k, _selected=selected(k)) for k, n in options - ], + *[tag.option(n, _value=k, _selected=selected(k)) for k, n in options], _name=field.name, _class=_class, - _id=_id or field.name + _id=_id or field.name, ) @staticmethod @@ -557,13 +517,11 @@ def selected(k): values = values or [] return tag.select( - *[ - tag.option(n, _value=k, _selected=selected(k)) for k, n in options - ], + *[tag.option(n, _value=k, _selected=selected(k)) for k, n in options], _name=field.name, _class=_class, _multiple="multiple", - _id=_id or field.name + _id=_id or field.name, ) @staticmethod @@ -605,14 +563,10 @@ def _coerce_value(value): _class="checkbox", _id=_id + "__del", _name=field.name + "__del", - _style="display: inline;" - ), - tag.label( - "delete", - _for=_id + "__del", - _style="margin: 4px" + _style="display: inline;", ), - _style="white-space: nowrap;" + tag.label("delete", _for=_id + "__del", _style="margin: 4px"), + _style="white-space: nowrap;", ) ) return tag.div(*elements, _class="upload_wrap") @@ -628,19 +582,22 @@ def widget_jsonb(attr, field, value, _id=None): @staticmethod def widget_radio(field, value): options, _ = FormStyle._field_options(field) - return cat(*[ - tag.div( - tag.input( - _id=f"{field.name}_{k}", - _name=field.name, - _value=k, - _type="radio", - _checked=("checked" if str(k) == str(value) else None) - ), - tag.label(n, _for=f"{field.name}_{k}"), - _class="option_wrap" - ) for k, n in options - ]) + return cat( + *[ + tag.div( + tag.input( + _id=f"{field.name}_{k}", + _name=field.name, + _value=k, + _type="radio", + _checked=("checked" if str(k) == str(value) else None), + ), + tag.label(n, _for=f"{field.name}_{k}"), + _class="option_wrap", + ) + for k, n in options + ] + ) def __init__(self, attributes): self.attr = attributes @@ -668,16 +625,12 @@ def _get_widget(self, field, value): elif wtype.startswith("decimal"): wtype = "float" try: - widget = getattr(self, "widget_" + wtype)( - self.attr, field, value, _id=widget_id - ) + widget = getattr(self, "widget_" + wtype)(self.attr, field, value, _id=widget_id) if not field.writable: self._disable_widget(widget) return widget, False except AttributeError: - raise RuntimeError( - f"Missing form widget for field {field.name} of type {wtype}" - ) + raise RuntimeError(f"Missing form widget for field {field.name} of type {wtype}") def _disable_widget(self, widget): widget.attributes["_disabled"] = "disabled" @@ -744,7 +697,8 @@ def add_form_on_model(cls): @wraps(cls) def wrapped(model, *args, **kwargs): return cls(model, *args, **kwargs) + return wrapped -setattr(Model, "form", classmethod(add_form_on_model(ModelForm))) +Model.form = classmethod(add_form_on_model(ModelForm)) diff --git a/emmett/helpers.py b/emmett/helpers.py index 21b6566b..27d021ec 100644 --- a/emmett/helpers.py +++ b/emmett/helpers.py @@ -1,47 +1,43 @@ # -*- coding: utf-8 -*- """ - emmett.helpers - -------------- +emmett.helpers +-------------- - Provides helping methods for applications. +Provides helping methods for applications. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ import os import re - from typing import Any, List, Optional, Tuple, Union -from pydal.exceptions import NotAuthorizedException, NotFoundException from emmett_core.http.helpers import abort as _abort from emmett_core.http.response import HTTPFileResponse, HTTPIOResponse +from pydal.exceptions import NotAuthorizedException, NotFoundException from .ctx import current from .html import HtmlTag, tag -_re_dbstream = re.compile(r'(?P.*?)\.(?P.*?)\..*') +_re_dbstream = re.compile(r"(?P
.*?)\.(?P.*?)\..*") -def abort(code: int, body: str = ''): + +def abort(code: int, body: str = ""): _abort(current, code, body) def stream_file(path: str): full_path = os.path.join(current.app.root_path, path) - raise HTTPFileResponse( - full_path, - headers=current.response.headers, - cookies=current.response.cookies - ) + raise HTTPFileResponse(full_path, headers=current.response.headers, cookies=current.response.cookies) def stream_dbfile(db: Any, name: str): items = _re_dbstream.match(name) if not items: abort(404) - table_name, field_name = items.group('table'), items.group('field') + table_name, field_name = items.group("table"), items.group("field") try: field = db[table_name][field_name] except AttributeError: @@ -55,19 +51,11 @@ def stream_dbfile(db: Any, name: str): except IOError: abort(404) if isinstance(path_or_stream, str): - raise HTTPFileResponse( - path_or_stream, - headers=current.response.headers, - cookies=current.response.cookies - ) - raise HTTPIOResponse( - path_or_stream, - headers=current.response.headers, - cookies=current.response.cookies - ) - - -def flash(message: str, category: str = 'message'): + raise HTTPFileResponse(path_or_stream, headers=current.response.headers, cookies=current.response.cookies) + raise HTTPIOResponse(path_or_stream, headers=current.response.headers, cookies=current.response.cookies) + + +def flash(message: str, category: str = "message"): #: Flashes a message to the next request. if current.session._flashes is None: current.session._flashes = [] @@ -75,8 +63,7 @@ def flash(message: str, category: str = 'message'): def get_flashed_messages( - with_categories: bool = False, - category_filter: Union[str, List[str]] = [] + with_categories: bool = False, category_filter: Union[str, List[str]] = [] ) -> Union[List[str], Tuple[str, str]]: #: Pulls flashed messages from the session and returns them. # By default just the messages are returned, but when `with_categories` @@ -97,13 +84,9 @@ def get_flashed_messages( return flashes -def load_component( - url: str, - target: Optional[str] = None, - content: str = 'loading...' -) -> HtmlTag: +def load_component(url: str, target: Optional[str] = None, content: str = "loading...") -> HtmlTag: attr = {} if target: - attr['_id'] = target - attr['_data-emt_remote'] = url + attr["_id"] = target + attr["_data-emt_remote"] = url return tag.div(content, **attr) diff --git a/emmett/html.py b/emmett/html.py index 25e11c97..193f6c44 100644 --- a/emmett/html.py +++ b/emmett/html.py @@ -1,21 +1,21 @@ # -*- coding: utf-8 -*- """ - emmett.html - ----------- +emmett.html +----------- - Provides html generation classes. +Provides html generation classes. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ import html import re import threading - from functools import reduce -__all__ = ['tag', 'cat', 'asis'] + +__all__ = ["tag", "cat", "asis"] class TagStack(threading.local): @@ -37,16 +37,16 @@ def __bool__(self): class HtmlTag: rules = { - 'ul': ['li'], - 'ol': ['li'], - 'table': ['tr', 'thead', 'tbody'], - 'thead': ['tr'], - 'tbody': ['tr'], - 'tr': ['td', 'th'], - 'select': ['option', 'optgroup'], - 'optgroup': ['optionp'] + "ul": ["li"], + "ol": ["li"], + "table": ["tr", "thead", "tbody"], + "thead": ["tr"], + "tbody": ["tr"], + "tr": ["td", "th"], + "select": ["option", "optgroup"], + "optgroup": ["optionp"], } - _self_closed = {'br', 'col', 'embed', 'hr', 'img', 'input', 'link', 'meta'} + _self_closed = {"br", "col", "embed", "hr", "img", "input", "link", "meta"} def __init__(self, name): self.name = name @@ -65,9 +65,7 @@ def __exit__(self, type, value, traceback): @staticmethod def wrap(component, rules): - if rules and ( - not isinstance(component, HtmlTag) or component.name not in rules - ): + if rules and (not isinstance(component, HtmlTag) or component.name not in rules): return HtmlTag(rules[0])(component) return component @@ -112,78 +110,69 @@ def __add__(self, other): return cat(self, other) def add_class(self, name): - """ add a class to _class attribute """ - c = self['_class'] + """add a class to _class attribute""" + c = self["_class"] classes = (set(c.split()) if c else set()) | set(name.split()) - self['_class'] = ' '.join(classes) if classes else None + self["_class"] = " ".join(classes) if classes else None return self def remove_class(self, name): - """ remove a class from _class attribute """ - c = self['_class'] + """remove a class from _class attribute""" + c = self["_class"] classes = (set(c.split()) if c else set()) - set(name.split()) - self['_class'] = ' '.join(classes) if classes else None + self["_class"] = " ".join(classes) if classes else None return self - regex_tag = re.compile(r'^([\w\-\:]+)') - regex_id = re.compile(r'#([\w\-]+)') - regex_class = re.compile(r'\.([\w\-]+)') - regex_attr = re.compile(r'\[([\w\-\:]+)=(.*?)\]') + regex_tag = re.compile(r"^([\w\-\:]+)") + regex_id = re.compile(r"#([\w\-]+)") + regex_class = re.compile(r"\.([\w\-]+)") + regex_attr = re.compile(r"\[([\w\-\:]+)=(.*?)\]") def find(self, expr): union = lambda a, b: a.union(b) - if ',' in expr: - tags = reduce( - union, - [self.find(x.strip()) for x in expr.split(',')], - set()) - elif ' ' in expr: + if "," in expr: + tags = reduce(union, [self.find(x.strip()) for x in expr.split(",")], set()) + elif " " in expr: tags = [self] for k, item in enumerate(expr.split()): if k > 0: - children = [ - set([c for c in tag if isinstance(c, HtmlTag)]) - for tag in tags] + children = [{c for c in tag if isinstance(c, HtmlTag)} for tag in tags] tags = reduce(union, children) tags = reduce(union, [tag.find(item) for tag in tags], set()) else: - tags = reduce( - union, - [c.find(expr) for c in self if isinstance(c, HtmlTag)], - set()) + tags = reduce(union, [c.find(expr) for c in self if isinstance(c, HtmlTag)], set()) tag = HtmlTag.regex_tag.match(expr) id = HtmlTag.regex_id.match(expr) _class = HtmlTag.regex_class.match(expr) attr = HtmlTag.regex_attr.match(expr) if ( - (tag is None or self.name == tag.group(1)) and - (id is None or self['_id'] == id.group(1)) and - (_class is None or _class.group(1) in - (self['_class'] or '').split()) and - (attr is None or self['_' + attr.group(1)] == attr.group(2)) + (tag is None or self.name == tag.group(1)) + and (id is None or self["_id"] == id.group(1)) + and (_class is None or _class.group(1) in (self["_class"] or "").split()) + and (attr is None or self["_" + attr.group(1)] == attr.group(2)) ): tags.add(self) return tags def _build_html_attributes(self): - return ' '.join( + return " ".join( '%s="%s"' % (k[1:], k[1:] if v is True else htmlescape(v)) for (k, v) in sorted(self.attributes.items()) - if k.startswith('_') and v is not None) + if k.startswith("_") and v is not None + ) def __html__(self): name = self.name attrs = self._build_html_attributes() - data = self.attributes.get('data', {}) - data_attrs = ' '.join( - 'data-%s="%s"' % (k, htmlescape(v)) for k, v in data.items()) + data = self.attributes.get("data", {}) + data_attrs = " ".join('data-%s="%s"' % (k, htmlescape(v)) for k, v in data.items()) if data_attrs: - attrs = attrs + ' ' + data_attrs - attrs = ' ' + attrs if attrs else '' + attrs = attrs + " " + data_attrs + attrs = " " + attrs if attrs else "" if name in self._self_closed: - return '<%s%s />' % (name, attrs) - components = ''.join(htmlescape(v) for v in self.components) - return '<%s%s>%s' % (name, attrs, components, name) + return "<%s%s />" % (name, attrs) + components = "".join(htmlescape(v) for v in self.components) + return "<%s%s>%s" % (name, attrs, components, name) def __json__(self): return str(self) @@ -199,11 +188,11 @@ def __getitem__(self, name): class cat(HtmlTag): def __init__(self, *components): - self.components = [c for c in components] + self.components = list(components) self.attributes = {} def __html__(self): - return ''.join(htmlescape(v) for v in self.components) + return "".join(htmlescape(v) for v in self.components) class asis(HtmlTag): @@ -221,7 +210,7 @@ def _to_str(obj): def htmlescape(obj): - if hasattr(obj, '__html__'): + if hasattr(obj, "__html__"): return obj.__html__() return html.escape(_to_str(obj), True).replace("'", "'") diff --git a/emmett/http.py b/emmett/http.py index 1bac0d8d..2a6f1d62 100644 --- a/emmett/http.py +++ b/emmett/http.py @@ -1,72 +1,72 @@ # -*- coding: utf-8 -*- """ - emmett.http - ----------- +emmett.http +----------- - Provides the HTTP interfaces. +Provides the HTTP interfaces. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ from __future__ import annotations from emmett_core.http.helpers import redirect as _redirect from emmett_core.http.response import ( - HTTPResponse as HTTPResponse, - HTTPStringResponse as HTTP, + HTTPAsyncIterResponse as HTTPAsyncIter, HTTPBytesResponse as HTTPBytes, - HTTPIOResponse as HTTPIO, HTTPFileResponse as HTTPFile, - HTTPAsyncIterResponse as HTTPAsyncIter, - HTTPIterResponse as HTTPIter + HTTPIOResponse as HTTPIO, + HTTPIterResponse as HTTPIter, + HTTPResponse as HTTPResponse, + HTTPStringResponse as HTTP, ) from .ctx import current status_codes = { - 100: '100 CONTINUE', - 101: '101 SWITCHING PROTOCOLS', - 200: '200 OK', - 201: '201 CREATED', - 202: '202 ACCEPTED', - 203: '203 NON-AUTHORITATIVE INFORMATION', - 204: '204 NO CONTENT', - 205: '205 RESET CONTENT', - 206: '206 PARTIAL CONTENT', - 207: '207 MULTI-STATUS', - 300: '300 MULTIPLE CHOICES', - 301: '301 MOVED PERMANENTLY', - 302: '302 FOUND', - 303: '303 SEE OTHER', - 304: '304 NOT MODIFIED', - 305: '305 USE PROXY', - 307: '307 TEMPORARY REDIRECT', - 400: '400 BAD REQUEST', - 401: '401 UNAUTHORIZED', - 403: '403 FORBIDDEN', - 404: '404 NOT FOUND', - 405: '405 METHOD NOT ALLOWED', - 406: '406 NOT ACCEPTABLE', - 407: '407 PROXY AUTHENTICATION REQUIRED', - 408: '408 REQUEST TIMEOUT', - 409: '409 CONFLICT', - 410: '410 GONE', - 411: '411 LENGTH REQUIRED', - 412: '412 PRECONDITION FAILED', - 413: '413 REQUEST ENTITY TOO LARGE', - 414: '414 REQUEST-URI TOO LONG', - 415: '415 UNSUPPORTED MEDIA TYPE', - 416: '416 REQUESTED RANGE NOT SATISFIABLE', - 417: '417 EXPECTATION FAILED', - 422: '422 UNPROCESSABLE ENTITY', - 500: '500 INTERNAL SERVER ERROR', - 501: '501 NOT IMPLEMENTED', - 502: '502 BAD GATEWAY', - 503: '503 SERVICE UNAVAILABLE', - 504: '504 GATEWAY TIMEOUT', - 505: '505 HTTP VERSION NOT SUPPORTED', + 100: "100 CONTINUE", + 101: "101 SWITCHING PROTOCOLS", + 200: "200 OK", + 201: "201 CREATED", + 202: "202 ACCEPTED", + 203: "203 NON-AUTHORITATIVE INFORMATION", + 204: "204 NO CONTENT", + 205: "205 RESET CONTENT", + 206: "206 PARTIAL CONTENT", + 207: "207 MULTI-STATUS", + 300: "300 MULTIPLE CHOICES", + 301: "301 MOVED PERMANENTLY", + 302: "302 FOUND", + 303: "303 SEE OTHER", + 304: "304 NOT MODIFIED", + 305: "305 USE PROXY", + 307: "307 TEMPORARY REDIRECT", + 400: "400 BAD REQUEST", + 401: "401 UNAUTHORIZED", + 403: "403 FORBIDDEN", + 404: "404 NOT FOUND", + 405: "405 METHOD NOT ALLOWED", + 406: "406 NOT ACCEPTABLE", + 407: "407 PROXY AUTHENTICATION REQUIRED", + 408: "408 REQUEST TIMEOUT", + 409: "409 CONFLICT", + 410: "410 GONE", + 411: "411 LENGTH REQUIRED", + 412: "412 PRECONDITION FAILED", + 413: "413 REQUEST ENTITY TOO LARGE", + 414: "414 REQUEST-URI TOO LONG", + 415: "415 UNSUPPORTED MEDIA TYPE", + 416: "416 REQUESTED RANGE NOT SATISFIABLE", + 417: "417 EXPECTATION FAILED", + 422: "422 UNPROCESSABLE ENTITY", + 500: "500 INTERNAL SERVER ERROR", + 501: "501 NOT IMPLEMENTED", + 502: "502 BAD GATEWAY", + 503: "503 SERVICE UNAVAILABLE", + 504: "504 GATEWAY TIMEOUT", + 505: "505 HTTP VERSION NOT SUPPORTED", } diff --git a/emmett/language/helpers.py b/emmett/language/helpers.py index f18cd73e..7db1ab3c 100644 --- a/emmett/language/helpers.py +++ b/emmett/language/helpers.py @@ -1,12 +1,12 @@ # -*- coding: utf-8 -*- """ - emmett.language.helpers - ----------------------- +emmett.language.helpers +----------------------- - Translation helpers. +Translation helpers. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ import re @@ -19,18 +19,13 @@ class Tstr(_Tstr): __slots__ = [] def __getstate__(self): - return { - 'text': self.text, - 'lang': self.lang, - 'args': self.args, - 'kwargs': self.kwargs - } + return {"text": self.text, "lang": self.lang, "args": self.args, "kwargs": self.kwargs} def __setstate__(self, state): - self.text = state['text'] - self.lang = state['lang'] - self.args = state['args'] - self.kwargs = state['kwargs'] + self.text = state["text"] + self.lang = state["lang"] + self.args = state["args"] + self.kwargs = state["kwargs"] def __getattr__(self, name): return getattr(str(self), name) @@ -40,9 +35,10 @@ def __json__(self): class LanguageAccept(Accept): - regex_locale_delim = re.compile(r'[_-]') + regex_locale_delim = re.compile(r"[_-]") def _value_matches(self, value, item): def _normalize(language): return self.regex_locale_delim.split(language.lower())[0] - return item == '*' or _normalize(value) == _normalize(item) + + return item == "*" or _normalize(value) == _normalize(item) diff --git a/emmett/language/translator.py b/emmett/language/translator.py index 4ef6f5be..4fd602f2 100644 --- a/emmett/language/translator.py +++ b/emmett/language/translator.py @@ -1,17 +1,18 @@ # -*- coding: utf-8 -*- """ - emmett.language.translator - -------------------------- +emmett.language.translator +-------------------------- - Severus translator implementation for Emmett. +Severus translator implementation for Emmett. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ from __future__ import annotations from typing import Optional + from severus.ctx import set_context from severus.translator import Translator as _Translator @@ -32,6 +33,4 @@ def _update_config(self, default_language: str): self._build_languages() def _get_best_language(self, lang: Optional[str] = None) -> str: - return self._langmap.get( - lang or current.language, self._default_language - ) + return self._langmap.get(lang or current.language, self._default_language) diff --git a/emmett/libs/contenttype.py b/emmett/libs/contenttype.py index 5f9a3b65..d30a939e 100644 --- a/emmett/libs/contenttype.py +++ b/emmett/libs/contenttype.py @@ -1,717 +1,717 @@ # -*- coding: utf-8 -*- """ - Content-type helper based on file extension. +Content-type helper based on file extension. - Original code from web2py - by Massimo Di Pierro +Original code from web2py +by Massimo Di Pierro - :license: LGPLv3 (http://www.gnu.org/licenses/lgpl.html) +:license: LGPLv3 (http://www.gnu.org/licenses/lgpl.html) """ -__all__ = ['contenttype'] +__all__ = ["contenttype"] CONTENT_TYPE = { - '.load': 'text/html', - '.123': 'application/vnd.lotus-1-2-3', - '.3ds': 'image/x-3ds', - '.3g2': 'video/3gpp', - '.3ga': 'video/3gpp', - '.3gp': 'video/3gpp', - '.3gpp': 'video/3gpp', - '.602': 'application/x-t602', - '.669': 'audio/x-mod', - '.7z': 'application/x-7z-compressed', - '.a': 'application/x-archive', - '.aac': 'audio/mp4', - '.abw': 'application/x-abiword', - '.abw.crashed': 'application/x-abiword', - '.abw.gz': 'application/x-abiword', - '.ac3': 'audio/ac3', - '.ace': 'application/x-ace', - '.adb': 'text/x-adasrc', - '.ads': 'text/x-adasrc', - '.afm': 'application/x-font-afm', - '.ag': 'image/x-applix-graphics', - '.ai': 'application/illustrator', - '.aif': 'audio/x-aiff', - '.aifc': 'audio/x-aiff', - '.aiff': 'audio/x-aiff', - '.al': 'application/x-perl', - '.alz': 'application/x-alz', - '.amr': 'audio/amr', - '.ani': 'application/x-navi-animation', - '.anim[1-9j]': 'video/x-anim', - '.anx': 'application/annodex', - '.ape': 'audio/x-ape', - '.arj': 'application/x-arj', - '.arw': 'image/x-sony-arw', - '.as': 'application/x-applix-spreadsheet', - '.asc': 'text/plain', - '.asf': 'video/x-ms-asf', - '.asp': 'application/x-asp', - '.ass': 'text/x-ssa', - '.asx': 'audio/x-ms-asx', - '.atom': 'application/atom+xml', - '.au': 'audio/basic', - '.avi': 'video/x-msvideo', - '.aw': 'application/x-applix-word', - '.awb': 'audio/amr-wb', - '.awk': 'application/x-awk', - '.axa': 'audio/annodex', - '.axv': 'video/annodex', - '.bak': 'application/x-trash', - '.bcpio': 'application/x-bcpio', - '.bdf': 'application/x-font-bdf', - '.bib': 'text/x-bibtex', - '.bin': 'application/octet-stream', - '.blend': 'application/x-blender', - '.blender': 'application/x-blender', - '.bmp': 'image/bmp', - '.bz': 'application/x-bzip', - '.bz2': 'application/x-bzip', - '.c': 'text/x-csrc', - '.c++': 'text/x-c++src', - '.cab': 'application/vnd.ms-cab-compressed', - '.cb7': 'application/x-cb7', - '.cbr': 'application/x-cbr', - '.cbt': 'application/x-cbt', - '.cbz': 'application/x-cbz', - '.cc': 'text/x-c++src', - '.cdf': 'application/x-netcdf', - '.cdr': 'application/vnd.corel-draw', - '.cer': 'application/x-x509-ca-cert', - '.cert': 'application/x-x509-ca-cert', - '.cgm': 'image/cgm', - '.chm': 'application/x-chm', - '.chrt': 'application/x-kchart', - '.class': 'application/x-java', - '.cls': 'text/x-tex', - '.cmake': 'text/x-cmake', - '.cpio': 'application/x-cpio', - '.cpio.gz': 'application/x-cpio-compressed', - '.cpp': 'text/x-c++src', - '.cr2': 'image/x-canon-cr2', - '.crt': 'application/x-x509-ca-cert', - '.crw': 'image/x-canon-crw', - '.cs': 'text/x-csharp', - '.csh': 'application/x-csh', - '.css': 'text/css', - '.cssl': 'text/css', - '.csv': 'text/csv', - '.cue': 'application/x-cue', - '.cur': 'image/x-win-bitmap', - '.cxx': 'text/x-c++src', - '.d': 'text/x-dsrc', - '.dar': 'application/x-dar', - '.dbf': 'application/x-dbf', - '.dc': 'application/x-dc-rom', - '.dcl': 'text/x-dcl', - '.dcm': 'application/dicom', - '.dcr': 'image/x-kodak-dcr', - '.dds': 'image/x-dds', - '.deb': 'application/x-deb', - '.der': 'application/x-x509-ca-cert', - '.desktop': 'application/x-desktop', - '.dia': 'application/x-dia-diagram', - '.diff': 'text/x-patch', - '.divx': 'video/x-msvideo', - '.djv': 'image/vnd.djvu', - '.djvu': 'image/vnd.djvu', - '.dng': 'image/x-adobe-dng', - '.doc': 'application/msword', - '.docbook': 'application/docbook+xml', - '.docm': 'application/vnd.openxmlformats-officedocument.wordprocessingml.document', - '.docx': 'application/vnd.openxmlformats-officedocument.wordprocessingml.document', - '.dot': 'text/vnd.graphviz', - '.dsl': 'text/x-dsl', - '.dtd': 'application/xml-dtd', - '.dtx': 'text/x-tex', - '.dv': 'video/dv', - '.dvi': 'application/x-dvi', - '.dvi.bz2': 'application/x-bzdvi', - '.dvi.gz': 'application/x-gzdvi', - '.dwg': 'image/vnd.dwg', - '.dxf': 'image/vnd.dxf', - '.e': 'text/x-eiffel', - '.egon': 'application/x-egon', - '.eif': 'text/x-eiffel', - '.el': 'text/x-emacs-lisp', - '.emf': 'image/x-emf', - '.emp': 'application/vnd.emusic-emusic_package', - '.ent': 'application/xml-external-parsed-entity', - '.eps': 'image/x-eps', - '.eps.bz2': 'image/x-bzeps', - '.eps.gz': 'image/x-gzeps', - '.epsf': 'image/x-eps', - '.epsf.bz2': 'image/x-bzeps', - '.epsf.gz': 'image/x-gzeps', - '.epsi': 'image/x-eps', - '.epsi.bz2': 'image/x-bzeps', - '.epsi.gz': 'image/x-gzeps', - '.epub': 'application/epub+zip', - '.erl': 'text/x-erlang', - '.es': 'application/ecmascript', - '.etheme': 'application/x-e-theme', - '.etx': 'text/x-setext', - '.exe': 'application/x-ms-dos-executable', - '.exr': 'image/x-exr', - '.ez': 'application/andrew-inset', - '.f': 'text/x-fortran', - '.f90': 'text/x-fortran', - '.f95': 'text/x-fortran', - '.fb2': 'application/x-fictionbook+xml', - '.fig': 'image/x-xfig', - '.fits': 'image/fits', - '.fl': 'application/x-fluid', - '.flac': 'audio/x-flac', - '.flc': 'video/x-flic', - '.fli': 'video/x-flic', - '.flv': 'video/x-flv', - '.flw': 'application/x-kivio', - '.fo': 'text/x-xslfo', - '.for': 'text/x-fortran', - '.g3': 'image/fax-g3', - '.gb': 'application/x-gameboy-rom', - '.gba': 'application/x-gba-rom', - '.gcrd': 'text/directory', - '.ged': 'application/x-gedcom', - '.gedcom': 'application/x-gedcom', - '.gen': 'application/x-genesis-rom', - '.gf': 'application/x-tex-gf', - '.gg': 'application/x-sms-rom', - '.gif': 'image/gif', - '.glade': 'application/x-glade', - '.gmo': 'application/x-gettext-translation', - '.gnc': 'application/x-gnucash', - '.gnd': 'application/gnunet-directory', - '.gnucash': 'application/x-gnucash', - '.gnumeric': 'application/x-gnumeric', - '.gnuplot': 'application/x-gnuplot', - '.gp': 'application/x-gnuplot', - '.gpg': 'application/pgp-encrypted', - '.gplt': 'application/x-gnuplot', - '.gra': 'application/x-graphite', - '.gsf': 'application/x-font-type1', - '.gsm': 'audio/x-gsm', - '.gtar': 'application/x-tar', - '.gv': 'text/vnd.graphviz', - '.gvp': 'text/x-google-video-pointer', - '.gz': 'application/x-gzip', - '.h': 'text/x-chdr', - '.h++': 'text/x-c++hdr', - '.hdf': 'application/x-hdf', - '.hh': 'text/x-c++hdr', - '.hp': 'text/x-c++hdr', - '.hpgl': 'application/vnd.hp-hpgl', - '.hpp': 'text/x-c++hdr', - '.hs': 'text/x-haskell', - '.htm': 'text/html', - '.html': 'text/html', - '.hwp': 'application/x-hwp', - '.hwt': 'application/x-hwt', - '.hxx': 'text/x-c++hdr', - '.ica': 'application/x-ica', - '.icb': 'image/x-tga', - '.icns': 'image/x-icns', - '.ico': 'image/vnd.microsoft.icon', - '.ics': 'text/calendar', - '.idl': 'text/x-idl', - '.ief': 'image/ief', - '.iff': 'image/x-iff', - '.ilbm': 'image/x-ilbm', - '.ime': 'text/x-imelody', - '.imy': 'text/x-imelody', - '.ins': 'text/x-tex', - '.iptables': 'text/x-iptables', - '.iso': 'application/x-cd-image', - '.iso9660': 'application/x-cd-image', - '.it': 'audio/x-it', - '.j2k': 'image/jp2', - '.jad': 'text/vnd.sun.j2me.app-descriptor', - '.jar': 'application/x-java-archive', - '.java': 'text/x-java', - '.jng': 'image/x-jng', - '.jnlp': 'application/x-java-jnlp-file', - '.jp2': 'image/jp2', - '.jpc': 'image/jp2', - '.jpe': 'image/jpeg', - '.jpeg': 'image/jpeg', - '.jpf': 'image/jp2', - '.jpg': 'image/jpeg', - '.jpr': 'application/x-jbuilder-project', - '.jpx': 'image/jp2', - '.js': 'application/javascript', - '.json': 'application/json', - '.jsonp': 'application/jsonp', - '.k25': 'image/x-kodak-k25', - '.kar': 'audio/midi', - '.karbon': 'application/x-karbon', - '.kdc': 'image/x-kodak-kdc', - '.kdelnk': 'application/x-desktop', - '.kexi': 'application/x-kexiproject-sqlite3', - '.kexic': 'application/x-kexi-connectiondata', - '.kexis': 'application/x-kexiproject-shortcut', - '.kfo': 'application/x-kformula', - '.kil': 'application/x-killustrator', - '.kino': 'application/smil', - '.kml': 'application/vnd.google-earth.kml+xml', - '.kmz': 'application/vnd.google-earth.kmz', - '.kon': 'application/x-kontour', - '.kpm': 'application/x-kpovmodeler', - '.kpr': 'application/x-kpresenter', - '.kpt': 'application/x-kpresenter', - '.kra': 'application/x-krita', - '.ksp': 'application/x-kspread', - '.kud': 'application/x-kugar', - '.kwd': 'application/x-kword', - '.kwt': 'application/x-kword', - '.la': 'application/x-shared-library-la', - '.latex': 'text/x-tex', - '.ldif': 'text/x-ldif', - '.lha': 'application/x-lha', - '.lhs': 'text/x-literate-haskell', - '.lhz': 'application/x-lhz', - '.log': 'text/x-log', - '.ltx': 'text/x-tex', - '.lua': 'text/x-lua', - '.lwo': 'image/x-lwo', - '.lwob': 'image/x-lwo', - '.lws': 'image/x-lws', - '.ly': 'text/x-lilypond', - '.lyx': 'application/x-lyx', - '.lz': 'application/x-lzip', - '.lzh': 'application/x-lha', - '.lzma': 'application/x-lzma', - '.lzo': 'application/x-lzop', - '.m': 'text/x-matlab', - '.m15': 'audio/x-mod', - '.m2t': 'video/mpeg', - '.m3u': 'audio/x-mpegurl', - '.m3u8': 'audio/x-mpegurl', - '.m4': 'application/x-m4', - '.m4a': 'audio/mp4', - '.m4b': 'audio/x-m4b', - '.m4v': 'video/mp4', - '.mab': 'application/x-markaby', - '.man': 'application/x-troff-man', - '.mbox': 'application/mbox', - '.md': 'application/x-genesis-rom', - '.mdb': 'application/vnd.ms-access', - '.mdi': 'image/vnd.ms-modi', - '.me': 'text/x-troff-me', - '.med': 'audio/x-mod', - '.metalink': 'application/metalink+xml', - '.mgp': 'application/x-magicpoint', - '.mid': 'audio/midi', - '.midi': 'audio/midi', - '.mif': 'application/x-mif', - '.minipsf': 'audio/x-minipsf', - '.mka': 'audio/x-matroska', - '.mkv': 'video/x-matroska', - '.ml': 'text/x-ocaml', - '.mli': 'text/x-ocaml', - '.mm': 'text/x-troff-mm', - '.mmf': 'application/x-smaf', - '.mml': 'text/mathml', - '.mng': 'video/x-mng', - '.mo': 'application/x-gettext-translation', - '.mo3': 'audio/x-mo3', - '.moc': 'text/x-moc', - '.mod': 'audio/x-mod', - '.mof': 'text/x-mof', - '.moov': 'video/quicktime', - '.mov': 'video/quicktime', - '.movie': 'video/x-sgi-movie', - '.mp+': 'audio/x-musepack', - '.mp2': 'video/mpeg', - '.mp3': 'audio/mpeg', - '.mp4': 'video/mp4', - '.mpc': 'audio/x-musepack', - '.mpe': 'video/mpeg', - '.mpeg': 'video/mpeg', - '.mpg': 'video/mpeg', - '.mpga': 'audio/mpeg', - '.mpp': 'audio/x-musepack', - '.mrl': 'text/x-mrml', - '.mrml': 'text/x-mrml', - '.mrw': 'image/x-minolta-mrw', - '.ms': 'text/x-troff-ms', - '.msi': 'application/x-msi', - '.msod': 'image/x-msod', - '.msx': 'application/x-msx-rom', - '.mtm': 'audio/x-mod', - '.mup': 'text/x-mup', - '.mxf': 'application/mxf', - '.n64': 'application/x-n64-rom', - '.nb': 'application/mathematica', - '.nc': 'application/x-netcdf', - '.nds': 'application/x-nintendo-ds-rom', - '.nef': 'image/x-nikon-nef', - '.nes': 'application/x-nes-rom', - '.nfo': 'text/x-nfo', - '.not': 'text/x-mup', - '.nsc': 'application/x-netshow-channel', - '.nsv': 'video/x-nsv', - '.o': 'application/x-object', - '.obj': 'application/x-tgif', - '.ocl': 'text/x-ocl', - '.oda': 'application/oda', - '.odb': 'application/vnd.oasis.opendocument.database', - '.odc': 'application/vnd.oasis.opendocument.chart', - '.odf': 'application/vnd.oasis.opendocument.formula', - '.odg': 'application/vnd.oasis.opendocument.graphics', - '.odi': 'application/vnd.oasis.opendocument.image', - '.odm': 'application/vnd.oasis.opendocument.text-master', - '.odp': 'application/vnd.oasis.opendocument.presentation', - '.ods': 'application/vnd.oasis.opendocument.spreadsheet', - '.odt': 'application/vnd.oasis.opendocument.text', - '.oga': 'audio/ogg', - '.ogg': 'video/x-theora+ogg', - '.ogm': 'video/x-ogm+ogg', - '.ogv': 'video/ogg', - '.ogx': 'application/ogg', - '.old': 'application/x-trash', - '.oleo': 'application/x-oleo', - '.opml': 'text/x-opml+xml', - '.ora': 'image/openraster', - '.orf': 'image/x-olympus-orf', - '.otc': 'application/vnd.oasis.opendocument.chart-template', - '.otf': 'application/x-font-otf', - '.otg': 'application/vnd.oasis.opendocument.graphics-template', - '.oth': 'application/vnd.oasis.opendocument.text-web', - '.otp': 'application/vnd.oasis.opendocument.presentation-template', - '.ots': 'application/vnd.oasis.opendocument.spreadsheet-template', - '.ott': 'application/vnd.oasis.opendocument.text-template', - '.owl': 'application/rdf+xml', - '.oxt': 'application/vnd.openofficeorg.extension', - '.p': 'text/x-pascal', - '.p10': 'application/pkcs10', - '.p12': 'application/x-pkcs12', - '.p7b': 'application/x-pkcs7-certificates', - '.p7s': 'application/pkcs7-signature', - '.pack': 'application/x-java-pack200', - '.pak': 'application/x-pak', - '.par2': 'application/x-par2', - '.pas': 'text/x-pascal', - '.patch': 'text/x-patch', - '.pbm': 'image/x-portable-bitmap', - '.pcd': 'image/x-photo-cd', - '.pcf': 'application/x-cisco-vpn-settings', - '.pcf.gz': 'application/x-font-pcf', - '.pcf.z': 'application/x-font-pcf', - '.pcl': 'application/vnd.hp-pcl', - '.pcx': 'image/x-pcx', - '.pdb': 'chemical/x-pdb', - '.pdc': 'application/x-aportisdoc', - '.pdf': 'application/pdf', - '.pdf.bz2': 'application/x-bzpdf', - '.pdf.gz': 'application/x-gzpdf', - '.pef': 'image/x-pentax-pef', - '.pem': 'application/x-x509-ca-cert', - '.perl': 'application/x-perl', - '.pfa': 'application/x-font-type1', - '.pfb': 'application/x-font-type1', - '.pfx': 'application/x-pkcs12', - '.pgm': 'image/x-portable-graymap', - '.pgn': 'application/x-chess-pgn', - '.pgp': 'application/pgp-encrypted', - '.php': 'application/x-php', - '.php3': 'application/x-php', - '.php4': 'application/x-php', - '.pict': 'image/x-pict', - '.pict1': 'image/x-pict', - '.pict2': 'image/x-pict', - '.pickle': 'application/python-pickle', - '.pk': 'application/x-tex-pk', - '.pkipath': 'application/pkix-pkipath', - '.pkr': 'application/pgp-keys', - '.pl': 'application/x-perl', - '.pla': 'audio/x-iriver-pla', - '.pln': 'application/x-planperfect', - '.pls': 'audio/x-scpls', - '.pm': 'application/x-perl', - '.png': 'image/png', - '.pnm': 'image/x-portable-anymap', - '.pntg': 'image/x-macpaint', - '.po': 'text/x-gettext-translation', - '.por': 'application/x-spss-por', - '.pot': 'text/x-gettext-translation-template', - '.ppm': 'image/x-portable-pixmap', - '.pps': 'application/vnd.ms-powerpoint', - '.ppt': 'application/vnd.ms-powerpoint', - '.pptm': 'application/vnd.openxmlformats-officedocument.presentationml.presentation', - '.pptx': 'application/vnd.openxmlformats-officedocument.presentationml.presentation', - '.ppz': 'application/vnd.ms-powerpoint', - '.prc': 'application/x-palm-database', - '.ps': 'application/postscript', - '.ps.bz2': 'application/x-bzpostscript', - '.ps.gz': 'application/x-gzpostscript', - '.psd': 'image/vnd.adobe.photoshop', - '.psf': 'audio/x-psf', - '.psf.gz': 'application/x-gz-font-linux-psf', - '.psflib': 'audio/x-psflib', - '.psid': 'audio/prs.sid', - '.psw': 'application/x-pocket-word', - '.pw': 'application/x-pw', - '.py': 'text/x-python', - '.pyc': 'application/x-python-bytecode', - '.pyo': 'application/x-python-bytecode', - '.qif': 'image/x-quicktime', - '.qt': 'video/quicktime', - '.qtif': 'image/x-quicktime', - '.qtl': 'application/x-quicktime-media-link', - '.qtvr': 'video/quicktime', - '.ra': 'audio/vnd.rn-realaudio', - '.raf': 'image/x-fuji-raf', - '.ram': 'application/ram', - '.rar': 'application/x-rar', - '.ras': 'image/x-cmu-raster', - '.raw': 'image/x-panasonic-raw', - '.rax': 'audio/vnd.rn-realaudio', - '.rb': 'application/x-ruby', - '.rdf': 'application/rdf+xml', - '.rdfs': 'application/rdf+xml', - '.reg': 'text/x-ms-regedit', - '.rej': 'application/x-reject', - '.rgb': 'image/x-rgb', - '.rle': 'image/rle', - '.rm': 'application/vnd.rn-realmedia', - '.rmj': 'application/vnd.rn-realmedia', - '.rmm': 'application/vnd.rn-realmedia', - '.rms': 'application/vnd.rn-realmedia', - '.rmvb': 'application/vnd.rn-realmedia', - '.rmx': 'application/vnd.rn-realmedia', - '.roff': 'text/troff', - '.rp': 'image/vnd.rn-realpix', - '.rpm': 'application/x-rpm', - '.rss': 'application/rss+xml', - '.rt': 'text/vnd.rn-realtext', - '.rtf': 'application/rtf', - '.rtx': 'text/richtext', - '.rv': 'video/vnd.rn-realvideo', - '.rvx': 'video/vnd.rn-realvideo', - '.s3m': 'audio/x-s3m', - '.sam': 'application/x-amipro', - '.sami': 'application/x-sami', - '.sav': 'application/x-spss-sav', - '.scm': 'text/x-scheme', - '.sda': 'application/vnd.stardivision.draw', - '.sdc': 'application/vnd.stardivision.calc', - '.sdd': 'application/vnd.stardivision.impress', - '.sdp': 'application/sdp', - '.sds': 'application/vnd.stardivision.chart', - '.sdw': 'application/vnd.stardivision.writer', - '.sgf': 'application/x-go-sgf', - '.sgi': 'image/x-sgi', - '.sgl': 'application/vnd.stardivision.writer', - '.sgm': 'text/sgml', - '.sgml': 'text/sgml', - '.sh': 'application/x-shellscript', - '.shar': 'application/x-shar', - '.shn': 'application/x-shorten', - '.siag': 'application/x-siag', - '.sid': 'audio/prs.sid', - '.sik': 'application/x-trash', - '.sis': 'application/vnd.symbian.install', - '.sisx': 'x-epoc/x-sisx-app', - '.sit': 'application/x-stuffit', - '.siv': 'application/sieve', - '.sk': 'image/x-skencil', - '.sk1': 'image/x-skencil', - '.skr': 'application/pgp-keys', - '.slk': 'text/spreadsheet', - '.smaf': 'application/x-smaf', - '.smc': 'application/x-snes-rom', - '.smd': 'application/vnd.stardivision.mail', - '.smf': 'application/vnd.stardivision.math', - '.smi': 'application/x-sami', - '.smil': 'application/smil', - '.sml': 'application/smil', - '.sms': 'application/x-sms-rom', - '.snd': 'audio/basic', - '.so': 'application/x-sharedlib', - '.spc': 'application/x-pkcs7-certificates', - '.spd': 'application/x-font-speedo', - '.spec': 'text/x-rpm-spec', - '.spl': 'application/x-shockwave-flash', - '.spx': 'audio/x-speex', - '.sql': 'text/x-sql', - '.sr2': 'image/x-sony-sr2', - '.src': 'application/x-wais-source', - '.srf': 'image/x-sony-srf', - '.srt': 'application/x-subrip', - '.ssa': 'text/x-ssa', - '.stc': 'application/vnd.sun.xml.calc.template', - '.std': 'application/vnd.sun.xml.draw.template', - '.sti': 'application/vnd.sun.xml.impress.template', - '.stm': 'audio/x-stm', - '.stw': 'application/vnd.sun.xml.writer.template', - '.sty': 'text/x-tex', - '.sub': 'text/x-subviewer', - '.sun': 'image/x-sun-raster', - '.sv4cpio': 'application/x-sv4cpio', - '.sv4crc': 'application/x-sv4crc', - '.svg': 'image/svg+xml', - '.svgz': 'image/svg+xml-compressed', - '.swf': 'application/x-shockwave-flash', - '.sxc': 'application/vnd.sun.xml.calc', - '.sxd': 'application/vnd.sun.xml.draw', - '.sxg': 'application/vnd.sun.xml.writer.global', - '.sxi': 'application/vnd.sun.xml.impress', - '.sxm': 'application/vnd.sun.xml.math', - '.sxw': 'application/vnd.sun.xml.writer', - '.sylk': 'text/spreadsheet', - '.t': 'text/troff', - '.t2t': 'text/x-txt2tags', - '.tar': 'application/x-tar', - '.tar.bz': 'application/x-bzip-compressed-tar', - '.tar.bz2': 'application/x-bzip-compressed-tar', - '.tar.gz': 'application/x-compressed-tar', - '.tar.lzma': 'application/x-lzma-compressed-tar', - '.tar.lzo': 'application/x-tzo', - '.tar.xz': 'application/x-xz-compressed-tar', - '.tar.z': 'application/x-tarz', - '.tbz': 'application/x-bzip-compressed-tar', - '.tbz2': 'application/x-bzip-compressed-tar', - '.tcl': 'text/x-tcl', - '.tex': 'text/x-tex', - '.texi': 'text/x-texinfo', - '.texinfo': 'text/x-texinfo', - '.tga': 'image/x-tga', - '.tgz': 'application/x-compressed-tar', - '.theme': 'application/x-theme', - '.themepack': 'application/x-windows-themepack', - '.tif': 'image/tiff', - '.tiff': 'image/tiff', - '.tk': 'text/x-tcl', - '.tlz': 'application/x-lzma-compressed-tar', - '.tnef': 'application/vnd.ms-tnef', - '.tnf': 'application/vnd.ms-tnef', - '.toc': 'application/x-cdrdao-toc', - '.torrent': 'application/x-bittorrent', - '.tpic': 'image/x-tga', - '.tr': 'text/troff', - '.ts': 'application/x-linguist', - '.tsv': 'text/tab-separated-values', - '.tta': 'audio/x-tta', - '.ttc': 'application/x-font-ttf', - '.ttf': 'application/x-font-ttf', - '.ttx': 'application/x-font-ttx', - '.txt': 'text/plain', - '.txz': 'application/x-xz-compressed-tar', - '.tzo': 'application/x-tzo', - '.ufraw': 'application/x-ufraw', - '.ui': 'application/x-designer', - '.uil': 'text/x-uil', - '.ult': 'audio/x-mod', - '.uni': 'audio/x-mod', - '.uri': 'text/x-uri', - '.url': 'text/x-uri', - '.ustar': 'application/x-ustar', - '.vala': 'text/x-vala', - '.vapi': 'text/x-vala', - '.vcf': 'text/directory', - '.vcs': 'text/calendar', - '.vct': 'text/directory', - '.vda': 'image/x-tga', - '.vhd': 'text/x-vhdl', - '.vhdl': 'text/x-vhdl', - '.viv': 'video/vivo', - '.vivo': 'video/vivo', - '.vlc': 'audio/x-mpegurl', - '.vob': 'video/mpeg', - '.voc': 'audio/x-voc', - '.vor': 'application/vnd.stardivision.writer', - '.vst': 'image/x-tga', - '.wav': 'audio/x-wav', - '.wax': 'audio/x-ms-asx', - '.wb1': 'application/x-quattropro', - '.wb2': 'application/x-quattropro', - '.wb3': 'application/x-quattropro', - '.wbmp': 'image/vnd.wap.wbmp', - '.wcm': 'application/vnd.ms-works', - '.wdb': 'application/vnd.ms-works', - '.webm': 'video/webm', - '.wk1': 'application/vnd.lotus-1-2-3', - '.wk3': 'application/vnd.lotus-1-2-3', - '.wk4': 'application/vnd.lotus-1-2-3', - '.wks': 'application/vnd.ms-works', - '.wma': 'audio/x-ms-wma', - '.wmf': 'image/x-wmf', - '.wml': 'text/vnd.wap.wml', - '.wmls': 'text/vnd.wap.wmlscript', - '.wmv': 'video/x-ms-wmv', - '.wmx': 'audio/x-ms-asx', - '.wp': 'application/vnd.wordperfect', - '.wp4': 'application/vnd.wordperfect', - '.wp5': 'application/vnd.wordperfect', - '.wp6': 'application/vnd.wordperfect', - '.wpd': 'application/vnd.wordperfect', - '.wpg': 'application/x-wpg', - '.wpl': 'application/vnd.ms-wpl', - '.wpp': 'application/vnd.wordperfect', - '.wps': 'application/vnd.ms-works', - '.wri': 'application/x-mswrite', - '.wrl': 'model/vrml', - '.wv': 'audio/x-wavpack', - '.wvc': 'audio/x-wavpack-correction', - '.wvp': 'audio/x-wavpack', - '.wvx': 'audio/x-ms-asx', - '.x3f': 'image/x-sigma-x3f', - '.xac': 'application/x-gnucash', - '.xbel': 'application/x-xbel', - '.xbl': 'application/xml', - '.xbm': 'image/x-xbitmap', - '.xcf': 'image/x-xcf', - '.xcf.bz2': 'image/x-compressed-xcf', - '.xcf.gz': 'image/x-compressed-xcf', - '.xhtml': 'application/xhtml+xml', - '.xi': 'audio/x-xi', - '.xla': 'application/vnd.ms-excel', - '.xlc': 'application/vnd.ms-excel', - '.xld': 'application/vnd.ms-excel', - '.xlf': 'application/x-xliff', - '.xliff': 'application/x-xliff', - '.xll': 'application/vnd.ms-excel', - '.xlm': 'application/vnd.ms-excel', - '.xls': 'application/vnd.ms-excel', - '.xlsm': 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet', - '.xlsx': 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet', - '.xlt': 'application/vnd.ms-excel', - '.xlw': 'application/vnd.ms-excel', - '.xm': 'audio/x-xm', - '.xmf': 'audio/x-xmf', - '.xmi': 'text/x-xmi', - '.xml': 'application/xml', - '.xpm': 'image/x-xpixmap', - '.xps': 'application/vnd.ms-xpsdocument', - '.xsl': 'application/xml', - '.xslfo': 'text/x-xslfo', - '.xslt': 'application/xml', - '.xspf': 'application/xspf+xml', - '.xul': 'application/vnd.mozilla.xul+xml', - '.xwd': 'image/x-xwindowdump', - '.xyz': 'chemical/x-pdb', - '.xz': 'application/x-xz', - '.w2p': 'application/w2p', - '.z': 'application/x-compress', - '.zabw': 'application/x-abiword', - '.zip': 'application/zip', - '.zoo': 'application/x-zoo', + ".load": "text/html", + ".123": "application/vnd.lotus-1-2-3", + ".3ds": "image/x-3ds", + ".3g2": "video/3gpp", + ".3ga": "video/3gpp", + ".3gp": "video/3gpp", + ".3gpp": "video/3gpp", + ".602": "application/x-t602", + ".669": "audio/x-mod", + ".7z": "application/x-7z-compressed", + ".a": "application/x-archive", + ".aac": "audio/mp4", + ".abw": "application/x-abiword", + ".abw.crashed": "application/x-abiword", + ".abw.gz": "application/x-abiword", + ".ac3": "audio/ac3", + ".ace": "application/x-ace", + ".adb": "text/x-adasrc", + ".ads": "text/x-adasrc", + ".afm": "application/x-font-afm", + ".ag": "image/x-applix-graphics", + ".ai": "application/illustrator", + ".aif": "audio/x-aiff", + ".aifc": "audio/x-aiff", + ".aiff": "audio/x-aiff", + ".al": "application/x-perl", + ".alz": "application/x-alz", + ".amr": "audio/amr", + ".ani": "application/x-navi-animation", + ".anim[1-9j]": "video/x-anim", + ".anx": "application/annodex", + ".ape": "audio/x-ape", + ".arj": "application/x-arj", + ".arw": "image/x-sony-arw", + ".as": "application/x-applix-spreadsheet", + ".asc": "text/plain", + ".asf": "video/x-ms-asf", + ".asp": "application/x-asp", + ".ass": "text/x-ssa", + ".asx": "audio/x-ms-asx", + ".atom": "application/atom+xml", + ".au": "audio/basic", + ".avi": "video/x-msvideo", + ".aw": "application/x-applix-word", + ".awb": "audio/amr-wb", + ".awk": "application/x-awk", + ".axa": "audio/annodex", + ".axv": "video/annodex", + ".bak": "application/x-trash", + ".bcpio": "application/x-bcpio", + ".bdf": "application/x-font-bdf", + ".bib": "text/x-bibtex", + ".bin": "application/octet-stream", + ".blend": "application/x-blender", + ".blender": "application/x-blender", + ".bmp": "image/bmp", + ".bz": "application/x-bzip", + ".bz2": "application/x-bzip", + ".c": "text/x-csrc", + ".c++": "text/x-c++src", + ".cab": "application/vnd.ms-cab-compressed", + ".cb7": "application/x-cb7", + ".cbr": "application/x-cbr", + ".cbt": "application/x-cbt", + ".cbz": "application/x-cbz", + ".cc": "text/x-c++src", + ".cdf": "application/x-netcdf", + ".cdr": "application/vnd.corel-draw", + ".cer": "application/x-x509-ca-cert", + ".cert": "application/x-x509-ca-cert", + ".cgm": "image/cgm", + ".chm": "application/x-chm", + ".chrt": "application/x-kchart", + ".class": "application/x-java", + ".cls": "text/x-tex", + ".cmake": "text/x-cmake", + ".cpio": "application/x-cpio", + ".cpio.gz": "application/x-cpio-compressed", + ".cpp": "text/x-c++src", + ".cr2": "image/x-canon-cr2", + ".crt": "application/x-x509-ca-cert", + ".crw": "image/x-canon-crw", + ".cs": "text/x-csharp", + ".csh": "application/x-csh", + ".css": "text/css", + ".cssl": "text/css", + ".csv": "text/csv", + ".cue": "application/x-cue", + ".cur": "image/x-win-bitmap", + ".cxx": "text/x-c++src", + ".d": "text/x-dsrc", + ".dar": "application/x-dar", + ".dbf": "application/x-dbf", + ".dc": "application/x-dc-rom", + ".dcl": "text/x-dcl", + ".dcm": "application/dicom", + ".dcr": "image/x-kodak-dcr", + ".dds": "image/x-dds", + ".deb": "application/x-deb", + ".der": "application/x-x509-ca-cert", + ".desktop": "application/x-desktop", + ".dia": "application/x-dia-diagram", + ".diff": "text/x-patch", + ".divx": "video/x-msvideo", + ".djv": "image/vnd.djvu", + ".djvu": "image/vnd.djvu", + ".dng": "image/x-adobe-dng", + ".doc": "application/msword", + ".docbook": "application/docbook+xml", + ".docm": "application/vnd.openxmlformats-officedocument.wordprocessingml.document", + ".docx": "application/vnd.openxmlformats-officedocument.wordprocessingml.document", + ".dot": "text/vnd.graphviz", + ".dsl": "text/x-dsl", + ".dtd": "application/xml-dtd", + ".dtx": "text/x-tex", + ".dv": "video/dv", + ".dvi": "application/x-dvi", + ".dvi.bz2": "application/x-bzdvi", + ".dvi.gz": "application/x-gzdvi", + ".dwg": "image/vnd.dwg", + ".dxf": "image/vnd.dxf", + ".e": "text/x-eiffel", + ".egon": "application/x-egon", + ".eif": "text/x-eiffel", + ".el": "text/x-emacs-lisp", + ".emf": "image/x-emf", + ".emp": "application/vnd.emusic-emusic_package", + ".ent": "application/xml-external-parsed-entity", + ".eps": "image/x-eps", + ".eps.bz2": "image/x-bzeps", + ".eps.gz": "image/x-gzeps", + ".epsf": "image/x-eps", + ".epsf.bz2": "image/x-bzeps", + ".epsf.gz": "image/x-gzeps", + ".epsi": "image/x-eps", + ".epsi.bz2": "image/x-bzeps", + ".epsi.gz": "image/x-gzeps", + ".epub": "application/epub+zip", + ".erl": "text/x-erlang", + ".es": "application/ecmascript", + ".etheme": "application/x-e-theme", + ".etx": "text/x-setext", + ".exe": "application/x-ms-dos-executable", + ".exr": "image/x-exr", + ".ez": "application/andrew-inset", + ".f": "text/x-fortran", + ".f90": "text/x-fortran", + ".f95": "text/x-fortran", + ".fb2": "application/x-fictionbook+xml", + ".fig": "image/x-xfig", + ".fits": "image/fits", + ".fl": "application/x-fluid", + ".flac": "audio/x-flac", + ".flc": "video/x-flic", + ".fli": "video/x-flic", + ".flv": "video/x-flv", + ".flw": "application/x-kivio", + ".fo": "text/x-xslfo", + ".for": "text/x-fortran", + ".g3": "image/fax-g3", + ".gb": "application/x-gameboy-rom", + ".gba": "application/x-gba-rom", + ".gcrd": "text/directory", + ".ged": "application/x-gedcom", + ".gedcom": "application/x-gedcom", + ".gen": "application/x-genesis-rom", + ".gf": "application/x-tex-gf", + ".gg": "application/x-sms-rom", + ".gif": "image/gif", + ".glade": "application/x-glade", + ".gmo": "application/x-gettext-translation", + ".gnc": "application/x-gnucash", + ".gnd": "application/gnunet-directory", + ".gnucash": "application/x-gnucash", + ".gnumeric": "application/x-gnumeric", + ".gnuplot": "application/x-gnuplot", + ".gp": "application/x-gnuplot", + ".gpg": "application/pgp-encrypted", + ".gplt": "application/x-gnuplot", + ".gra": "application/x-graphite", + ".gsf": "application/x-font-type1", + ".gsm": "audio/x-gsm", + ".gtar": "application/x-tar", + ".gv": "text/vnd.graphviz", + ".gvp": "text/x-google-video-pointer", + ".gz": "application/x-gzip", + ".h": "text/x-chdr", + ".h++": "text/x-c++hdr", + ".hdf": "application/x-hdf", + ".hh": "text/x-c++hdr", + ".hp": "text/x-c++hdr", + ".hpgl": "application/vnd.hp-hpgl", + ".hpp": "text/x-c++hdr", + ".hs": "text/x-haskell", + ".htm": "text/html", + ".html": "text/html", + ".hwp": "application/x-hwp", + ".hwt": "application/x-hwt", + ".hxx": "text/x-c++hdr", + ".ica": "application/x-ica", + ".icb": "image/x-tga", + ".icns": "image/x-icns", + ".ico": "image/vnd.microsoft.icon", + ".ics": "text/calendar", + ".idl": "text/x-idl", + ".ief": "image/ief", + ".iff": "image/x-iff", + ".ilbm": "image/x-ilbm", + ".ime": "text/x-imelody", + ".imy": "text/x-imelody", + ".ins": "text/x-tex", + ".iptables": "text/x-iptables", + ".iso": "application/x-cd-image", + ".iso9660": "application/x-cd-image", + ".it": "audio/x-it", + ".j2k": "image/jp2", + ".jad": "text/vnd.sun.j2me.app-descriptor", + ".jar": "application/x-java-archive", + ".java": "text/x-java", + ".jng": "image/x-jng", + ".jnlp": "application/x-java-jnlp-file", + ".jp2": "image/jp2", + ".jpc": "image/jp2", + ".jpe": "image/jpeg", + ".jpeg": "image/jpeg", + ".jpf": "image/jp2", + ".jpg": "image/jpeg", + ".jpr": "application/x-jbuilder-project", + ".jpx": "image/jp2", + ".js": "application/javascript", + ".json": "application/json", + ".jsonp": "application/jsonp", + ".k25": "image/x-kodak-k25", + ".kar": "audio/midi", + ".karbon": "application/x-karbon", + ".kdc": "image/x-kodak-kdc", + ".kdelnk": "application/x-desktop", + ".kexi": "application/x-kexiproject-sqlite3", + ".kexic": "application/x-kexi-connectiondata", + ".kexis": "application/x-kexiproject-shortcut", + ".kfo": "application/x-kformula", + ".kil": "application/x-killustrator", + ".kino": "application/smil", + ".kml": "application/vnd.google-earth.kml+xml", + ".kmz": "application/vnd.google-earth.kmz", + ".kon": "application/x-kontour", + ".kpm": "application/x-kpovmodeler", + ".kpr": "application/x-kpresenter", + ".kpt": "application/x-kpresenter", + ".kra": "application/x-krita", + ".ksp": "application/x-kspread", + ".kud": "application/x-kugar", + ".kwd": "application/x-kword", + ".kwt": "application/x-kword", + ".la": "application/x-shared-library-la", + ".latex": "text/x-tex", + ".ldif": "text/x-ldif", + ".lha": "application/x-lha", + ".lhs": "text/x-literate-haskell", + ".lhz": "application/x-lhz", + ".log": "text/x-log", + ".ltx": "text/x-tex", + ".lua": "text/x-lua", + ".lwo": "image/x-lwo", + ".lwob": "image/x-lwo", + ".lws": "image/x-lws", + ".ly": "text/x-lilypond", + ".lyx": "application/x-lyx", + ".lz": "application/x-lzip", + ".lzh": "application/x-lha", + ".lzma": "application/x-lzma", + ".lzo": "application/x-lzop", + ".m": "text/x-matlab", + ".m15": "audio/x-mod", + ".m2t": "video/mpeg", + ".m3u": "audio/x-mpegurl", + ".m3u8": "audio/x-mpegurl", + ".m4": "application/x-m4", + ".m4a": "audio/mp4", + ".m4b": "audio/x-m4b", + ".m4v": "video/mp4", + ".mab": "application/x-markaby", + ".man": "application/x-troff-man", + ".mbox": "application/mbox", + ".md": "application/x-genesis-rom", + ".mdb": "application/vnd.ms-access", + ".mdi": "image/vnd.ms-modi", + ".me": "text/x-troff-me", + ".med": "audio/x-mod", + ".metalink": "application/metalink+xml", + ".mgp": "application/x-magicpoint", + ".mid": "audio/midi", + ".midi": "audio/midi", + ".mif": "application/x-mif", + ".minipsf": "audio/x-minipsf", + ".mka": "audio/x-matroska", + ".mkv": "video/x-matroska", + ".ml": "text/x-ocaml", + ".mli": "text/x-ocaml", + ".mm": "text/x-troff-mm", + ".mmf": "application/x-smaf", + ".mml": "text/mathml", + ".mng": "video/x-mng", + ".mo": "application/x-gettext-translation", + ".mo3": "audio/x-mo3", + ".moc": "text/x-moc", + ".mod": "audio/x-mod", + ".mof": "text/x-mof", + ".moov": "video/quicktime", + ".mov": "video/quicktime", + ".movie": "video/x-sgi-movie", + ".mp+": "audio/x-musepack", + ".mp2": "video/mpeg", + ".mp3": "audio/mpeg", + ".mp4": "video/mp4", + ".mpc": "audio/x-musepack", + ".mpe": "video/mpeg", + ".mpeg": "video/mpeg", + ".mpg": "video/mpeg", + ".mpga": "audio/mpeg", + ".mpp": "audio/x-musepack", + ".mrl": "text/x-mrml", + ".mrml": "text/x-mrml", + ".mrw": "image/x-minolta-mrw", + ".ms": "text/x-troff-ms", + ".msi": "application/x-msi", + ".msod": "image/x-msod", + ".msx": "application/x-msx-rom", + ".mtm": "audio/x-mod", + ".mup": "text/x-mup", + ".mxf": "application/mxf", + ".n64": "application/x-n64-rom", + ".nb": "application/mathematica", + ".nc": "application/x-netcdf", + ".nds": "application/x-nintendo-ds-rom", + ".nef": "image/x-nikon-nef", + ".nes": "application/x-nes-rom", + ".nfo": "text/x-nfo", + ".not": "text/x-mup", + ".nsc": "application/x-netshow-channel", + ".nsv": "video/x-nsv", + ".o": "application/x-object", + ".obj": "application/x-tgif", + ".ocl": "text/x-ocl", + ".oda": "application/oda", + ".odb": "application/vnd.oasis.opendocument.database", + ".odc": "application/vnd.oasis.opendocument.chart", + ".odf": "application/vnd.oasis.opendocument.formula", + ".odg": "application/vnd.oasis.opendocument.graphics", + ".odi": "application/vnd.oasis.opendocument.image", + ".odm": "application/vnd.oasis.opendocument.text-master", + ".odp": "application/vnd.oasis.opendocument.presentation", + ".ods": "application/vnd.oasis.opendocument.spreadsheet", + ".odt": "application/vnd.oasis.opendocument.text", + ".oga": "audio/ogg", + ".ogg": "video/x-theora+ogg", + ".ogm": "video/x-ogm+ogg", + ".ogv": "video/ogg", + ".ogx": "application/ogg", + ".old": "application/x-trash", + ".oleo": "application/x-oleo", + ".opml": "text/x-opml+xml", + ".ora": "image/openraster", + ".orf": "image/x-olympus-orf", + ".otc": "application/vnd.oasis.opendocument.chart-template", + ".otf": "application/x-font-otf", + ".otg": "application/vnd.oasis.opendocument.graphics-template", + ".oth": "application/vnd.oasis.opendocument.text-web", + ".otp": "application/vnd.oasis.opendocument.presentation-template", + ".ots": "application/vnd.oasis.opendocument.spreadsheet-template", + ".ott": "application/vnd.oasis.opendocument.text-template", + ".owl": "application/rdf+xml", + ".oxt": "application/vnd.openofficeorg.extension", + ".p": "text/x-pascal", + ".p10": "application/pkcs10", + ".p12": "application/x-pkcs12", + ".p7b": "application/x-pkcs7-certificates", + ".p7s": "application/pkcs7-signature", + ".pack": "application/x-java-pack200", + ".pak": "application/x-pak", + ".par2": "application/x-par2", + ".pas": "text/x-pascal", + ".patch": "text/x-patch", + ".pbm": "image/x-portable-bitmap", + ".pcd": "image/x-photo-cd", + ".pcf": "application/x-cisco-vpn-settings", + ".pcf.gz": "application/x-font-pcf", + ".pcf.z": "application/x-font-pcf", + ".pcl": "application/vnd.hp-pcl", + ".pcx": "image/x-pcx", + ".pdb": "chemical/x-pdb", + ".pdc": "application/x-aportisdoc", + ".pdf": "application/pdf", + ".pdf.bz2": "application/x-bzpdf", + ".pdf.gz": "application/x-gzpdf", + ".pef": "image/x-pentax-pef", + ".pem": "application/x-x509-ca-cert", + ".perl": "application/x-perl", + ".pfa": "application/x-font-type1", + ".pfb": "application/x-font-type1", + ".pfx": "application/x-pkcs12", + ".pgm": "image/x-portable-graymap", + ".pgn": "application/x-chess-pgn", + ".pgp": "application/pgp-encrypted", + ".php": "application/x-php", + ".php3": "application/x-php", + ".php4": "application/x-php", + ".pict": "image/x-pict", + ".pict1": "image/x-pict", + ".pict2": "image/x-pict", + ".pickle": "application/python-pickle", + ".pk": "application/x-tex-pk", + ".pkipath": "application/pkix-pkipath", + ".pkr": "application/pgp-keys", + ".pl": "application/x-perl", + ".pla": "audio/x-iriver-pla", + ".pln": "application/x-planperfect", + ".pls": "audio/x-scpls", + ".pm": "application/x-perl", + ".png": "image/png", + ".pnm": "image/x-portable-anymap", + ".pntg": "image/x-macpaint", + ".po": "text/x-gettext-translation", + ".por": "application/x-spss-por", + ".pot": "text/x-gettext-translation-template", + ".ppm": "image/x-portable-pixmap", + ".pps": "application/vnd.ms-powerpoint", + ".ppt": "application/vnd.ms-powerpoint", + ".pptm": "application/vnd.openxmlformats-officedocument.presentationml.presentation", + ".pptx": "application/vnd.openxmlformats-officedocument.presentationml.presentation", + ".ppz": "application/vnd.ms-powerpoint", + ".prc": "application/x-palm-database", + ".ps": "application/postscript", + ".ps.bz2": "application/x-bzpostscript", + ".ps.gz": "application/x-gzpostscript", + ".psd": "image/vnd.adobe.photoshop", + ".psf": "audio/x-psf", + ".psf.gz": "application/x-gz-font-linux-psf", + ".psflib": "audio/x-psflib", + ".psid": "audio/prs.sid", + ".psw": "application/x-pocket-word", + ".pw": "application/x-pw", + ".py": "text/x-python", + ".pyc": "application/x-python-bytecode", + ".pyo": "application/x-python-bytecode", + ".qif": "image/x-quicktime", + ".qt": "video/quicktime", + ".qtif": "image/x-quicktime", + ".qtl": "application/x-quicktime-media-link", + ".qtvr": "video/quicktime", + ".ra": "audio/vnd.rn-realaudio", + ".raf": "image/x-fuji-raf", + ".ram": "application/ram", + ".rar": "application/x-rar", + ".ras": "image/x-cmu-raster", + ".raw": "image/x-panasonic-raw", + ".rax": "audio/vnd.rn-realaudio", + ".rb": "application/x-ruby", + ".rdf": "application/rdf+xml", + ".rdfs": "application/rdf+xml", + ".reg": "text/x-ms-regedit", + ".rej": "application/x-reject", + ".rgb": "image/x-rgb", + ".rle": "image/rle", + ".rm": "application/vnd.rn-realmedia", + ".rmj": "application/vnd.rn-realmedia", + ".rmm": "application/vnd.rn-realmedia", + ".rms": "application/vnd.rn-realmedia", + ".rmvb": "application/vnd.rn-realmedia", + ".rmx": "application/vnd.rn-realmedia", + ".roff": "text/troff", + ".rp": "image/vnd.rn-realpix", + ".rpm": "application/x-rpm", + ".rss": "application/rss+xml", + ".rt": "text/vnd.rn-realtext", + ".rtf": "application/rtf", + ".rtx": "text/richtext", + ".rv": "video/vnd.rn-realvideo", + ".rvx": "video/vnd.rn-realvideo", + ".s3m": "audio/x-s3m", + ".sam": "application/x-amipro", + ".sami": "application/x-sami", + ".sav": "application/x-spss-sav", + ".scm": "text/x-scheme", + ".sda": "application/vnd.stardivision.draw", + ".sdc": "application/vnd.stardivision.calc", + ".sdd": "application/vnd.stardivision.impress", + ".sdp": "application/sdp", + ".sds": "application/vnd.stardivision.chart", + ".sdw": "application/vnd.stardivision.writer", + ".sgf": "application/x-go-sgf", + ".sgi": "image/x-sgi", + ".sgl": "application/vnd.stardivision.writer", + ".sgm": "text/sgml", + ".sgml": "text/sgml", + ".sh": "application/x-shellscript", + ".shar": "application/x-shar", + ".shn": "application/x-shorten", + ".siag": "application/x-siag", + ".sid": "audio/prs.sid", + ".sik": "application/x-trash", + ".sis": "application/vnd.symbian.install", + ".sisx": "x-epoc/x-sisx-app", + ".sit": "application/x-stuffit", + ".siv": "application/sieve", + ".sk": "image/x-skencil", + ".sk1": "image/x-skencil", + ".skr": "application/pgp-keys", + ".slk": "text/spreadsheet", + ".smaf": "application/x-smaf", + ".smc": "application/x-snes-rom", + ".smd": "application/vnd.stardivision.mail", + ".smf": "application/vnd.stardivision.math", + ".smi": "application/x-sami", + ".smil": "application/smil", + ".sml": "application/smil", + ".sms": "application/x-sms-rom", + ".snd": "audio/basic", + ".so": "application/x-sharedlib", + ".spc": "application/x-pkcs7-certificates", + ".spd": "application/x-font-speedo", + ".spec": "text/x-rpm-spec", + ".spl": "application/x-shockwave-flash", + ".spx": "audio/x-speex", + ".sql": "text/x-sql", + ".sr2": "image/x-sony-sr2", + ".src": "application/x-wais-source", + ".srf": "image/x-sony-srf", + ".srt": "application/x-subrip", + ".ssa": "text/x-ssa", + ".stc": "application/vnd.sun.xml.calc.template", + ".std": "application/vnd.sun.xml.draw.template", + ".sti": "application/vnd.sun.xml.impress.template", + ".stm": "audio/x-stm", + ".stw": "application/vnd.sun.xml.writer.template", + ".sty": "text/x-tex", + ".sub": "text/x-subviewer", + ".sun": "image/x-sun-raster", + ".sv4cpio": "application/x-sv4cpio", + ".sv4crc": "application/x-sv4crc", + ".svg": "image/svg+xml", + ".svgz": "image/svg+xml-compressed", + ".swf": "application/x-shockwave-flash", + ".sxc": "application/vnd.sun.xml.calc", + ".sxd": "application/vnd.sun.xml.draw", + ".sxg": "application/vnd.sun.xml.writer.global", + ".sxi": "application/vnd.sun.xml.impress", + ".sxm": "application/vnd.sun.xml.math", + ".sxw": "application/vnd.sun.xml.writer", + ".sylk": "text/spreadsheet", + ".t": "text/troff", + ".t2t": "text/x-txt2tags", + ".tar": "application/x-tar", + ".tar.bz": "application/x-bzip-compressed-tar", + ".tar.bz2": "application/x-bzip-compressed-tar", + ".tar.gz": "application/x-compressed-tar", + ".tar.lzma": "application/x-lzma-compressed-tar", + ".tar.lzo": "application/x-tzo", + ".tar.xz": "application/x-xz-compressed-tar", + ".tar.z": "application/x-tarz", + ".tbz": "application/x-bzip-compressed-tar", + ".tbz2": "application/x-bzip-compressed-tar", + ".tcl": "text/x-tcl", + ".tex": "text/x-tex", + ".texi": "text/x-texinfo", + ".texinfo": "text/x-texinfo", + ".tga": "image/x-tga", + ".tgz": "application/x-compressed-tar", + ".theme": "application/x-theme", + ".themepack": "application/x-windows-themepack", + ".tif": "image/tiff", + ".tiff": "image/tiff", + ".tk": "text/x-tcl", + ".tlz": "application/x-lzma-compressed-tar", + ".tnef": "application/vnd.ms-tnef", + ".tnf": "application/vnd.ms-tnef", + ".toc": "application/x-cdrdao-toc", + ".torrent": "application/x-bittorrent", + ".tpic": "image/x-tga", + ".tr": "text/troff", + ".ts": "application/x-linguist", + ".tsv": "text/tab-separated-values", + ".tta": "audio/x-tta", + ".ttc": "application/x-font-ttf", + ".ttf": "application/x-font-ttf", + ".ttx": "application/x-font-ttx", + ".txt": "text/plain", + ".txz": "application/x-xz-compressed-tar", + ".tzo": "application/x-tzo", + ".ufraw": "application/x-ufraw", + ".ui": "application/x-designer", + ".uil": "text/x-uil", + ".ult": "audio/x-mod", + ".uni": "audio/x-mod", + ".uri": "text/x-uri", + ".url": "text/x-uri", + ".ustar": "application/x-ustar", + ".vala": "text/x-vala", + ".vapi": "text/x-vala", + ".vcf": "text/directory", + ".vcs": "text/calendar", + ".vct": "text/directory", + ".vda": "image/x-tga", + ".vhd": "text/x-vhdl", + ".vhdl": "text/x-vhdl", + ".viv": "video/vivo", + ".vivo": "video/vivo", + ".vlc": "audio/x-mpegurl", + ".vob": "video/mpeg", + ".voc": "audio/x-voc", + ".vor": "application/vnd.stardivision.writer", + ".vst": "image/x-tga", + ".wav": "audio/x-wav", + ".wax": "audio/x-ms-asx", + ".wb1": "application/x-quattropro", + ".wb2": "application/x-quattropro", + ".wb3": "application/x-quattropro", + ".wbmp": "image/vnd.wap.wbmp", + ".wcm": "application/vnd.ms-works", + ".wdb": "application/vnd.ms-works", + ".webm": "video/webm", + ".wk1": "application/vnd.lotus-1-2-3", + ".wk3": "application/vnd.lotus-1-2-3", + ".wk4": "application/vnd.lotus-1-2-3", + ".wks": "application/vnd.ms-works", + ".wma": "audio/x-ms-wma", + ".wmf": "image/x-wmf", + ".wml": "text/vnd.wap.wml", + ".wmls": "text/vnd.wap.wmlscript", + ".wmv": "video/x-ms-wmv", + ".wmx": "audio/x-ms-asx", + ".wp": "application/vnd.wordperfect", + ".wp4": "application/vnd.wordperfect", + ".wp5": "application/vnd.wordperfect", + ".wp6": "application/vnd.wordperfect", + ".wpd": "application/vnd.wordperfect", + ".wpg": "application/x-wpg", + ".wpl": "application/vnd.ms-wpl", + ".wpp": "application/vnd.wordperfect", + ".wps": "application/vnd.ms-works", + ".wri": "application/x-mswrite", + ".wrl": "model/vrml", + ".wv": "audio/x-wavpack", + ".wvc": "audio/x-wavpack-correction", + ".wvp": "audio/x-wavpack", + ".wvx": "audio/x-ms-asx", + ".x3f": "image/x-sigma-x3f", + ".xac": "application/x-gnucash", + ".xbel": "application/x-xbel", + ".xbl": "application/xml", + ".xbm": "image/x-xbitmap", + ".xcf": "image/x-xcf", + ".xcf.bz2": "image/x-compressed-xcf", + ".xcf.gz": "image/x-compressed-xcf", + ".xhtml": "application/xhtml+xml", + ".xi": "audio/x-xi", + ".xla": "application/vnd.ms-excel", + ".xlc": "application/vnd.ms-excel", + ".xld": "application/vnd.ms-excel", + ".xlf": "application/x-xliff", + ".xliff": "application/x-xliff", + ".xll": "application/vnd.ms-excel", + ".xlm": "application/vnd.ms-excel", + ".xls": "application/vnd.ms-excel", + ".xlsm": "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", + ".xlsx": "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", + ".xlt": "application/vnd.ms-excel", + ".xlw": "application/vnd.ms-excel", + ".xm": "audio/x-xm", + ".xmf": "audio/x-xmf", + ".xmi": "text/x-xmi", + ".xml": "application/xml", + ".xpm": "image/x-xpixmap", + ".xps": "application/vnd.ms-xpsdocument", + ".xsl": "application/xml", + ".xslfo": "text/x-xslfo", + ".xslt": "application/xml", + ".xspf": "application/xspf+xml", + ".xul": "application/vnd.mozilla.xul+xml", + ".xwd": "image/x-xwindowdump", + ".xyz": "chemical/x-pdb", + ".xz": "application/x-xz", + ".w2p": "application/w2p", + ".z": "application/x-compress", + ".zabw": "application/x-abiword", + ".zip": "application/zip", + ".zoo": "application/x-zoo", } -def contenttype(filename, default='text/plain'): +def contenttype(filename, default="text/plain"): """ Returns the Content-Type string matching extension of the given filename. """ - i = filename.rfind('.') + i = filename.rfind(".") if i >= 0: default = CONTENT_TYPE.get(filename[i:].lower(), default) - j = filename.rfind('.', 0, i) + j = filename.rfind(".", 0, i) if j >= 0: default = CONTENT_TYPE.get(filename[j:].lower(), default) - if default.startswith('text/'): - default += '; charset=utf-8' + if default.startswith("text/"): + default += "; charset=utf-8" return default diff --git a/emmett/libs/portalocker.py b/emmett/libs/portalocker.py index 720963e3..480452bf 100644 --- a/emmett/libs/portalocker.py +++ b/emmett/libs/portalocker.py @@ -43,22 +43,23 @@ os_locking = None try: - import google.appengine - os_locking = 'gae' -except: + os_locking = "gae" +except Exception: try: import fcntl - os_locking = 'posix' - except: + + os_locking = "posix" + except Exception: try: + import pywintypes import win32con import win32file - import pywintypes - os_locking = 'windows' - except: + + os_locking = "windows" + except Exception: pass -if os_locking == 'windows': +if os_locking == "windows": LOCK_EX = win32con.LOCKFILE_EXCLUSIVE_LOCK LOCK_SH = 0 # the default LOCK_NB = win32con.LOCKFILE_FAIL_IMMEDIATELY @@ -69,14 +70,14 @@ def lock(file, flags): hfile = win32file._get_osfhandle(file.fileno()) - win32file.LockFileEx(hfile, flags, 0, 0x7fff0000, __overlapped) + win32file.LockFileEx(hfile, flags, 0, 0x7FFF0000, __overlapped) def unlock(file): hfile = win32file._get_osfhandle(file.fileno()) - win32file.UnlockFileEx(hfile, 0, 0x7fff0000, __overlapped) + win32file.UnlockFileEx(hfile, 0, 0x7FFF0000, __overlapped) -elif os_locking == 'posix': +elif os_locking == "posix": LOCK_EX = fcntl.LOCK_EX LOCK_SH = fcntl.LOCK_SH LOCK_NB = fcntl.LOCK_NB @@ -89,9 +90,9 @@ def unlock(file): else: - #if platform.system() == 'Windows': + # if platform.system() == 'Windows': # logger.error('no file locking, you must install the win32 extensions from: http://sourceforge.net/projects/pywin32/files/') - #elif os_locking != 'gae': + # elif os_locking != 'gae': # logger.debug('no file locking, this will cause problems') LOCK_EX = None @@ -106,18 +107,18 @@ def unlock(file): class LockedFile(object): - def __init__(self, filename, mode='rb'): + def __init__(self, filename, mode="rb"): self.filename = filename self.mode = mode self.file = None - if 'r' in mode: - kwargs = {'encoding': 'utf8'} if 'b' not in mode else {} + if "r" in mode: + kwargs = {"encoding": "utf8"} if "b" not in mode else {} self.file = open(filename, mode, **kwargs) lock(self.file, LOCK_SH) - elif 'w' in mode or 'a' in mode: - self.file = open(filename, mode.replace('w', 'a')) + elif "w" in mode or "a" in mode: + self.file = open(filename, mode.replace("w", "a")) lock(self.file, LOCK_EX) - if 'a' not in mode: + if "a" not in mode: self.file.seek(0) self.file.truncate() else: @@ -148,13 +149,13 @@ def __del__(self): def read_locked(filename): - fp = LockedFile(filename, 'r') + fp = LockedFile(filename, "r") data = fp.read() fp.close() return data def write_locked(filename, data): - fp = LockedFile(filename, 'wb') + fp = LockedFile(filename, "wb") data = fp.write(data) fp.close() diff --git a/emmett/locals.py b/emmett/locals.py index 3a17a447..b99b44e5 100644 --- a/emmett/locals.py +++ b/emmett/locals.py @@ -1,12 +1,12 @@ # -*- coding: utf-8 -*- """ - emmett.locals - ------------- +emmett.locals +------------- - Provides shortcuts to `current` object. +Provides shortcuts to `current` object. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ from typing import Optional, cast @@ -21,11 +21,12 @@ from .wrappers.response import Response from .wrappers.websocket import Websocket -request = cast(Request, _VProxy[Request](_ctxv, 'request')) -response = cast(Response, _VProxy[Response](_ctxv, 'response')) -session = cast(Optional[sdict], _VProxy[Optional[sdict]](_ctxv, 'session')) -websocket = cast(Websocket, _VProxy[Websocket](_ctxv, 'websocket')) -T = cast(Translator, _OProxy[Translator](current, 'T')) + +request = cast(Request, _VProxy[Request](_ctxv, "request")) +response = cast(Response, _VProxy[Response](_ctxv, "response")) +session = cast(Optional[sdict], _VProxy[Optional[sdict]](_ctxv, "session")) +websocket = cast(Websocket, _VProxy[Websocket](_ctxv, "websocket")) +T = cast(Translator, _OProxy[Translator](current, "T")) def now() -> DateTime: diff --git a/emmett/orm/__init__.py b/emmett/orm/__init__.py index 8a8ffa65..cefe76a6 100644 --- a/emmett/orm/__init__.py +++ b/emmett/orm/__init__.py @@ -1,16 +1,27 @@ from . import _patches from .adapters import adapters as adapters_registry -from .base import Database -from .objects import Field, TransactionOps -from .models import Model from .apis import ( - belongs_to, refers_to, has_one, has_many, - compute, rowattr, rowmethod, - before_insert, before_update, before_delete, - before_save, before_destroy, - before_commit, - after_insert, after_update, after_delete, - after_save, after_destroy, after_commit, - scope + after_delete, + after_destroy, + after_insert, + after_save, + after_update, + before_commit, + before_delete, + before_destroy, + before_insert, + before_save, + before_update, + belongs_to, + compute, + has_many, + has_one, + refers_to, + rowattr, + rowmethod, + scope, ) +from .base import Database +from .models import Model +from .objects import Field, TransactionOps diff --git a/emmett/orm/_patches.py b/emmett/orm/_patches.py index 23c51c01..cfe4f670 100644 --- a/emmett/orm/_patches.py +++ b/emmett/orm/_patches.py @@ -1,85 +1,71 @@ # -*- coding: utf-8 -*- """ - emmett.orm._patches - ------------------- +emmett.orm._patches +------------------- - Provides pyDAL patches. +Provides pyDAL patches. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ from emmett_core.utils import cachedprop - -from ..serializers import Serializers - from pydal.adapters.base import BaseAdapter from pydal.connection import ConnectionPool from pydal.helpers.classes import ConnectionConfigurationMixin from pydal.helpers.serializers import Serializers as _Serializers +from ..serializers import Serializers from .adapters import ( - _initialize, _begin, _in_transaction, - _push_transaction, + _initialize, _pop_transaction, + _push_transaction, + _top_transaction, _transaction_depth, - _top_transaction ) from .connection import ( ConnectionManager, PooledConnectionManager, - _connection_init, - _connect_sync, - _connect_loop, - _close_sync, _close_loop, + _close_sync, + _connect_and_configure, + _connect_loop, + _connect_sync, _connection_getter, + _connection_init, _connection_setter, _cursors_getter, - _connect_and_configure ) from .engines.sqlite import SQLite def _patch_adapter_cls(): - setattr(BaseAdapter, '_initialize_', _initialize) - setattr(BaseAdapter, 'in_transaction', _in_transaction) - setattr(BaseAdapter, 'push_transaction', _push_transaction) - setattr(BaseAdapter, 'pop_transaction', _pop_transaction) - setattr(BaseAdapter, 'transaction_depth', _transaction_depth) - setattr(BaseAdapter, 'top_transaction', _top_transaction) - setattr(BaseAdapter, '_connection_manager_cls', PooledConnectionManager) - setattr(BaseAdapter, 'begin', _begin) - setattr(SQLite, '_connection_manager_cls', ConnectionManager) + BaseAdapter._initialize_ = _initialize + BaseAdapter.in_transaction = _in_transaction + BaseAdapter.push_transaction = _push_transaction + BaseAdapter.pop_transaction = _pop_transaction + BaseAdapter.transaction_depth = _transaction_depth + BaseAdapter.top_transaction = _top_transaction + BaseAdapter._connection_manager_cls = PooledConnectionManager + BaseAdapter.begin = _begin + SQLite._connection_manager_cls = ConnectionManager def _patch_adapter_connection(): - setattr(ConnectionPool, '__init__', _connection_init) - setattr(ConnectionPool, 'reconnect', _connect_sync) - setattr(ConnectionPool, 'reconnect_loop', _connect_loop) - setattr(ConnectionPool, 'close', _close_sync) - setattr(ConnectionPool, 'close_loop', _close_loop) - setattr( - ConnectionPool, - 'connection', - property(_connection_getter, _connection_setter) - ) - setattr(ConnectionPool, 'cursors', property(_cursors_getter)) - setattr( - ConnectionConfigurationMixin, - '_reconnect_and_configure', - _connect_and_configure - ) + ConnectionPool.__init__ = _connection_init + ConnectionPool.reconnect = _connect_sync + ConnectionPool.reconnect_loop = _connect_loop + ConnectionPool.close = _close_sync + ConnectionPool.close_loop = _close_loop + ConnectionPool.connection = property(_connection_getter, _connection_setter) + ConnectionPool.cursors = property(_cursors_getter) + ConnectionConfigurationMixin._reconnect_and_configure = _connect_and_configure def _patch_serializers(): - setattr( - _Serializers, - 'json', - cachedprop(lambda _: Serializers.get_for('json'), name='json') - ) + _Serializers.json = cachedprop(lambda _: Serializers.get_for("json"), name="json") _patch_adapter_cls() diff --git a/emmett/orm/adapters.py b/emmett/orm/adapters.py index 089a49a5..9ac806b9 100644 --- a/emmett/orm/adapters.py +++ b/emmett/orm/adapters.py @@ -1,35 +1,20 @@ # -*- coding: utf-8 -*- """ - emmett.orm.adapters - ------------------- +emmett.orm.adapters +------------------- - Provides ORM adapters facilities. +Provides ORM adapters facilities. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ import sys - from functools import wraps from pydal.adapters.base import SQLAdapter -from pydal.adapters.mssql import ( - MSSQL1, - MSSQL3, - MSSQL4, - MSSQL1N, - MSSQL3N, - MSSQL4N -) -from pydal.adapters.postgres import ( - Postgre, - PostgrePsyco, - PostgrePG8000, - PostgreNew, - PostgrePsycoNew, - PostgrePG8000New -) +from pydal.adapters.mssql import MSSQL1, MSSQL1N, MSSQL3, MSSQL3N, MSSQL4, MSSQL4N +from pydal.adapters.postgres import Postgre, PostgreNew, PostgrePG8000, PostgrePG8000New, PostgrePsyco, PostgrePsycoNew from pydal.helpers.classes import SQLALL from pydal.helpers.regex import REGEX_TABLE_DOT_FIELD from pydal.parsers import ParserMethodWrapper, for_type as _parser_for_type @@ -37,37 +22,38 @@ from .engines import adapters from .helpers import GeoFieldWrapper, PasswordFieldWrapper, typed_row_reference -from .objects import Expression, Field, Row, IterRows - - -adapters._registry_.update({ - 'mssql': MSSQL4, - 'mssql2': MSSQL1, - 'mssql3': MSSQL3, - 'mssqln': MSSQL4N, - 'mssqln2': MSSQL1N, - 'mssqln3': MSSQL3N, - 'postgres2': PostgreNew, - 'postgres2:psycopg2': PostgrePsycoNew, - 'postgres2:pg8000': PostgrePG8000New, - 'postgres3': Postgre, - 'postgres3:psycopg2': PostgrePsyco, - 'postgres3:pg8000': PostgrePG8000 -}) +from .objects import Expression, Field, IterRows, Row + + +adapters._registry_.update( + { + "mssql": MSSQL4, + "mssql2": MSSQL1, + "mssql3": MSSQL3, + "mssqln": MSSQL4N, + "mssqln2": MSSQL1N, + "mssqln3": MSSQL3N, + "postgres2": PostgreNew, + "postgres2:psycopg2": PostgrePsycoNew, + "postgres2:pg8000": PostgrePG8000New, + "postgres3": Postgre, + "postgres3:psycopg2": PostgrePsyco, + "postgres3:pg8000": PostgrePG8000, + } +) def _wrap_on_obj(f, adapter): @wraps(f) def wrapped(*args, **kwargs): return f(adapter, *args, **kwargs) + return wrapped def patch_adapter(adapter): #: BaseAdapter interfaces - adapter._expand_all_with_concrete_tables = _wrap_on_obj( - _expand_all_with_concrete_tables, adapter - ) + adapter._expand_all_with_concrete_tables = _wrap_on_obj(_expand_all_with_concrete_tables, adapter) adapter._parse = _wrap_on_obj(_parse, adapter) adapter._parse_expand_colnames = _wrap_on_obj(_parse_expand_colnames, adapter) adapter.iterparse = _wrap_on_obj(iterparse, adapter) @@ -87,44 +73,28 @@ def patch_adapter(adapter): def patch_dialect(dialect): - _create_table_map = { - 'mysql': _create_table_mysql, - 'firebird': _create_table_firebird - } - dialect.create_table = _wrap_on_obj( - _create_table_map.get(dialect.adapter.dbengine, _create_table), dialect - ) + _create_table_map = {"mysql": _create_table_mysql, "firebird": _create_table_firebird} + dialect.create_table = _wrap_on_obj(_create_table_map.get(dialect.adapter.dbengine, _create_table), dialect) dialect.add_foreign_key_constraint = _wrap_on_obj(_add_fk_constraint, dialect) dialect.drop_constraint = _wrap_on_obj(_drop_constraint, dialect) def patch_parser(dialect, parser): - parser.registered['password'] = ParserMethodWrapper( - parser, - _parser_for_type('password')(_parser_password).f + parser.registered["password"] = ParserMethodWrapper(parser, _parser_for_type("password")(_parser_password).f) + parser.registered["reference"] = ParserMethodWrapper( + parser, _parser_for_type("reference")(_parser_reference).f, parser._before_registry_["reference"] ) - parser.registered['reference'] = ParserMethodWrapper( - parser, - _parser_for_type('reference')(_parser_reference).f, - parser._before_registry_['reference'] - ) - if 'geography' in dialect.types: - parser.registered['geography'] = ParserMethodWrapper( - parser, - _parser_for_type('geography')(_parser_geo).f - ) - if 'geometry' in dialect.types: - parser.registered['geometry'] = ParserMethodWrapper( - parser, - _parser_for_type('geometry')(_parser_geo).f - ) + if "geography" in dialect.types: + parser.registered["geography"] = ParserMethodWrapper(parser, _parser_for_type("geography")(_parser_geo).f) + if "geometry" in dialect.types: + parser.registered["geometry"] = ParserMethodWrapper(parser, _parser_for_type("geometry")(_parser_geo).f) def patch_representer(representer): - representer.registered_t['reference'] = TReprMethodWrapper( + representer.registered_t["reference"] = TReprMethodWrapper( representer, - _representer_for_type('reference')(_representer_reference), - representer._tbefore_registry_['reference'] + _representer_for_type("reference")(_representer_reference), + representer._tbefore_registry_["reference"], ) @@ -132,17 +102,14 @@ def insert(adapter, table, fields): query = adapter._insert(table, fields) try: adapter.execute(query) - except: + except Exception: e = sys.exc_info()[1] - if hasattr(table, '_on_insert_error'): + if hasattr(table, "_on_insert_error"): return table._on_insert_error(table, fields, e) raise e if not table._id: - id = { - field.name: val for field, val in fields - if field.name in table._primarykey - } or None - elif table._id.type == 'id': + id = {field.name: val for field, val in fields if field.name in table._primarykey} or None + elif table._id.type == "id": id = adapter.lastrowid(table) else: id = {field.name: val for field, val in fields}.get(table._id.name) @@ -193,7 +160,7 @@ def _select_wcols( orderby_on_limitby=True, for_update=False, outer_scoped=[], - **kwargs + **kwargs, ): return adapter._select_wcols_inner( query, @@ -207,7 +174,7 @@ def _select_wcols( limitby=limitby, orderby_on_limitby=orderby_on_limitby, for_update=for_update, - outer_scoped=outer_scoped + outer_scoped=outer_scoped, ) @@ -215,42 +182,25 @@ def _select_aux(adapter, sql, fields, attributes, colnames): rows = adapter._select_aux_execute(sql) if isinstance(rows, tuple): rows = list(rows) - limitby = attributes.get('limitby', None) or (0,) + limitby = attributes.get("limitby", None) or (0,) rows = adapter.rowslice(rows, limitby[0], None) - return adapter.parse( - rows, - fields, - colnames, - concrete_tables=attributes.get('_concrete_tables', []) - ) + return adapter.parse(rows, fields, colnames, concrete_tables=attributes.get("_concrete_tables", [])) def parse(adapter, rows, fields, colnames, **options): fdata, tables = _parse_expand_colnames(adapter, fields) new_rows = [ _parse( - adapter, - row, - fdata, - tables, - options['concrete_tables'], - fields, - colnames, - options.get('blob_decode', True) - ) for row in rows + adapter, row, fdata, tables, options["concrete_tables"], fields, colnames, options.get("blob_decode", True) + ) + for row in rows ] rowsobj = adapter.db.Rows(adapter.db, new_rows, colnames, rawrows=rows) return rowsobj def iterparse(adapter, sql, fields, colnames, **options): - return IterRows( - adapter.db, - sql, - fields, - options.get('_concrete_tables', []), - colnames - ) + return IterRows(adapter.db, sql, fields, options.get("_concrete_tables", []), colnames) def _parse_expand_colnames(adapter, fieldlist): @@ -269,12 +219,10 @@ def _parse_expand_colnames(adapter, fieldlist): def _parse(adapter, row, fdata, tables, concrete_tables, fields, colnames, blob_decode): - new_row, rows_cls, rows_accum = _build_newrow_wtables( - adapter, tables, concrete_tables - ) + new_row, rows_cls, rows_accum = _build_newrow_wtables(adapter, tables, concrete_tables) extras = adapter.db.Row() #: let's loop over columns - for (idx, colname) in enumerate(colnames): + for idx, colname in enumerate(colnames): value = row[idx] fd = fdata[idx] tablename = None @@ -289,9 +237,7 @@ def _parse(adapter, row, fdata, tables, concrete_tables, fields, colnames, blob_ colset[fieldname] = value #: otherwise we set the value in extras else: - value = adapter.parse_value( - value, fields[idx]._itype, fields[idx].type, blob_decode - ) + value = adapter.parse_value(value, fields[idx]._itype, fields[idx].type, blob_decode) extras[colname] = value new_column_name = adapter._regex_select_as_parser(colname) if new_column_name is not None: @@ -301,13 +247,13 @@ def _parse(adapter, row, fdata, tables, concrete_tables, fields, colnames, blob_ new_row[key] = val._from_engine(rows_accum[key]) #: add extras if needed (eg. operations results) if extras: - new_row['_extra'] = extras + new_row["_extra"] = extras return new_row def _build_newrow_wtables(adapter, tables, concrete_tables): row, cls_map, accum = adapter.db.Row(), {}, {} - for name, table in tables.items(): + for name, _ in tables.items(): cls_map[name] = adapter.db.Row accum[name] = {} for table in concrete_tables: @@ -317,14 +263,14 @@ def _build_newrow_wtables(adapter, tables, concrete_tables): def _create_table(dialect, tablename, fields): - return [ - "CREATE TABLE %s(\n %s\n);" % (dialect.quote(tablename), fields)] + return ["CREATE TABLE %s(\n %s\n);" % (dialect.quote(tablename), fields)] def _create_table_mysql(dialect, tablename, fields): - return ["CREATE TABLE %s(\n %s\n) ENGINE=%s CHARACTER SET utf8;" % ( - dialect.quote(tablename), fields, - dialect.adapter.adapter_args.get('engine', 'InnoDB'))] + return [ + "CREATE TABLE %s(\n %s\n) ENGINE=%s CHARACTER SET utf8;" + % (dialect.quote(tablename), fields, dialect.adapter.adapter_args.get("engine", "InnoDB")) + ] def _create_table_firebird(dialect, tablename, fields): @@ -332,30 +278,25 @@ def _create_table_firebird(dialect, tablename, fields): sequence_name = dialect.sequence_name(tablename) trigger_name = dialect.trigger_name(tablename) trigger_sql = ( - 'create trigger %s for %s active before insert position 0 as\n' - 'begin\n' + "create trigger %s for %s active before insert position 0 as\n" + "begin\n" 'if(new."id" is null) then\n' - 'begin\n' + "begin\n" 'new."id" = gen_id(%s, 1);\n' - 'end\n' - 'end;') - rv.extend([ - 'create generator %s;' % sequence_name, - 'set generator %s to 0;' % sequence_name, - trigger_sql % (trigger_name, dialect.quote(tablename), sequence_name) - ]) + "end\n" + "end;" + ) + rv.extend( + [ + "create generator %s;" % sequence_name, + "set generator %s to 0;" % sequence_name, + trigger_sql % (trigger_name, dialect.quote(tablename), sequence_name), + ] + ) return rv -def _add_fk_constraint( - dialect, - name, - table_local, - table_foreign, - columns_local, - columns_foreign, - on_delete -): +def _add_fk_constraint(dialect, name, table_local, table_foreign, columns_local, columns_foreign, on_delete): return ( f"ALTER TABLE {dialect.quote(table_local)} " f"ADD CONSTRAINT {dialect.quote(name)} " @@ -371,7 +312,7 @@ def _drop_constraint(dialect, name, table): def _parser_reference(parser, value, referee): - if '.' not in referee: + if "." not in referee: value = typed_row_reference(value, parser.adapter.db[referee]) return value @@ -385,7 +326,7 @@ def _parser_password(parser, value): def _representer_reference(representer, value, referenced): - rtname, _, rfname = referenced.partition('.') + rtname, _, rfname = referenced.partition(".") rtable = representer.adapter.db[rtname] if not rfname and rtable._id: rfname = rtable._id.name @@ -394,9 +335,9 @@ def _representer_reference(representer, value, referenced): rtype = rtable[rfname].type if isinstance(value, Row) and getattr(value, "_concrete", False): value = value[(value._model.primary_keys or ["id"])[0]] - if rtype in ('id', 'integer'): + if rtype in ("id", "integer"): return str(int(value)) - if rtype == 'string': + if rtype == "string": return str(value) return representer.adapter.represent(value, rtype) @@ -406,7 +347,8 @@ def _initialize(adapter, *args, **kwargs): adapter._connection_manager.configure( max_connections=adapter.db._pool_size, connect_timeout=adapter.db._connect_timeout, - stale_timeout=adapter.db._keep_alive_timeout) + stale_timeout=adapter.db._keep_alive_timeout, + ) def _begin(adapter): diff --git a/emmett/orm/apis.py b/emmett/orm/apis.py index b0c934ac..523e839a 100644 --- a/emmett/orm/apis.py +++ b/emmett/orm/apis.py @@ -1,19 +1,19 @@ # -*- coding: utf-8 -*- """ - emmett.orm.apis - --------------- +emmett.orm.apis +--------------- - Provides ORM apis. +Provides ORM apis. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ from collections import OrderedDict from typing import List from .errors import MissingFieldsForCompute -from .helpers import Reference, Callback +from .helpers import Callback, Reference class belongs_to(Reference): @@ -93,61 +93,62 @@ class rowmethod(rowattr): def before_insert(f): - return Callback(f, '_before_insert') + return Callback(f, "_before_insert") def after_insert(f): - return Callback(f, '_after_insert') + return Callback(f, "_after_insert") def before_update(f): - return Callback(f, '_before_update') + return Callback(f, "_before_update") def after_update(f): - return Callback(f, '_after_update') + return Callback(f, "_after_update") def before_delete(f): - return Callback(f, '_before_delete') + return Callback(f, "_before_delete") def after_delete(f): - return Callback(f, '_after_delete') + return Callback(f, "_after_delete") def before_save(f): - return Callback(f, '_before_save') + return Callback(f, "_before_save") def after_save(f): - return Callback(f, '_after_save') + return Callback(f, "_after_save") def before_destroy(f): - return Callback(f, '_before_destroy') + return Callback(f, "_before_destroy") def after_destroy(f): - return Callback(f, '_after_destroy') + return Callback(f, "_after_destroy") def before_commit(f): - return Callback(f, '_before_commit') + return Callback(f, "_before_commit") def after_commit(f): - return Callback(f, '_after_commit') + return Callback(f, "_after_commit") def _commit_callback_op(kind, op): def _deco(f): - return Callback(f, f'_{kind}_commit_{op}') + return Callback(f, f"_{kind}_commit_{op}") + return _deco -before_commit.operation = lambda op: _commit_callback_op('before', op) -after_commit.operation = lambda op: _commit_callback_op('after', op) +before_commit.operation = lambda op: _commit_callback_op("before", op) +after_commit.operation = lambda op: _commit_callback_op("after", op) class scope(object): diff --git a/emmett/orm/base.py b/emmett/orm/base.py index c9e8c502..fa8729e9 100644 --- a/emmett/orm/base.py +++ b/emmett/orm/base.py @@ -1,12 +1,12 @@ # -*- coding: utf-8 -*- """ - emmett.orm.base - --------------- +emmett.orm.base +--------------- - Provides base pyDAL implementation for Emmett. +Provides base pyDAL implementation for Emmett. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ from __future__ import annotations @@ -14,12 +14,11 @@ import copyreg import os import threading - from functools import wraps +from emmett_core.serializers import _json_default from pydal import DAL as _pyDAL from pydal._globals import THREAD_LOCAL -from emmett_core.serializers import _json_default from .._shortcuts import uuid as _uuid from ..datastructures import sdict @@ -27,10 +26,10 @@ from ..pipeline import Pipe from ..serializers import xml from .adapters import patch_adapter -from .objects import Table, Field, Set, Row, Rows from .helpers import ConnectionContext, TimingHandler from .models import MetaModel, Model -from .transactions import _atomic, _transaction, _savepoint +from .objects import Field, Row, Rows, Set, Table +from .transactions import _atomic, _savepoint, _transaction class DatabasePipe(Pipe): @@ -51,7 +50,7 @@ async def close(self): class Database(_pyDAL): - serializers = {'json': _json_default, 'xml': xml} + serializers = {"json": _json_default, "xml": xml} logger = None uuid = lambda x: _uuid() @@ -80,19 +79,12 @@ def uri_from_config(config=None): return uri def __new__(cls, app, *args, **kwargs): - config = kwargs.get('config', sdict()) or app.config.db + config = kwargs.get("config", sdict()) or app.config.db uri = config.uri or Database.uri_from_config(config) return super(Database, cls).__new__(cls, uri, *args, **kwargs) def __init__( - self, - app, - config=None, - pool_size=None, - keep_alive_timeout=3600, - connect_timeout=60, - folder=None, - **kwargs + self, app, config=None, pool_size=None, keep_alive_timeout=3600, connect_timeout=60, folder=None, **kwargs ): app.send_signal(Signals.before_database) self.logger = app.log @@ -100,30 +92,24 @@ def __init__( if not config.uri: config.uri = self.uri_from_config(config) if not config.migrations_folder: - config.migrations_folder = 'migrations' + config.migrations_folder = "migrations" self.config = config - self._auto_migrate = self.config.get( - 'auto_migrate', kwargs.pop('auto_migrate', False)) - self._auto_connect = self.config.get( - 'auto_connect', kwargs.pop('auto_connect', None)) - self._use_bigint_on_id_fields = self.config.get( - 'big_id_fields', kwargs.pop('big_id_fields', False)) + self._auto_migrate = self.config.get("auto_migrate", kwargs.pop("auto_migrate", False)) + self._auto_connect = self.config.get("auto_connect", kwargs.pop("auto_connect", None)) + self._use_bigint_on_id_fields = self.config.get("big_id_fields", kwargs.pop("big_id_fields", False)) #: load config data - kwargs['check_reserved'] = self.config.check_reserved or \ - kwargs.get('check_reserved', None) - kwargs['migrate'] = self._auto_migrate - kwargs['driver_args'] = self.config.driver_args or \ - kwargs.get('driver_args', None) - kwargs['adapter_args'] = self.config.adapter_args or \ - kwargs.get('adapter_args', None) + kwargs["check_reserved"] = self.config.check_reserved or kwargs.get("check_reserved", None) + kwargs["migrate"] = self._auto_migrate + kwargs["driver_args"] = self.config.driver_args or kwargs.get("driver_args", None) + kwargs["adapter_args"] = self.config.adapter_args or kwargs.get("adapter_args", None) if self._auto_connect is not None: - kwargs['do_connect'] = self._auto_connect + kwargs["do_connect"] = self._auto_connect else: - kwargs['do_connect'] = os.environ.get('EMMETT_CLI_ENV') == 'true' + kwargs["do_connect"] = os.environ.get("EMMETT_CLI_ENV") == "true" if self._use_bigint_on_id_fields: - kwargs['bigint_id'] = True + kwargs["bigint_id"] = True #: set directory - folder = folder or 'databases' + folder = folder or "databases" folder = os.path.join(app.root_path, folder) if self._auto_migrate: with self._cls_global_lock_: @@ -132,17 +118,14 @@ def __init__( #: set pool_size pool_size = self.config.pool_size or pool_size or 5 self._keep_alive_timeout = ( - keep_alive_timeout if self.config.keep_alive_timeout is None - else self.config.keep_alive_timeout) - self._connect_timeout = ( - connect_timeout if self.config.connect_timeout is None - else self.config.connect_timeout) + keep_alive_timeout if self.config.keep_alive_timeout is None else self.config.keep_alive_timeout + ) + self._connect_timeout = connect_timeout if self.config.connect_timeout is None else self.config.connect_timeout #: add timings storage if requested if config.store_execution_timings: self.execution_handlers.append(TimingHandler) #: finally setup pyDAL instance - super(Database, self).__init__( - self.config.uri, pool_size, folder, **kwargs) + super(Database, self).__init__(self.config.uri, pool_size, folder, **kwargs) patch_adapter(self._adapter) Model._init_inheritable_dicts_() app.send_signal(Signals.after_database, database=self) @@ -153,29 +136,19 @@ def pipe(self): @property def execution_timings(self): - return getattr(THREAD_LOCAL, '_emtdal_timings_', []) + return getattr(THREAD_LOCAL, "_emtdal_timings_", []) def connection_open(self, with_transaction=True, reuse_if_open=True): - return self._adapter.reconnect( - with_transaction=with_transaction, reuse_if_open=reuse_if_open) + return self._adapter.reconnect(with_transaction=with_transaction, reuse_if_open=reuse_if_open) def connection_close(self): self._adapter.close() - def connection( - self, - with_transaction: bool = True, - reuse_if_open: bool = True - ) -> ConnectionContext: - return ConnectionContext( - self, - with_transaction=with_transaction, - reuse_if_open=reuse_if_open - ) + def connection(self, with_transaction: bool = True, reuse_if_open: bool = True) -> ConnectionContext: + return ConnectionContext(self, with_transaction=with_transaction, reuse_if_open=reuse_if_open) def connection_open_loop(self, with_transaction=True, reuse_if_open=True): - return self._adapter.reconnect_loop( - with_transaction=with_transaction, reuse_if_open=reuse_if_open) + return self._adapter.reconnect_loop(with_transaction=with_transaction, reuse_if_open=reuse_if_open) def connection_close_loop(self): return self._adapter.close_loop() @@ -195,15 +168,13 @@ def define_models(self, *models): obj._define_relations_() obj._define_virtuals_() # define table and store in model - args = dict( - migrate=obj.migrate, - format=obj.format, - table_class=Table, - primarykey=obj.primary_keys or ['id'] - ) - model.table = self.define_table( - obj.tablename, *obj.fields, **args - ) + args = { + "migrate": obj.migrate, + "format": obj.format, + "table_class": Table, + "primarykey": obj.primary_keys or ["id"], + } + model.table = self.define_table(obj.tablename, *obj.fields, **args) model.table._model_ = obj # set reference in db for model name self.__setattr__(model.__name__, obj.table) @@ -219,7 +190,7 @@ def where(self, query=None, ignore_common_filters=None, model=None): if isinstance(query, Table): q = self._adapter.id_query(query) elif isinstance(query, Field): - q = (query != None) + q = query != None # noqa: E711 elif isinstance(query, dict): icf = query.get("ignore_common_filters") if icf: @@ -229,14 +200,14 @@ def where(self, query=None, ignore_common_filters=None, model=None): q = self._adapter.id_query(query.table) else: q = query - return Set( - self, q, ignore_common_filters=ignore_common_filters, model=model) + return Set(self, q, ignore_common_filters=ignore_common_filters, model=model) def with_connection(self, f): @wraps(f) def wrapped(*args, **kwargs): with self.connection(): f(*args, **kwargs) + return wrapped def atomic(self): @@ -261,7 +232,7 @@ def rollback(self): def _Database_unpickler(db_uid): fake_app_obj = sdict(config=sdict(db=sdict())) - fake_app_obj.config.db.adapter = '' + fake_app_obj.config.db.adapter = "" return Database(fake_app_obj, db_uid=db_uid) diff --git a/emmett/orm/connection.py b/emmett/orm/connection.py index 1919c38a..a3b8a483 100644 --- a/emmett/orm/connection.py +++ b/emmett/orm/connection.py @@ -1,16 +1,16 @@ # -*- coding: utf-8 -*- """ - emmett.orm.connection - --------------------- +emmett.orm.connection +--------------------- - Provides pyDAL connection implementation for Emmett. +Provides pyDAL connection implementation for Emmett. - :copyright: 2014 Giovanni Barillari +:copyright: 2014 Giovanni Barillari - Parts of this code are inspired to peewee - :copyright: (c) 2010 by Charles Leifer +Parts of this code are inspired to peewee +:copyright: (c) 2010 by Charles Leifer - :license: BSD-3-Clause +:license: BSD-3-Clause """ import asyncio @@ -18,7 +18,6 @@ import heapq import threading import time - from collections import OrderedDict from functools import partial @@ -30,13 +29,13 @@ class ConnectionStateCtxVars: - __slots__ = ('_connection', '_transactions', '_cursors', '_closed') + __slots__ = ("_connection", "_transactions", "_cursors", "_closed") def __init__(self): - self._connection = contextvars.ContextVar('_emt_orm_cs_connection') - self._transactions = contextvars.ContextVar('_emt_orm_cs_transactions') - self._cursors = contextvars.ContextVar('_emt_orm_cs_cursors') - self._closed = contextvars.ContextVar('_emt_orm_cs_closed') + self._connection = contextvars.ContextVar("_emt_orm_cs_connection") + self._transactions = contextvars.ContextVar("_emt_orm_cs_transactions") + self._cursors = contextvars.ContextVar("_emt_orm_cs_cursors") + self._closed = contextvars.ContextVar("_emt_orm_cs_closed") self.reset() @property @@ -69,7 +68,7 @@ def reset(self): class ConnectionState: - __slots__ = ('_connection', '_transactions', '_cursors', '_closed') + __slots__ = ("_connection", "_transactions", "_cursors", "_closed") def __init__(self, connection=None): self.connection = connection @@ -87,14 +86,14 @@ def connection(self, value): class ConnectionStateCtl: - __slots__ = ['_state_obj_var', '_state_load_var'] + __slots__ = ["_state_obj_var", "_state_load_var"] state_cls = ConnectionState def __init__(self): inst_id = id(self) - self._state_obj_var = f'__emt_orm_state_{inst_id}__' - self._state_load_var = f'__emt_orm_state_loaded_{inst_id}__' + self._state_obj_var = f"__emt_orm_state_{inst_id}__" + self._state_load_var = f"__emt_orm_state_loaded_{inst_id}__" @property def _has_ctx(self): @@ -133,7 +132,7 @@ def reset(self): class ConnectionManager: - __slots__ = ['adapter', 'state', '__dict__'] + __slots__ = ["adapter", "state", "__dict__"] state_cls = ConnectionStateCtl def __init__(self, adapter, **kwargs): @@ -157,12 +156,7 @@ def _connection_open_sync(self): return self._connector_sync(), True async def _connection_open_loop(self): - return ( - await self._loop.run_in_executor( - None, self._connector_loop - ), - True - ) + return (await self._loop.run_in_executor(None, self._connector_loop), True) def _connection_close_sync(self, connection, *args, **kwargs): try: @@ -171,9 +165,7 @@ def _connection_close_sync(self, connection, *args, **kwargs): pass async def _connection_close_loop(self, connection, *args, **kwargs): - return await self._loop.run_in_executor( - None, partial(self._connection_close_sync, connection) - ) + return await self._loop.run_in_executor(None, partial(self._connection_close_sync, connection)) connect_sync = _connection_open_sync connect_loop = _connection_open_loop @@ -188,18 +180,16 @@ def __del__(self): class PooledConnectionManager(ConnectionManager): __slots__ = [ - 'max_connections', 'connect_timeout', 'stale_timeout', - 'connections_map', 'connections_sync', - 'in_use', '_lock_sync' + "max_connections", + "connect_timeout", + "stale_timeout", + "connections_map", + "connections_sync", + "in_use", + "_lock_sync", ] - def __init__( - self, - adapter, - max_connections=5, - connect_timeout=0, - stale_timeout=0 - ): + def __init__(self, adapter, max_connections=5, connect_timeout=0, stale_timeout=0): super().__init__(adapter) self.max_connections = max(max_connections, 1) self.connect_timeout = connect_timeout @@ -234,9 +224,7 @@ def connect_sync(self): raise MaxConnectionsExceeded() async def connect_loop(self): - return await asyncio.wait_for( - self._acquire_loop(), self.connect_timeout or None - ) + return await asyncio.wait_for(self._acquire_loop(), self.connect_timeout or None) def _acquire_sync(self): _opened = False @@ -274,9 +262,7 @@ async def _acquire_loop(self): break ts, key = await self.connections_loop.get() if self.stale_timeout and self.is_stale(ts): - await self._connection_close_loop( - self.connections_map.pop(key) - ) + await self._connection_close_loop(self.connections_map.pop(key)) else: conn = self.connections_map[key] break @@ -326,7 +312,7 @@ def _connect_sync(self, with_transaction=True, reuse_if_open=False): if not self._connection_manager.state.closed: if reuse_if_open: return False - raise RuntimeError('Connection already opened.') + raise RuntimeError("Connection already opened.") self.connection, _opened = self._connection_manager.connect_sync() if _opened: self.after_connection_hook() @@ -340,7 +326,7 @@ async def _connect_loop(self, with_transaction=True, reuse_if_open=False): if not self._connection_manager.state.closed: if reuse_if_open: return False - raise RuntimeError('Connection already opened.') + raise RuntimeError("Connection already opened.") self.connection, _opened = await self._connection_manager.connect_loop() if _opened: self.after_connection_hook() @@ -350,7 +336,7 @@ async def _connect_loop(self, with_transaction=True, reuse_if_open=False): return True -def _close_sync(self, action='commit', really=True): +def _close_sync(self, action="commit", really=True): is_open = not self._connection_manager.state.closed if not is_open: return is_open @@ -369,7 +355,7 @@ def _close_sync(self, action='commit', really=True): return is_open -async def _close_loop(self, action='commit', really=True): +async def _close_loop(self, action="commit", really=True): is_open = not self._connection_manager.state.closed if not is_open: return is_open diff --git a/emmett/orm/engines/__init__.py b/emmett/orm/engines/__init__.py index 1a1702a6..a612ca93 100644 --- a/emmett/orm/engines/__init__.py +++ b/emmett/orm/engines/__init__.py @@ -1,6 +1,3 @@ from pydal.adapters import adapters -from . import ( - postgres, - sqlite -) +from . import postgres, sqlite diff --git a/emmett/orm/engines/postgres.py b/emmett/orm/engines/postgres.py index dcc47f88..b6bbc174 100644 --- a/emmett/orm/engines/postgres.py +++ b/emmett/orm/engines/postgres.py @@ -1,19 +1,15 @@ # -*- coding: utf-8 -*- """ - emmett.orm.engines.postgres - --------------------------- +emmett.orm.engines.postgres +--------------------------- - Provides ORM PostgreSQL engine specific features. +Provides ORM PostgreSQL engine specific features. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ -from pydal.adapters.postgres import ( - PostgreBoolean, - PostgrePsycoBoolean, - PostgrePG8000Boolean -) +from pydal.adapters.postgres import PostgreBoolean, PostgrePG8000Boolean, PostgrePsycoBoolean from pydal.dialects import register_expression, sqltype_for from pydal.dialects.postgre import PostgreDialectBooleanJSON from pydal.helpers.serializers import serializers @@ -27,98 +23,94 @@ class JSONBPostgreDialect(PostgreDialectBooleanJSON): - @sqltype_for('jsonb') + @sqltype_for("jsonb") def type_jsonb(self): - return 'jsonb' + return "jsonb" def _jcontains(self, field, data, query_env={}): - return '(%s @> %s)' % ( + return "(%s @> %s)" % ( self.expand(field, query_env=query_env), - self.expand(data, field.type, query_env=query_env) + self.expand(data, field.type, query_env=query_env), ) - @register_expression('jcontains') + @register_expression("jcontains") def _jcontains_expr(self, expr, data): return Query(expr.db, self._jcontains, expr, data) def _jin(self, field, data, query_env={}): - return '(%s <@ %s)' % ( + return "(%s <@ %s)" % ( self.expand(field, query_env=query_env), - self.expand(data, field.type, query_env=query_env) + self.expand(data, field.type, query_env=query_env), ) - @register_expression('jin') + @register_expression("jin") def _jin_expr(self, expr, data): return Query(expr.db, self._jin, expr, data) def _jget_common(self, op, field, data, query_env): if not isinstance(data, int): - _dtype = field.type if isinstance(data, (dict, list)) else 'string' + _dtype = field.type if isinstance(data, (dict, list)) else "string" data = self.expand(data, field_type=_dtype, query_env=query_env) - return '%s %s %s' % ( - self.expand(field, query_env=query_env), - op, - str(data) - ) + return "%s %s %s" % (self.expand(field, query_env=query_env), op, str(data)) def _jget(self, field, data, query_env={}): - return self._jget_common('->', field, data, query_env=query_env) + return self._jget_common("->", field, data, query_env=query_env) - @register_expression('jget') + @register_expression("jget") def _jget_expr(self, expr, data): return Expression(expr.db, self._jget, expr, data, expr.type) def _jgetv(self, field, data, query_env={}): - return self._jget_common('->>', field, data, query_env=query_env) + return self._jget_common("->>", field, data, query_env=query_env) - @register_expression('jgetv') + @register_expression("jgetv") def _jgetv_expr(self, expr, data): - return Expression(expr.db, self._jgetv, expr, data, 'string') + return Expression(expr.db, self._jgetv, expr, data, "string") def _jpath(self, field, data, query_env={}): - return '%s #> %s' % ( + return "%s #> %s" % ( self.expand(field, query_env=query_env), - self.expand(data, field_type='string', query_env=query_env) + self.expand(data, field_type="string", query_env=query_env), ) - @register_expression('jpath') + @register_expression("jpath") def _jpath_expr(self, expr, data): return Expression(expr.db, self._jpath, expr, data, expr.type) def _jpathv(self, field, data, query_env={}): - return '%s #>> %s' % ( + return "%s #>> %s" % ( self.expand(field, query_env=query_env), - self.expand(data, field_type='string', query_env=query_env) + self.expand(data, field_type="string", query_env=query_env), ) - @register_expression('jpathv') + @register_expression("jpathv") def _jpathv_expr(self, expr, data): - return Expression(expr.db, self._jpathv, expr, data, 'string') + return Expression(expr.db, self._jpathv, expr, data, "string") def _jhas(self, field, data, all=False, query_env={}): - _op, _ftype = '?', 'string' + _op, _ftype = "?", "string" if isinstance(data, list): - _op = '?&' if all else '?|' - _ftype = 'list:string' - return '%s %s %s' % ( + _op = "?&" if all else "?|" + _ftype = "list:string" + return "%s %s %s" % ( self.expand(field, query_env=query_env), _op, - self.expand(data, field_type=_ftype, query_env=query_env) + self.expand(data, field_type=_ftype, query_env=query_env), ) - @register_expression('jhas') + @register_expression("jhas") def _jhas_expr(self, expr, data, all=False): return Query(expr.db, self._jhas, expr, data, all=all) class JSONBPostgreParser(PostgreBooleanAutoJSONParser): - @parse_type('jsonb') + @parse_type("jsonb") def _jsonb(self, value): return value class JSONBPostgreRepresenter(PostgreArraysRepresenter): - @repr_type('jsonb') + @repr_type("jsonb") def _jsonb(self, value): return serializers.json(value) @@ -145,9 +137,9 @@ def _insert(self, table, fields): retval = table._id._rname return self.dialect.insert( table._rname, - ','.join(el[0]._rname for el in fields), - ','.join(self.expand(v, f.type) for f, v in fields), - retval + ",".join(el[0]._rname for el in fields), + ",".join(self.expand(v, f.type) for f, v in fields), + retval, ) return self.dialect.insert_empty(table._rname) @@ -159,16 +151,16 @@ def lastrowid(self, table): return self.cursor.fetchone()[0] -@adapters.register_for('postgres') +@adapters.register_for("postgres") class PostgresAdapter(PostgresAdapterMixin, PostgreBoolean): pass -@adapters.register_for('postgres:psycopg2') +@adapters.register_for("postgres:psycopg2") class PostgresPsycoPG2Adapter(PostgresAdapterMixin, PostgrePsycoBoolean): pass -@adapters.register_for('postgres:pg8000') +@adapters.register_for("postgres:pg8000") class PostgresPG8000Adapter(PostgresAdapterMixin, PostgrePG8000Boolean): pass diff --git a/emmett/orm/engines/sqlite.py b/emmett/orm/engines/sqlite.py index a979ce76..2b6b5d81 100644 --- a/emmett/orm/engines/sqlite.py +++ b/emmett/orm/engines/sqlite.py @@ -1,12 +1,12 @@ # -*- coding: utf-8 -*- """ - emmett.orm.engines.sqlite - ------------------------- +emmett.orm.engines.sqlite +------------------------- - Provides ORM SQLite engine specific features. +Provides ORM SQLite engine specific features. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ from pydal.adapters.sqlite import SQLite as _SQLite @@ -14,27 +14,21 @@ from . import adapters -@adapters.register_for('sqlite', 'sqlite:memory') +@adapters.register_for("sqlite", "sqlite:memory") class SQLite(_SQLite): def _initialize_(self, do_connect): super()._initialize_(do_connect) - self.driver_args['isolation_level'] = None + self.driver_args["isolation_level"] = None def begin(self, lock_type=None): - statement = 'BEGIN %s;' % lock_type if lock_type else 'BEGIN;' + statement = "BEGIN %s;" % lock_type if lock_type else "BEGIN;" self.execute(statement) def delete(self, table, query): - deleted = ( - [x[table._id.name] for x in self.db(query).select(table._id)] - if table._id else [] - ) + deleted = [x[table._id.name] for x in self.db(query).select(table._id)] if table._id else [] counter = super(_SQLite, self).delete(table, query) if table._id and counter: for field in table._referenced_by: - if ( - field.type == 'reference ' + table._dalname and - field.ondelete == 'CASCADE' - ): + if field.type == "reference " + table._dalname and field.ondelete == "CASCADE": self.db(field.belongs(deleted)).delete() return counter diff --git a/emmett/orm/errors.py b/emmett/orm/errors.py index 691bf876..293e9e35 100644 --- a/emmett/orm/errors.py +++ b/emmett/orm/errors.py @@ -1,39 +1,33 @@ # -*- coding: utf-8 -*- """ - emmett.orm.errors - ----------------- +emmett.orm.errors +----------------- - Provides some error wrappers. +Provides some error wrappers. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ class MaxConnectionsExceeded(RuntimeError): def __init__(self): - super().__init__('Exceeded maximum connections') + super().__init__("Exceeded maximum connections") -class MissingFieldsForCompute(RuntimeError): - ... +class MissingFieldsForCompute(RuntimeError): ... -class SaveException(RuntimeError): - ... +class SaveException(RuntimeError): ... -class InsertFailureOnSave(SaveException): - ... +class InsertFailureOnSave(SaveException): ... -class UpdateFailureOnSave(SaveException): - ... +class UpdateFailureOnSave(SaveException): ... -class DestroyException(RuntimeError): - ... +class DestroyException(RuntimeError): ... -class ValidationError(RuntimeError): - ... +class ValidationError(RuntimeError): ... diff --git a/emmett/orm/geo.py b/emmett/orm/geo.py index 8ecb3e6a..df752609 100644 --- a/emmett/orm/geo.py +++ b/emmett/orm/geo.py @@ -1,12 +1,12 @@ # -*- coding: utf-8 -*- """ - emmett.orm.geo - -------------- +emmett.orm.geo +-------------- - Provides geographic facilities. +Provides geographic facilities. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ from .helpers import GeoFieldWrapper @@ -17,9 +17,7 @@ def Point(x, y): def Line(*coordinates): - return GeoFieldWrapper( - "LINESTRING(%s)" % ','.join("%f %f" % point for point in coordinates) - ) + return GeoFieldWrapper("LINESTRING(%s)" % ",".join("%f %f" % point for point in coordinates)) def Polygon(*coordinates_groups): @@ -29,46 +27,30 @@ def Polygon(*coordinates_groups): except Exception: pass return GeoFieldWrapper( - "POLYGON(%s)" % ( - ",".join([ - "(%s)" % ",".join("%f %f" % point for point in group) - for group in coordinates_groups - ]) - ) + "POLYGON(%s)" + % (",".join(["(%s)" % ",".join("%f %f" % point for point in group) for group in coordinates_groups])) ) def MultiPoint(*points): - return GeoFieldWrapper( - "MULTIPOINT(%s)" % ( - ",".join([ - "(%f %f)" % point for point in points - ]) - ) - ) + return GeoFieldWrapper("MULTIPOINT(%s)" % (",".join(["(%f %f)" % point for point in points]))) def MultiLine(*lines): return GeoFieldWrapper( - "MULTILINESTRING(%s)" % ( - ",".join([ - "(%s)" % ",".join("%f %f" % point for point in line) - for line in lines - ]) - ) + "MULTILINESTRING(%s)" % (",".join(["(%s)" % ",".join("%f %f" % point for point in line) for line in lines])) ) def MultiPolygon(*polygons): return GeoFieldWrapper( - "MULTIPOLYGON(%s)" % ( - ",".join([ - "(%s)" % ( - ",".join([ - "(%s)" % ",".join("%f %f" % point for point in group) - for group in polygon - ]) - ) for polygon in polygons - ]) + "MULTIPOLYGON(%s)" + % ( + ",".join( + [ + "(%s)" % (",".join(["(%s)" % ",".join("%f %f" % point for point in group) for group in polygon])) + for polygon in polygons + ] + ) ) ) diff --git a/emmett/orm/helpers.py b/emmett/orm/helpers.py index 949df374..58713725 100644 --- a/emmett/orm/helpers.py +++ b/emmett/orm/helpers.py @@ -1,12 +1,12 @@ # -*- coding: utf-8 -*- """ - emmett.orm.helpers - ------------------ +emmett.orm.helpers +------------------ - Provides ORM helpers. +Provides ORM helpers. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ from __future__ import annotations @@ -15,23 +15,23 @@ import operator import re import time - from functools import reduce, wraps from typing import TYPE_CHECKING, Any, Callable +from emmett_core.utils import cachedprop from pydal._globals import THREAD_LOCAL from pydal.helpers.classes import ExecutionHandler from pydal.objects import Field as _Field -from emmett_core.utils import cachedprop from ..datastructures import sdict + if TYPE_CHECKING: from .objects import Table class RowReferenceMeta: - __slots__ = ['table', 'pk', 'caster'] + __slots__ = ["table", "pk", "caster"] def __init__(self, table: Table, caster: Callable[[Any], Any]): self.table = table @@ -39,15 +39,14 @@ def __init__(self, table: Table, caster: Callable[[Any], Any]): self.caster = caster def fetch(self, val): - return self.table._db(self.table._id == self.caster(val)).select( - limitby=(0, 1), - orderby_on_limitby=False - ).first() + return ( + self.table._db(self.table._id == self.caster(val)).select(limitby=(0, 1), orderby_on_limitby=False).first() + ) class RowReferenceMultiMeta: - __slots__ = ['table', 'pks', 'pks_idx', 'caster', 'casters'] - _casters = {'integer': int, 'string': str} + __slots__ = ["table", "pks", "pks_idx", "caster", "casters"] + _casters = {"integer": int, "string": str} def __init__(self, table: Table) -> None: self.table = table @@ -58,15 +57,10 @@ def __init__(self, table: Table) -> None: def fetch(self, val): query = reduce( - operator.and_, [ - self.table[pk] == self.casters[pk](self.caster.__getitem__(val, idx)) - for pk, idx in self.pks_idx.items() - ] + operator.and_, + [self.table[pk] == self.casters[pk](self.caster.__getitem__(val, idx)) for pk, idx in self.pks_idx.items()], ) - return self.table._db(query).select( - limitby=(0, 1), - orderby_on_limitby=False - ).first() + return self.table._db(query).select(limitby=(0, 1), orderby_on_limitby=False).first() class RowReferenceMixin: @@ -75,8 +69,7 @@ def _allocate_(self): self._refrecord = self._refmeta.fetch(self) if not self._refrecord: raise RuntimeError( - "Using a recursive select but encountered a broken " + - "reference: %s %r" % (self._table, self) + "Using a recursive select but encountered a broken " + "reference: %s %r" % (self._table, self) ) def __getattr__(self, key: str) -> Any: @@ -92,7 +85,7 @@ def get(self, key: str, default: Any = None) -> Any: return self.__getattr__(key, default) def __setattr__(self, key: str, value: Any): - if key.startswith('_'): + if key.startswith("_"): self._refmeta.caster.__setattr__(self, key, value) return self._allocate_() @@ -118,16 +111,16 @@ def __repr__(self) -> str: class RowReferenceInt(RowReferenceMixin, int): def __new__(cls, id, table: Table, *args: Any, **kwargs: Any): rv = super().__new__(cls, id, *args, **kwargs) - int.__setattr__(rv, '_refmeta', RowReferenceMeta(table, int)) - int.__setattr__(rv, '_refrecord', None) + int.__setattr__(rv, "_refmeta", RowReferenceMeta(table, int)) + int.__setattr__(rv, "_refrecord", None) return rv class RowReferenceStr(RowReferenceMixin, str): def __new__(cls, id, table: Table, *args: Any, **kwargs: Any): rv = super().__new__(cls, id, *args, **kwargs) - str.__setattr__(rv, '_refmeta', RowReferenceMeta(table, str)) - str.__setattr__(rv, '_refrecord', None) + str.__setattr__(rv, "_refmeta", RowReferenceMeta(table, str)) + str.__setattr__(rv, "_refrecord", None) return rv @@ -135,15 +128,13 @@ class RowReferenceMulti(RowReferenceMixin, tuple): def __new__(cls, id, table: Table, *args: Any, **kwargs: Any): tupid = tuple(id[key] for key in table._primarykey) rv = super().__new__(cls, tupid, *args, **kwargs) - tuple.__setattr__(rv, '_refmeta', RowReferenceMultiMeta(table)) - tuple.__setattr__(rv, '_refrecord', None) + tuple.__setattr__(rv, "_refmeta", RowReferenceMultiMeta(table)) + tuple.__setattr__(rv, "_refrecord", None) return rv def __getattr__(self, key: str) -> Any: if key in self._refmeta.pks: - return self._refmeta.casters[key]( - tuple.__getitem__(self, self._refmeta.pks_idx[key]) - ) + return self._refmeta.casters[key](tuple.__getitem__(self, self._refmeta.pks_idx[key])) if key in self._refmeta.table: self._allocate_() if self._refrecord: @@ -152,9 +143,7 @@ def __getattr__(self, key: str) -> Any: def __getitem__(self, key): if key in self._refmeta.pks: - return self._refmeta.casters[key]( - tuple.__getitem__(self, self._refmeta.pks_idx[key]) - ) + return self._refmeta.casters[key](tuple.__getitem__(self, self._refmeta.pks_idx[key])) self._allocate_() return self._refrecord.get(key, None) @@ -167,22 +156,22 @@ class GeoFieldWrapper(str): "POLYGON": "Polygon", "MULTIPOINT": "MultiPoint", "MULTILINESTRING": "MultiLineString", - "MULTIPOLYGON": "MultiPolygon" + "MULTIPOLYGON": "MultiPolygon", } def __new__(cls, value, *args: Any, **kwargs: Any): geometry, raw_coords = value.strip()[:-1].split("(", 1) rv = super().__new__(cls, value, *args, **kwargs) coords = cls._parse_coords_block(raw_coords) - str.__setattr__(rv, '_geometry', geometry.strip()) - str.__setattr__(rv, '_coordinates', coords) + str.__setattr__(rv, "_geometry", geometry.strip()) + str.__setattr__(rv, "_coordinates", coords) return rv @classmethod def _parse_coords_block(cls, v): groups = [] parens_match = cls._rule_parens.match(v) - parens = parens_match.group(1) if parens_match else '' + parens = parens_match.group(1) if parens_match else "" if parens: for element in v.split(parens): if not element: @@ -192,9 +181,7 @@ def _parse_coords_block(cls, v): groups.append(f"{parens}{element}"[1:shift]) if not groups: return cls._parse_coords_group(v) - return tuple( - cls._parse_coords_block(group) for group in groups - ) + return tuple(cls._parse_coords_block(group) for group in groups) @staticmethod def _parse_coords_group(v): @@ -225,17 +212,13 @@ def coordinates(self): @property def groups(self): if not self._geometry.startswith("MULTI"): - return tuple() + return () return tuple( - self.__class__(f"{self._geometry[5:]}({self._repr_coords(coords)[0]})") - for coords in self._coordinates + self.__class__(f"{self._geometry[5:]}({self._repr_coords(coords)[0]})") for coords in self._coordinates ) def __json__(self): - return { - "type": self._json_geom_map[self._geometry], - "coordinates": self._coordinates - } + return {"type": self._json_geom_map[self._geometry], "coordinates": self._coordinates} class PasswordFieldWrapper(str): @@ -244,27 +227,24 @@ class PasswordFieldWrapper(str): class Reference(object): def __init__(self, *args, **params): - self.reference = [arg for arg in args] + self.reference = list(args) self.params = params self.refobj[id(self)] = self def __call__(self, func): - if self.__class__.__name__ not in ['has_one', 'has_many']: - raise SyntaxError( - '%s cannot be used as a decorator' % self.__class__.__name__) + if self.__class__.__name__ not in ["has_one", "has_many"]: + raise SyntaxError("%s cannot be used as a decorator" % self.__class__.__name__) if not callable(func): - raise SyntaxError('Argument must be callable') + raise SyntaxError("Argument must be callable") if self.reference: - raise SyntaxError( - "When using %s as decorator, you must use the 'field' option" % - self.__class__.__name__) - new_reference = {func.__name__: {'method': func}} - field = self.params.get('field') + raise SyntaxError("When using %s as decorator, you must use the 'field' option" % self.__class__.__name__) + new_reference = {func.__name__: {"method": func}} + field = self.params.get("field") if field: - new_reference[func.__name__]['field'] = field - cast = self.params.get('cast') + new_reference[func.__name__]["field"] = field + cast = self.params.get("cast") if cast: - new_reference[func.__name__]['cast'] = cast + new_reference[func.__name__]["cast"] = cast self.reference = [new_reference] return self @@ -345,10 +325,11 @@ def _get_belongs(self, modelname, value): def belongs_query(self): return reduce( - operator.and_, [ + operator.and_, + [ self.model.table[local] == self.model.db[self.ref.model][foreign] for local, foreign in self.ref.coupled_fields - ] + ], ) @staticmethod @@ -361,17 +342,10 @@ def many_query(ref, rid): components.append(element.cast(ref.cast)) else: components.append(element) - return reduce( - operator.and_, [ - field == components[idx] - for idx, field in enumerate(ref.fields_instances) - ] - ) + return reduce(operator.and_, [field == components[idx] for idx, field in enumerate(ref.fields_instances)]) def _many(self, ref, rid): - return ref.dbset.where( - self._patch_query_with_scopes(ref, self.many_query(ref, rid)) - ).query + return ref.dbset.where(self._patch_query_with_scopes(ref, self.many_query(ref, rid))).query def many(self, row=None): return self._many(self.ref, self._make_refid(row)) @@ -381,17 +355,11 @@ def via(self, row=None): rid = self._make_refid(row) sname = self.model.__class__.__name__ stack = [] - midrel = self.model._hasmany_ref_.get( - self.ref.via, - self.model._hasone_ref_.get(self.ref.via) - ) + midrel = self.model._hasmany_ref_.get(self.ref.via, self.model._hasone_ref_.get(self.ref.via)) stack.append(self.ref) while midrel.via is not None: stack.insert(0, midrel) - midrel = self.model._hasmany_ref_.get( - midrel.via, - self.model._hasone_ref_.get(midrel.via) - ) + midrel = self.model._hasmany_ref_.get(midrel.via, self.model._hasone_ref_.get(midrel.via)) query = self._many(midrel, rid) step_model = midrel.table_name sel_field = db[step_model].ALL @@ -405,12 +373,11 @@ def via(self, row=None): last_belongs = step_model last_via = via _query = reduce( - operator.and_, [ - ( - db[belongs_model.model][foreign] == - db[step_model][local] - ) for local, foreign in belongs_model.coupled_fields - ] + operator.and_, + [ + (db[belongs_model.model][foreign] == db[step_model][local]) + for local, foreign in belongs_model.coupled_fields + ], ) sel_field = db[belongs_model.model].ALL step_model = belongs_model.model @@ -418,10 +385,7 @@ def via(self, row=None): #: shortcut way last_belongs = None rname = via.field or via.name - midrel = db[step_model]._model_._hasmany_ref_.get( - rname, - db[step_model]._model_._hasone_ref_.get(rname) - ) + midrel = db[step_model]._model_._hasmany_ref_.get(rname, db[step_model]._model_._hasone_ref_.get(rname)) if midrel.via: nested = RelationBuilder(midrel, midrel.model_class) nested_data = nested.via() @@ -429,19 +393,13 @@ def via(self, row=None): step_model = midrel.model_class.tablename else: _query = self._many( - midrel, [ - db[step_model][step_field] - for step_field in ( - db[step_model]._model_.primary_keys or ["id"] - ) - ] + midrel, + [db[step_model][step_field] for step_field in (db[step_model]._model_.primary_keys or ["id"])], ) step_model = midrel.table_name sel_field = db[step_model].ALL query = query & _query - query = via.dbset.where( - self._patch_query_with_scopes_on(via, query, step_model) - ).query + query = via.dbset.where(self._patch_query_with_scopes_on(via, query, step_model)).query return query, sel_field, sname, rid, last_belongs, last_via @@ -464,8 +422,7 @@ def __call__(self): class TimingHandler(ExecutionHandler): def _timings(self): - THREAD_LOCAL._emtdal_timings_ = getattr( - THREAD_LOCAL, '_emtdal_timings_', []) + THREAD_LOCAL._emtdal_timings_ = getattr(THREAD_LOCAL, "_emtdal_timings_", []) return THREAD_LOCAL._emtdal_timings_ @cachedprop @@ -481,7 +438,7 @@ def after_execute(self, command): class ConnectionContext: - __slots__ = ['db', 'conn', 'with_transaction', 'reuse_if_open'] + __slots__ = ["db", "conn", "with_transaction", "reuse_if_open"] def __init__(self, db, with_transaction=True, reuse_if_open=True): self.db = db @@ -490,10 +447,7 @@ def __init__(self, db, with_transaction=True, reuse_if_open=True): self.reuse_if_open = reuse_if_open def __enter__(self): - self.conn = self.db.connection_open( - with_transaction=self.with_transaction, - reuse_if_open=self.reuse_if_open - ) + self.conn = self.db.connection_open(with_transaction=self.with_transaction, reuse_if_open=self.reuse_if_open) return self def __exit__(self, exc_type, exc_val, exc_tb): @@ -503,8 +457,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): async def __aenter__(self): self.conn = await self.db.connection_open_loop( - with_transaction=self.with_transaction, - reuse_if_open=self.reuse_if_open + with_transaction=self.with_transaction, reuse_if_open=self.reuse_if_open ) return self @@ -515,7 +468,7 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): def decamelize(name): - return "_".join(re.findall('[A-Z][^A-Z]*', name)).lower() + return "_".join(re.findall("[A-Z][^A-Z]*", name)).lower() def camelize(name): @@ -529,17 +482,16 @@ def make_tablename(classname): def wrap_scope_on_set(dbset, model_instance, scope): @wraps(scope) def wrapped(*args, **kwargs): - return dbset.where( - scope(model_instance, *args, **kwargs), - model=model_instance.__class__) + return dbset.where(scope(model_instance, *args, **kwargs), model=model_instance.__class__) + return wrapped def wrap_scope_on_model(scope): @wraps(scope) def wrapped(cls, *args, **kwargs): - return cls.db.where( - scope(cls._instance_(), *args, **kwargs), model=cls) + return cls.db.where(scope(cls._instance_(), *args, **kwargs), model=cls) + return wrapped @@ -547,27 +499,22 @@ def wrap_virtual_on_model(model, virtual): @wraps(virtual) def wrapped(row, *args, **kwargs): return virtual(model, row, *args, **kwargs) + return wrapped def typed_row_reference(id: Any, table: Table): field_type = table._id.type if table._id else None - return { - 'id': RowReferenceInt, - 'integer': RowReferenceInt, - 'string': RowReferenceStr, - None: RowReferenceMulti - }[field_type](id, table) + return {"id": RowReferenceInt, "integer": RowReferenceInt, "string": RowReferenceStr, None: RowReferenceMulti}[ + field_type + ](id, table) def typed_row_reference_from_record(record: Any, model: Any): field_type = model.table._id.type if model.table._id else None - refcls = { - 'id': RowReferenceInt, - 'integer': RowReferenceInt, - 'string': RowReferenceStr, - None: RowReferenceMulti - }[field_type] + refcls = {"id": RowReferenceInt, "integer": RowReferenceInt, "string": RowReferenceStr, None: RowReferenceMulti}[ + field_type + ] if len(model._fieldset_pk) > 1: id = {pk: record[pk] for pk in model._fieldset_pk} else: @@ -578,7 +525,7 @@ def typed_row_reference_from_record(record: Any, model: Any): def _rowref_pickler(obj): - return obj._refmeta.caster, (obj.__pure__(), ) + return obj._refmeta.caster, (obj.__pure__(),) copyreg.pickle(RowReferenceInt, _rowref_pickler) diff --git a/emmett/orm/migrations/__init__.py b/emmett/orm/migrations/__init__.py index 5b9db84a..87321f19 100644 --- a/emmett/orm/migrations/__init__.py +++ b/emmett/orm/migrations/__init__.py @@ -1,2 +1,2 @@ -from .base import Migration, Column +from .base import Column, Migration from .operations import * diff --git a/emmett/orm/migrations/base.py b/emmett/orm/migrations/base.py index 7c0d6e15..7178d4a3 100644 --- a/emmett/orm/migrations/base.py +++ b/emmett/orm/migrations/base.py @@ -1,12 +1,12 @@ # -*- coding: utf-8 -*- """ - emmett.orm.migrations.base - -------------------------- +emmett.orm.migrations.base +-------------------------- - Provides base migrations objects. +Provides base migrations objects. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ from __future__ import annotations @@ -14,10 +14,11 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, Type from ...datastructures import sdict -from .. import Database, Model, Field -from .engine import MetaEngine, Engine +from .. import Database, Field, Model +from .engine import Engine, MetaEngine from .helpers import WrappedOperation, _feasible_as_dbms_default + if TYPE_CHECKING: from .operations import Operation @@ -32,13 +33,11 @@ class Migration: skip_on_compare: bool = False @classmethod - def register_operation( - cls, - name: str - ) -> Callable[[Type[Operation]], Type[Operation]]: + def register_operation(cls, name: str) -> Callable[[Type[Operation]], Type[Operation]]: def wrap(op_cls: Type[Operation]) -> Type[Operation]: cls._registered_ops_[name] = op_cls return op_cls + return wrap def __init__(self, app: Any, db: Database, is_meta: bool = False): @@ -56,14 +55,7 @@ def __getattr__(self, name: str) -> WrappedOperation: class Column(sdict): - def __init__( - self, - name: str, - type: str = 'string', - unique: bool = False, - notnull: bool = False, - **kwargs: Any - ): + def __init__(self, name: str, type: str = "string", unique: bool = False, notnull: bool = False, **kwargs: Any): self.name = name self.type = type self.unique = unique @@ -76,7 +68,7 @@ def _fk_type(self, db: Database, tablename: str): if self.name not in db[tablename]._model_._belongs_ref_: return ref = db[tablename]._model_._belongs_ref_[self.name] - if ref.ftype != 'id': + if ref.ftype != "id": self.type = ref.ftype self.length = db[ref.model][ref.fk].length self.on_delete = None @@ -90,7 +82,7 @@ def from_field(cls, field: Field) -> Column: field.notnull, length=field.length, ondelete=field.ondelete, - **field._ormkw + **field._ormkw, ) if _feasible_as_dbms_default(field.default): rv.default = field.default @@ -98,7 +90,4 @@ def from_field(cls, field: Field) -> Column: return rv def __repr__(self) -> str: - return "%s(%s)" % ( - self.__class__.__name__, - ", ".join(["%s=%r" % (k, v) for k, v in self.items()]) - ) + return "%s(%s)" % (self.__class__.__name__, ", ".join(["%s=%r" % (k, v) for k, v in self.items()])) diff --git a/emmett/orm/migrations/commands.py b/emmett/orm/migrations/commands.py index 6d055ecb..64019389 100644 --- a/emmett/orm/migrations/commands.py +++ b/emmett/orm/migrations/commands.py @@ -1,12 +1,12 @@ # -*- coding: utf-8 -*- """ - emmett.orm.migrations.commands - ------------------------------ +emmett.orm.migrations.commands +------------------------------ - Provides command interfaces for migrations. +Provides command interfaces for migrations. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ from __future__ import annotations @@ -16,9 +16,9 @@ import click from ...datastructures import sdict -from .base import Database, Schema, Column +from .base import Column, Database, Schema from .helpers import DryRunDatabase, make_migration_id, to_tuple -from .operations import MigrationOp, UpgradeOps, DowngradeOps +from .operations import DowngradeOps, MigrationOp, UpgradeOps from .scripts import ScriptDir @@ -30,14 +30,7 @@ def __init__(self, app: Any, dals: List[Database]): def _load_envs(self, dals): for dal in dals: - self.envs.append( - sdict( - db=dal, - scriptdir=ScriptDir( - self.app, dal.config.migrations_folder - ) - ) - ) + self.envs.append(sdict(db=dal, scriptdir=ScriptDir(self.app, dal.config.migrations_folder))) def load_schema(self, ctx): ctx.db.define_models(Schema) @@ -53,6 +46,7 @@ def _ensure_schema_table_(self, ctx): ctx.db.rollback() from .engine import Engine from .operations import CreateTableOp + op = CreateTableOp.from_table(self._build_schema_metatable_(ctx)) op.engine = Engine(ctx.db) op.run() @@ -61,13 +55,11 @@ def _ensure_schema_table_(self, ctx): @staticmethod def _build_schema_metatable_(ctx): from .generation import MetaTable + columns = [] for field in list(ctx.db.Schema): columns.append(Column.from_field(field)) - return MetaTable( - ctx.db.Schema._tablename, - columns - ) + return MetaTable(ctx.db.Schema._tablename, columns) @staticmethod def _load_current_revision_(ctx): @@ -81,226 +73,152 @@ def _load_current_revision_(ctx): @staticmethod def _log_store_new(revid): - click.echo( - " ".join([ - "> Adding revision", - click.style(revid, fg="cyan", bold=True), - "to schema" - ]) - ) + click.echo(" ".join(["> Adding revision", click.style(revid, fg="cyan", bold=True), "to schema"])) @staticmethod def _log_store_del(revid): - click.echo( - " ".join([ - "> Removing revision", - click.style(revid, fg="cyan", bold=True), - "from schema" - ]) - ) + click.echo(" ".join(["> Removing revision", click.style(revid, fg="cyan", bold=True), "from schema"])) @staticmethod def _log_store_upd(revid_src, revid_dst): click.echo( - " ".join([ - "> Updating schema revision from", - click.style(revid_src, fg="cyan", bold=True), - "to", - click.style(revid_dst, fg="cyan", bold=True), - ]) + " ".join( + [ + "> Updating schema revision from", + click.style(revid_src, fg="cyan", bold=True), + "to", + click.style(revid_dst, fg="cyan", bold=True), + ] + ) ) @staticmethod def _log_dry_run(msg): - click.secho(msg, fg='yellow') + click.secho(msg, fg="yellow") def _store_current_revision_(self, ctx, source, dest): - _store_logs = { - 'new': self._log_store_new, - 'del': self._log_store_del, - 'upd': self._log_store_upd - } + _store_logs = {"new": self._log_store_new, "del": self._log_store_del, "upd": self._log_store_upd} source = to_tuple(source) dest = to_tuple(dest) if not source and dest: - _store_logs['new'](dest[0]) + _store_logs["new"](dest[0]) ctx.db.Schema.insert(version=dest[0]) ctx.db.commit() ctx._current_revision_ = [dest[0]] return if not dest and source: - _store_logs['del'](source[0]) + _store_logs["del"](source[0]) ctx.db(ctx.db.Schema.version == source[0]).delete() ctx.db.commit() ctx._current_revision_ = [] return if len(source) > 1: if len(source) > 2: - ctx.db( - ctx.db.Schema.version.belongs( - source[1:])).delete() - _store_logs['del'](source[1:]) + ctx.db(ctx.db.Schema.version.belongs(source[1:])).delete() + _store_logs["del"](source[1:]) else: - ctx.db( - ctx.db.Schema.version == source[1]).delete() - _store_logs['del'](source[1]) - ctx.db(ctx.db.Schema.version == source[0]).update( - version=dest[0] - ) - _store_logs['upd'](source[0], dest[0]) + ctx.db(ctx.db.Schema.version == source[1]).delete() + _store_logs["del"](source[1]) + ctx.db(ctx.db.Schema.version == source[0]).update(version=dest[0]) + _store_logs["upd"](source[0], dest[0]) ctx._current_revision_ = [dest[0]] else: if list(source) != ctx._current_revision_: ctx.db.Schema.insert(version=dest[0]) - _store_logs['new'](dest[0]) + _store_logs["new"](dest[0]) ctx._current_revision_.append(dest[0]) else: - ctx.db( - ctx.db.Schema.version == source[0] - ).update( - version=dest[0] - ) - _store_logs['upd'](source[0], dest[0]) + ctx.db(ctx.db.Schema.version == source[0]).update(version=dest[0]) + _store_logs["upd"](source[0], dest[0]) ctx._current_revision_ = [dest[0]] ctx.db.commit() @staticmethod def _generate_migration_script(ctx, migration, head): from .generation import Renderer + upgrades, downgrades = Renderer.render_migration(migration) ctx.scriptdir.generate_revision( - migration.rev_id, migration.message, head, upgrades=upgrades, - downgrades=downgrades + migration.rev_id, migration.message, head, upgrades=upgrades, downgrades=downgrades ) def generate(self, message, head): from .generation import Generator + for ctx in self.envs: upgrade_ops = Generator.generate_from(ctx.db, ctx.scriptdir, head) revid = make_migration_id() - migration = MigrationOp( - revid, upgrade_ops, upgrade_ops.reverse(), message - ) + migration = MigrationOp(revid, upgrade_ops, upgrade_ops.reverse(), message) self._generate_migration_script(ctx, migration, head) - click.echo( - " ".join([ - "> Generated migration for revision", - click.style(revid, fg="cyan", bold=True) - ]) - ) + click.echo(" ".join(["> Generated migration for revision", click.style(revid, fg="cyan", bold=True)])) def new(self, message, head): for ctx in self.envs: source_rev = ctx.scriptdir.get_revision(head) revid = make_migration_id() - migration = MigrationOp( - revid, UpgradeOps(), DowngradeOps(), message - ) - self._generate_migration_script( - ctx, migration, source_rev.revision - ) - click.echo( - " ".join([ - "> Created new migration for revision", - click.style(revid, fg="cyan", bold=True) - ]) - ) + migration = MigrationOp(revid, UpgradeOps(), DowngradeOps(), message) + self._generate_migration_script(ctx, migration, source_rev.revision) + click.echo(" ".join(["> Created new migration for revision", click.style(revid, fg="cyan", bold=True)])) def history(self, base, head, verbose): for ctx in self.envs: click.echo("> Migrations history:") lines = [] - for sc in ctx.scriptdir.walk_revisions( - base=base or "base", - head=head or "heads" - ): - lines.append( - sc.cmd_format( - verbose=verbose, include_doc=True, include_parents=True - ) - ) + for sc in ctx.scriptdir.walk_revisions(base=base or "base", head=head or "heads"): + lines.append(sc.cmd_format(verbose=verbose, include_doc=True, include_parents=True)) for line in lines: click.echo(line) if not lines: - click.secho( - "No migrations for the selected application.", fg="yellow" - ) + click.secho("No migrations for the selected application.", fg="yellow") def status(self, verbose): for ctx in self.envs: self.load_schema(ctx) - click.echo( - " ".join([ - "> Current revision(s) for", - click.style(ctx.db._uri, bold=True) - ]) - ) + click.echo(" ".join(["> Current revision(s) for", click.style(ctx.db._uri, bold=True)])) lines = [] for rev in ctx.scriptdir.get_revisions(ctx._current_revision_): lines.append(rev.cmd_format(verbose)) for line in lines: click.echo(line) if not lines: - click.secho( - "No revision state found on the schema.", fg="yellow" - ) + click.secho("No revision state found on the schema.", fg="yellow") def up(self, rev_id, dry_run=False): log_verb = "Previewing" if dry_run else "Performing" for ctx in self.envs: self.load_schema(ctx) start_point = ctx._current_revision_ - revisions = ctx.scriptdir.get_upgrade_revs( - rev_id, start_point - ) - click.echo( - " ".join([ - f"> {log_verb} upgrades against", - click.style(ctx.db._uri, bold=True) - ]) - ) - db = ( - DryRunDatabase(ctx.db, self._log_dry_run) if dry_run else - ctx.db - ) + revisions = ctx.scriptdir.get_upgrade_revs(rev_id, start_point) + click.echo(" ".join([f"> {log_verb} upgrades against", click.style(ctx.db._uri, bold=True)])) + db = DryRunDatabase(ctx.db, self._log_dry_run) if dry_run else ctx.db with db.connection(): for revision in revisions: - click.echo( - " ".join([ - f"> {log_verb} upgrade:", - click.style(str(revision), fg="cyan", bold=True) - ]) - ) + click.echo(" ".join([f"> {log_verb} upgrade:", click.style(str(revision), fg="cyan", bold=True)])) migration = revision.migration_class(self.app, db) try: migration.up() db.commit() if dry_run: continue - self._store_current_revision_( - ctx, migration.revises, migration.revision - ) + self._store_current_revision_(ctx, migration.revises, migration.revision) click.echo( - "".join([ - click.style( - "> Succesfully upgraded to revision ", - fg="green" - ), - click.style( - revision.revision, fg="cyan", bold=True - ), - click.style(f": {revision.doc}", fg="green") - ]) + "".join( + [ + click.style("> Succesfully upgraded to revision ", fg="green"), + click.style(revision.revision, fg="cyan", bold=True), + click.style(f": {revision.doc}", fg="green"), + ] + ) ) except Exception: db.rollback() click.echo( - " ".join([ - click.style("> Failed upgrading to", fg="red"), - click.style( - revision.revision, fg="red", bold=True - ), - ]) + " ".join( + [ + click.style("> Failed upgrading to", fg="red"), + click.style(revision.revision, fg="red", bold=True), + ] + ) ) raise @@ -309,58 +227,37 @@ def down(self, rev_id, dry_run=False): for ctx in self.envs: self.load_schema(ctx) start_point = ctx._current_revision_ - revisions = ctx.scriptdir.get_downgrade_revs( - rev_id, start_point) - click.echo( - " ".join([ - f"> {log_verb} downgrades against", - click.style(ctx.db._uri, bold=True) - ]) - ) - db = ( - DryRunDatabase(ctx.db, self._log_dry_run) if dry_run else - ctx.db - ) + revisions = ctx.scriptdir.get_downgrade_revs(rev_id, start_point) + click.echo(" ".join([f"> {log_verb} downgrades against", click.style(ctx.db._uri, bold=True)])) + db = DryRunDatabase(ctx.db, self._log_dry_run) if dry_run else ctx.db with db.connection(): for revision in revisions: - click.echo( - " ".join([ - f"> {log_verb} downgrade:", - click.style(str(revision), fg="cyan", bold=True) - ]) - ) + click.echo(" ".join([f"> {log_verb} downgrade:", click.style(str(revision), fg="cyan", bold=True)])) migration = revision.migration_class(self.app, db) try: migration.down() db.commit() if dry_run: continue - self._store_current_revision_( - ctx, migration.revision, migration.revises - ) + self._store_current_revision_(ctx, migration.revision, migration.revises) click.echo( - "".join([ - click.style( - "> Succesfully downgraded from revision ", - fg="green" - ), - click.style( - revision.revision, fg="cyan", bold=True - ), - click.style(f": {revision.doc}", fg="green") - ]) + "".join( + [ + click.style("> Succesfully downgraded from revision ", fg="green"), + click.style(revision.revision, fg="cyan", bold=True), + click.style(f": {revision.doc}", fg="green"), + ] + ) ) except Exception: db.rollback() click.echo( - " ".join([ - click.style( - "> Failed downgrading from", fg="red" - ), - click.style( - revision.revision, fg="red", bold=True - ), - ]) + " ".join( + [ + click.style("> Failed downgrading from", fg="red"), + click.style(revision.revision, fg="red", bold=True), + ] + ) ) raise @@ -373,32 +270,29 @@ def set(self, rev_id, auto_confirm=False): click.secho("> No matching revision found", fg="red") return click.echo( - " ".join([ - click.style("> Setting revision to", fg="yellow"), - click.style(target_revision.revision, bold=True, fg="yellow"), - click.style("against", fg="yellow"), - click.style(ctx.db._uri, bold=True, fg="yellow") - ]) + " ".join( + [ + click.style("> Setting revision to", fg="yellow"), + click.style(target_revision.revision, bold=True, fg="yellow"), + click.style("against", fg="yellow"), + click.style(ctx.db._uri, bold=True, fg="yellow"), + ] + ) ) if not auto_confirm: if not click.confirm("Do you want to continue?"): click.echo("Aborting") return with ctx.db.connection(): - self._store_current_revision_( - ctx, current_revision, target_revision.revision - ) + self._store_current_revision_(ctx, current_revision, target_revision.revision) click.echo( - "".join([ - click.style( - "> Succesfully set revision to ", - fg="green" - ), - click.style( - target_revision.revision, fg="cyan", bold=True - ), - click.style(f": {target_revision.doc}", fg="green") - ]) + "".join( + [ + click.style("> Succesfully set revision to ", fg="green"), + click.style(target_revision.revision, fg="cyan", bold=True), + click.style(f": {target_revision.doc}", fg="green"), + ] + ) ) @@ -413,9 +307,7 @@ def new(app, dals, message, head): def history(app, dals, rev_range, verbose): if rev_range is not None: if ":" not in rev_range: - raise Exception( - "History range requires [start]:[end], " - "[start]:, or :[end]") + raise Exception("History range requires [start]:[end], " "[start]:, or :[end]") base, head = rev_range.strip().split(":") else: base = head = None diff --git a/emmett/orm/migrations/engine.py b/emmett/orm/migrations/engine.py index c52f2df4..565e0704 100644 --- a/emmett/orm/migrations/engine.py +++ b/emmett/orm/migrations/engine.py @@ -1,12 +1,12 @@ # -*- coding: utf-8 -*- """ - emmett.orm.migrations.engine - ---------------------------- +emmett.orm.migrations.engine +---------------------------- - Provides migration engine for pyDAL. +Provides migration engine for pyDAL. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ from __future__ import annotations @@ -15,6 +15,7 @@ from ...datastructures import sdict + if TYPE_CHECKING: from .base import Column, Database from .generation import MetaData @@ -48,40 +49,27 @@ def drop_index(self, name, table_name): self.db.drop_index(table_name, name) def create_foreign_key_constraint( - self, - name, - table_name, - column_names, - foreign_table_name, - foreign_keys, - on_delete + self, name, table_name, column_names, foreign_table_name, foreign_keys, on_delete ): self.db.create_foreign_key_constraint( - table_name, - name, - column_names, - foreign_table_name, - foreign_keys, - on_delete + table_name, name, column_names, foreign_table_name, foreign_keys, on_delete ) def drop_foreign_key_constraint(self, name, table_name): self.db.drop_foreign_key_constraint(table_name, name) @staticmethod - def _parse_column_changes( - changes: List[Tuple[str, str, str, Dict[str, Any], Any, Any]] - ) -> Dict[str, List[Any]]: + def _parse_column_changes(changes: List[Tuple[str, str, str, Dict[str, Any], Any, Any]]) -> Dict[str, List[Any]]: rv = {} for change in changes: if change[0] == "modify_type": - rv['type'] = [change[4], change[5], change[3]['existing_length']] + rv["type"] = [change[4], change[5], change[3]["existing_length"]] elif change[0] == "modify_length": - rv['length'] = [change[4], change[5], change[3]['existing_type']] + rv["length"] = [change[4], change[5], change[3]["existing_type"]] elif change[0] == "modify_notnull": - rv['notnull'] = [change[4], change[5]] + rv["notnull"] = [change[4], change[5]] elif change[0] == "modify_default": - rv['default'] = [change[4], change[5], change[3]['existing_type']] + rv["default"] = [change[4], change[5], change[3]["existing_type"]] else: rv[change[0].split("modify_")[-1]] = [change[4], change[5], change[3]] return rv @@ -110,7 +98,7 @@ def create_table(self, name, columns, primary_keys, **kwargs): def drop_table(self, name): adapt_v = sdict(_rname=self.dialect.quote(name)) - sql_list = self.dialect.drop_table(adapt_v, 'cascade') + sql_list = self.dialect.drop_table(adapt_v, "cascade") for sql in sql_list: self._log_and_exec(sql) @@ -123,11 +111,7 @@ def drop_column(self, tablename, colname): self._log_and_exec(sql) def alter_column(self, table_name, column_name, changes): - sql = self._alter_column_sql( - table_name, - column_name, - self._parse_column_changes(changes) - ) + sql = self._alter_column_sql(table_name, column_name, self._parse_column_changes(changes)) if sql is not None: self._log_and_exec(sql) @@ -150,7 +134,7 @@ def create_foreign_key_constraint( column_names: List[str], foreign_table_name: str, foreign_keys: List[str], - on_delete: str + on_delete: str, ): sql = self.dialect.add_foreign_key_constraint( name, table_name, foreign_table_name, column_names, foreign_keys, on_delete @@ -165,104 +149,74 @@ def _gen_reference(self, tablename, column): referenced = column.type[10:].strip() constraint_name = self.dialect.constraint_name(tablename, column.name) try: - rtablename, rfieldname = referenced.split('.') + rtablename, rfieldname = referenced.split(".") except Exception: rtablename = referenced - rfieldname = 'id' + rfieldname = "id" if not rtablename: rtablename = tablename - csql_info = dict( - index_name=self.dialect.quote(column.name + '__idx'), - field_name=self.dialect.quote(column.name), - constraint_name=self.dialect.quote(constraint_name), - foreign_key='%s (%s)' % ( - self.dialect.quote(rtablename), - self.dialect.quote(rfieldname) - ), - on_delete_action=column.ondelete) - csql_info['null'] = ( - ' NOT NULL' if column.notnull else - self.dialect.allow_null - ) - csql_info['unique'] = ' UNIQUE' if column.unique else '' - csql = self.adapter.types['reference'] % csql_info + csql_info = { + "index_name": self.dialect.quote(column.name + "__idx"), + "field_name": self.dialect.quote(column.name), + "constraint_name": self.dialect.quote(constraint_name), + "foreign_key": "%s (%s)" % (self.dialect.quote(rtablename), self.dialect.quote(rfieldname)), + "on_delete_action": column.ondelete, + } + csql_info["null"] = " NOT NULL" if column.notnull else self.dialect.allow_null + csql_info["unique"] = " UNIQUE" if column.unique else "" + csql = self.adapter.types["reference"] % csql_info return csql def _gen_primary_key(self, fields, primary_keys=[]): if primary_keys: - fields.append( - self.dialect.primary_key( - ', '.join([self.dialect.quote(pk) for pk in primary_keys]) - ) - ) + fields.append(self.dialect.primary_key(", ".join([self.dialect.quote(pk) for pk in primary_keys]))) def _gen_geo(self, column_type, geometry_type, srid, dimension): - if not hasattr(self.adapter, 'srid'): - raise RuntimeError('Adapter does not support geometry') + if not hasattr(self.adapter, "srid"): + raise RuntimeError("Adapter does not support geometry") if column_type not in self.adapter.types: - raise SyntaxError( - f'Field: unknown field type: {column_type}' - ) + raise SyntaxError(f"Field: unknown field type: {column_type}") return "{ctype}({gtype},{srid},{dimension})".format( ctype=self.adapter.types[column_type], gtype=geometry_type, srid=srid or self.adapter.srid, - dimension=dimension or 2 + dimension=dimension or 2, ) - def _new_column_sql( - self, - tablename: str, - column: Column, - primary_key: bool = False - ) -> str: - if column.type.startswith('reference'): + def _new_column_sql(self, tablename: str, column: Column, primary_key: bool = False) -> str: + if column.type.startswith("reference"): csql = self._gen_reference(tablename, column) - elif column.type.startswith('list:reference'): + elif column.type.startswith("list:reference"): csql = self.adapter.types[column.type[:14]] - elif column.type.startswith('decimal'): - precision, scale = map(int, column.type[8:-1].split(',')) - csql = self.adapter.types[column.type[:7]] % dict( - precision=precision, scale=scale - ) - elif column.type.startswith('geo'): - csql = self._gen_geo( - column.type, - column.geometry_type, - column.srid, - column.dimension - ) + elif column.type.startswith("decimal"): + precision, scale = map(int, column.type[8:-1].split(",")) + csql = self.adapter.types[column.type[:7]] % {"precision": precision, "scale": scale} + elif column.type.startswith("geo"): + csql = self._gen_geo(column.type, column.geometry_type, column.srid, column.dimension) elif column.type not in self.adapter.types: - raise SyntaxError( - f'Field: unknown field type: {column.type} for {column.nmae}' - ) + raise SyntaxError(f"Field: unknown field type: {column.type} for {column.nmae}") else: - csql = self.adapter.types[column.type] % {'length': column.length} - if self.adapter.dbengine not in ('firebird', 'informix', 'oracle'): + csql = self.adapter.types[column.type] % {"length": column.length} + if self.adapter.dbengine not in ("firebird", "informix", "oracle"): cprops = "%(notnull)s%(default)s%(unique)s%(pk)s%(qualifier)s" else: cprops = "%(default)s%(notnull)s%(unique)s%(pk)s%(qualifier)s" - if not column.type.startswith(('id', 'reference')): + if not column.type.startswith(("id", "reference")): csql += cprops % { - 'notnull': ' NOT NULL' if column.notnull else self.dialect.allow_null, - 'default': ( - ' DEFAULT %s' % self.adapter.represent(column.default, column.type) - if column.default is not None else '' + "notnull": " NOT NULL" if column.notnull else self.dialect.allow_null, + "default": ( + " DEFAULT %s" % self.adapter.represent(column.default, column.type) + if column.default is not None + else "" ), - 'unique': ' UNIQUE' if column.unique else '', - 'pk': ' PRIMARY KEY' if primary_key else '', - 'qualifier': ( - ' %s' % column.custom_qualifier if column.custom_qualifier else '' - ) + "unique": " UNIQUE" if column.unique else "", + "pk": " PRIMARY KEY" if primary_key else "", + "qualifier": (" %s" % column.custom_qualifier if column.custom_qualifier else ""), } return csql def _new_table_sql( - self, - tablename: str, - columns: List[Column], - primary_keys: List[str] = [], - id_col: str ='id' + self, tablename: str, columns: List[Column], primary_keys: List[str] = [], id_col: str = "id" ) -> str: # TODO: # - SQLCustomType @@ -270,39 +224,34 @@ def _new_table_sql( fields = [] for column in columns: csql = self._new_column_sql( - tablename, - column, - primary_key=( - column.name in primary_keys if not composed_primary_key else False - ) + tablename, column, primary_key=(column.name in primary_keys if not composed_primary_key else False) ) - fields.append('%s %s' % (self.dialect.quote(column.name), csql)) + fields.append("%s %s" % (self.dialect.quote(column.name), csql)) # backend-specific extensions to fields - if self.adapter.dbengine == 'mysql': + if self.adapter.dbengine == "mysql": if not primary_keys: primary_keys.append(id_col) elif not composed_primary_key: primary_keys.clear() self._gen_primary_key(fields, primary_keys) - fields = ',\n '.join(fields) + fields = ",\n ".join(fields) return self.dialect.create_table(tablename, fields) def _add_column_sql(self, tablename, column): csql = self._new_column_sql(tablename, column) - return 'ALTER TABLE %(tname)s ADD %(cname)s %(sql)s;' % { - 'tname': self.dialect.quote(tablename), - 'cname': self.dialect.quote(column.name), - 'sql': csql + return "ALTER TABLE %(tname)s ADD %(cname)s %(sql)s;" % { + "tname": self.dialect.quote(tablename), + "cname": self.dialect.quote(column.name), + "sql": csql, } def _drop_column_sql(self, table_name, column_name): if self.adapter.dbengine == "firebird": - sql = 'ALTER TABLE %s DROP %s;' + sql = "ALTER TABLE %s DROP %s;" else: - sql = 'ALTER TABLE %s DROP COLUMN %s;' - return sql % ( - self.dialect.quote(table_name), self.dialect.quote(column_name)) + sql = "ALTER TABLE %s DROP COLUMN %s;" + return sql % (self.dialect.quote(table_name), self.dialect.quote(column_name)) def _represent_changes(self, changes, field): geo_attrs = ("geometry_type", "srid", "dimension") @@ -311,62 +260,46 @@ def _represent_changes(self, changes, field): geo_changes[key] = changes.pop(key) geo_data.update(geo_changes[key][3]) - if 'default' in changes and changes['default'][1] is not None: - ftype = changes['default'][2] or field.type - if 'type' in changes: - ftype = changes['type'][1] - changes['default'][1] = self.adapter.represent( - changes['default'][1], ftype) - if 'type' in changes: - changes.pop('length', None) - coltype = changes['type'][1] - if coltype.startswith('reference'): - raise NotImplementedError( - 'Type change on reference fields is not supported.' - ) - elif coltype.startswith('decimal'): - precision, scale = map(int, coltype[8:-1].split(',')) - csql = self.adapter.types[coltype[:7]] % \ - dict(precision=precision, scale=scale) - elif coltype.startswith('geo'): + if "default" in changes and changes["default"][1] is not None: + ftype = changes["default"][2] or field.type + if "type" in changes: + ftype = changes["type"][1] + changes["default"][1] = self.adapter.represent(changes["default"][1], ftype) + if "type" in changes: + changes.pop("length", None) + coltype = changes["type"][1] + if coltype.startswith("reference"): + raise NotImplementedError("Type change on reference fields is not supported.") + elif coltype.startswith("decimal"): + precision, scale = map(int, coltype[8:-1].split(",")) + csql = self.adapter.types[coltype[:7]] % {"precision": precision, "scale": scale} + elif coltype.startswith("geo"): gen_attrs = [] for key in geo_attrs: - val = ( - geo_changes.get(f"{key}", (None, None))[1] or - geo_data[f"existing_{key}"] - ) + val = geo_changes.get(f"{key}", (None, None))[1] or geo_data[f"existing_{key}"] gen_attrs.append(val) csql = self._gen_geo(coltype, gen_attrs) else: - csql = self.adapter.types[coltype] % { - 'length': changes['type'][2] or field.length - } - changes['type'][1] = csql - elif 'length' in changes: - change = changes.pop('length') + csql = self.adapter.types[coltype] % {"length": changes["type"][2] or field.length} + changes["type"][1] = csql + elif "length" in changes: + change = changes.pop("length") ftype = change[2] or field.type - changes['type'] = [None, self.adapter.types[ftype] % {'length': change[1]}] + changes["type"] = [None, self.adapter.types[ftype] % {"length": change[1]}] elif geo_changes: coltype = geo_data["existing_type"] or field.type gen_attrs = [] for key in geo_attrs: - val = ( - geo_changes.get(f"{key}", (None, None))[1] or - geo_data[f"existing_{key}"] - ) + val = geo_changes.get(f"{key}", (None, None))[1] or geo_data[f"existing_{key}"] gen_attrs.append(val) - changes['type'] = [None, self._gen_geo(coltype, *gen_attrs)] - + changes["type"] = [None, self._gen_geo(coltype, *gen_attrs)] def _alter_column_sql(self, table_name, column_name, changes): - sql = 'ALTER TABLE %(tname)s ALTER COLUMN %(cname)s %(changes)s;' + sql = "ALTER TABLE %(tname)s ALTER COLUMN %(cname)s %(changes)s;" sql_changes_map = { - 'type': "%s" if self.adapter.dbengine in ["mysql", "mssql"] else "TYPE %s", - 'notnull': { - True: "SET NOT NULL", - False: "DROP NOT NULL" - }, - 'default': ["SET DEFAULT %s", "DROP DEFAULT"] + "type": "%s" if self.adapter.dbengine in ["mysql", "mssql"] else "TYPE %s", + "notnull": {True: "SET NOT NULL", False: "DROP NOT NULL"}, + "default": ["SET DEFAULT %s", "DROP DEFAULT"], } field = self.db[table_name][column_name] self._represent_changes(changes, field) @@ -376,16 +309,13 @@ def _alter_column_sql(self, table_name, column_name, changes): if isinstance(change_sql, dict): sql_changes.append(change_sql[change_val[1]]) elif isinstance(change_sql, list): - sql_changes.append( - change_sql[0] % change_val[1] if change_val[1] is not None else - change_sql[1] - ) + sql_changes.append(change_sql[0] % change_val[1] if change_val[1] is not None else change_sql[1]) else: sql_changes.append(change_sql % change_val[1]) if not sql_changes: return None return sql % { - 'tname': self.dialect.quote(table_name), - 'cname': self.dialect.quote(column_name), - 'changes': " ".join(sql_changes) + "tname": self.dialect.quote(table_name), + "cname": self.dialect.quote(column_name), + "changes": " ".join(sql_changes), } diff --git a/emmett/orm/migrations/exceptions.py b/emmett/orm/migrations/exceptions.py index fc1d3198..eb86e00c 100644 --- a/emmett/orm/migrations/exceptions.py +++ b/emmett/orm/migrations/exceptions.py @@ -1,12 +1,12 @@ # -*- coding: utf-8 -*- """ - emmett.orm.migrations.exceptions - -------------------------------- +emmett.orm.migrations.exceptions +-------------------------------- - Provides exceptions for migration operations. +Provides exceptions for migration operations. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ @@ -19,8 +19,7 @@ def __init__(self, lower, upper): self.lower = lower self.upper = upper super(RangeNotAncestorError, self).__init__( - "Revision %s is not an ancestor of revision %s" % - (lower or "base", upper or "base") + "Revision %s is not an ancestor of revision %s" % (lower or "base", upper or "base") ) @@ -29,8 +28,7 @@ def __init__(self, heads, argument): self.heads = heads self.argument = argument super(MultipleHeads, self).__init__( - "Multiple heads are present for given argument '%s'; " - "%s" % (argument, ", ".join(heads)) + "Multiple heads are present for given argument '%s'; " "%s" % (argument, ", ".join(heads)) ) diff --git a/emmett/orm/migrations/generation.py b/emmett/orm/migrations/generation.py index 59ce6467..2ce14a9a 100644 --- a/emmett/orm/migrations/generation.py +++ b/emmett/orm/migrations/generation.py @@ -1,16 +1,16 @@ # -*- coding: utf-8 -*- """ - emmett.orm.migrations.generation - -------------------------------- +emmett.orm.migrations.generation +-------------------------------- - Provides generation utils for migrations. +Provides generation utils for migrations. - :copyright: 2014 Giovanni Barillari +:copyright: 2014 Giovanni Barillari - Based on the code of Alembic (https://bitbucket.org/zzzeek/alembic) - :copyright: (c) 2009-2015 by Michael Bayer +Based on the code of Alembic (https://bitbucket.org/zzzeek/alembic) +:copyright: (c) 2009-2015 by Michael Bayer - :license: BSD-3-Clause +:license: BSD-3-Clause """ from __future__ import annotations @@ -22,7 +22,7 @@ from ...datastructures import OrderedSet from ..objects import Rows, Table from .base import Column, Database -from .helpers import Dispatcher, DEFAULT_VALUE +from .helpers import DEFAULT_VALUE, Dispatcher from .operations import ( AddColumnOp, AlterColumnOp, @@ -36,19 +36,13 @@ MigrationOp, OpContainer, Operation, - UpgradeOps + UpgradeOps, ) from .scripts import ScriptDir class MetaTable: - def __init__( - self, - name: str, - columns: List[Column] = [], - primary_keys: List[str] = [], - **kw: Any - ): + def __init__(self, name: str, columns: List[Column] = [], primary_keys: List[str] = [], **kw: Any): self.name = name self.columns = OrderedDict() for column in columns: @@ -72,25 +66,14 @@ def __delitem__(self, name: str): del self.columns[name] def __repr__(self) -> str: - return "Table(%r, %s)" % ( - self.name, - ", ".join(["%s" % column for column in self.columns.values()]) - ) + return "Table(%r, %s)" % (self.name, ", ".join(["%s" % column for column in self.columns.values()])) def insert(self, *args, **kwargs) -> Any: return None class MetaIndex: - def __init__( - self, - table_name: str, - name: str, - fields: List[str], - expressions: List[str], - unique: bool, - **kw: Any - ): + def __init__(self, table_name: str, name: str, fields: List[str], expressions: List[str], unique: bool, **kw: Any): self.table_name = table_name self.name = name self.fields = fields @@ -100,16 +83,13 @@ def __init__( @property def where(self) -> Optional[str]: - return self.kw.get('where') + return self.kw.get("where") def __repr__(self) -> str: - opts = [('expressions', self.expressions), ('unique', self.unique)] + opts = [("expressions", self.expressions), ("unique", self.unique)] for key, val in self.kw.items(): opts.append((key, val)) - return "Index(%r, %r, %s)" % ( - self.name, self.fields, - ", ".join(["%s=%r" % (opt[0], opt[1]) for opt in opts]) - ) + return "Index(%r, %r, %s)" % (self.name, self.fields, ", ".join(["%s=%r" % (opt[0], opt[1]) for opt in opts])) class MetaForeignKey: @@ -121,7 +101,7 @@ def __init__( foreign_table_name: str, foreign_keys: List[str], on_delete: str, - **kw + **kw, ): self.table_name = table_name self.name = name @@ -150,7 +130,7 @@ def __repr__(self) -> str: self.foreign_table_name, self.column_names, self.foreign_keys, - self.on_delete + self.on_delete, ) @@ -189,13 +169,7 @@ def where(self, *args, **kwargs) -> MetaDataSet: def __call__(self, *args, **kwargs): return self.where(*args, **kwargs) - def create_table( - self, - name: str, - columns: List[Column], - primary_keys: List[str], - **kw: Any - ): + def create_table(self, name: str, columns: List[Column], primary_keys: List[str], **kw: Any): self.tables[name] = MetaTable(name, columns, primary_keys, **kw) def drop_table(self, name: str): @@ -211,13 +185,7 @@ def change_column(self, table_name: str, column_name: str, changes: Dict[str, An self.tables[table_name][column_name].update(**changes) def create_index( - self, - table_name: str, - index_name: str, - fields: List[str], - expressions: List[str], - unique: bool, - **kw: Any + self, table_name: str, index_name: str, fields: List[str], expressions: List[str], unique: bool, **kw: Any ): self.tables[table_name].indexes[index_name] = MetaIndex( table_name, index_name, fields, expressions, unique, **kw @@ -233,15 +201,10 @@ def create_foreign_key_constraint( column_names: List[str], foreign_table_name: str, foreign_keys: List[str], - on_delete: str + on_delete: str, ): self.tables[table_name].foreign_keys[constraint_name] = MetaForeignKey( - table_name, - constraint_name, - column_names, - foreign_table_name, - foreign_keys, - on_delete + table_name, constraint_name, column_names, foreign_table_name, foreign_keys, on_delete ) def drop_foreign_key_constraint(self, table_name: str, constraint_name: str): @@ -261,10 +224,8 @@ def make_ops(self) -> List[Operation]: def _build_metatable(self, dbtable: Table): return MetaTable( dbtable._tablename, - [ - Column.from_field(field) for field in list(dbtable) - ], - primary_keys=list(dbtable._primary_keys) + [Column.from_field(field) for field in list(dbtable)], + primary_keys=list(dbtable._primary_keys), ) def _build_metaindex(self, dbtable: Table, index_name: str) -> MetaIndex: @@ -272,15 +233,15 @@ def _build_metaindex(self, dbtable: Table, index_name: str) -> MetaIndex: dbindex = model._indexes_[index_name] kw = {} with self.db._adapter.index_expander(): - if 'where' in dbindex: - kw['where'] = str(dbindex['where']) + if "where" in dbindex: + kw["where"] = str(dbindex["where"]) rv = MetaIndex( model.tablename, index_name, - [field for field in dbindex['fields']], - [str(expr) for expr in dbindex['expressions']], - dbindex['unique'], - **kw + list(dbindex["fields"]), + [str(expr) for expr in dbindex["expressions"]], + dbindex["unique"], + **kw, ) return rv @@ -288,12 +249,7 @@ def _build_metafk(self, dbtable: Table, fk_name: str) -> MetaForeignKey: model = dbtable._model_ dbfk = model._foreign_keys_[fk_name] return MetaForeignKey( - model.tablename, - fk_name, - dbfk['fields_local'], - dbfk['table'], - dbfk['fields_foreign'], - dbfk['on_delete'] + model.tablename, fk_name, dbfk["fields_local"], dbfk["table"], dbfk["fields_foreign"], dbfk["on_delete"] ) def tables(self): @@ -322,64 +278,37 @@ def table(self, dbtable: Table, metatable: MetaTable): self.indexes_and_uniques(dbtable, metatable) self.foreign_keys(dbtable, metatable) - def indexes_and_uniques( - self, - dbtable: Table, - metatable: MetaTable, - ops_stack: Optional[List[Operation]] = None - ): + def indexes_and_uniques(self, dbtable: Table, metatable: MetaTable, ops_stack: Optional[List[Operation]] = None): ops = ops_stack if ops_stack is not None else self.ops - db_index_names = OrderedSet( - [idxname for idxname in dbtable._model_._indexes_.keys()] - ) + db_index_names = OrderedSet(list(dbtable._model_._indexes_.keys())) meta_index_names = OrderedSet(list(metatable.indexes)) #: removed indexes for index_name in meta_index_names.difference(db_index_names): ops.append(DropIndexOp.from_index(metatable.indexes[index_name])) #: new indexs for index_name in db_index_names.difference(meta_index_names): - ops.append( - CreateIndexOp.from_index( - self._build_metaindex(dbtable, index_name) - ) - ) + ops.append(CreateIndexOp.from_index(self._build_metaindex(dbtable, index_name))) #: existing indexes for index_name in meta_index_names.intersection(db_index_names): metaindex = metatable.indexes[index_name] dbindex = self._build_metaindex(dbtable, index_name) if any( - getattr(metaindex, key) != getattr(dbindex, key) - for key in ['fields', 'expressions', 'unique', 'kw'] + getattr(metaindex, key) != getattr(dbindex, key) for key in ["fields", "expressions", "unique", "kw"] ): ops.append(DropIndexOp.from_index(metaindex)) ops.append(CreateIndexOp.from_index(dbindex)) # TODO: uniques - def foreign_keys( - self, - dbtable: Table, - metatable: MetaTable, - ops_stack: Optional[List[Operation]] = None - ): + def foreign_keys(self, dbtable: Table, metatable: MetaTable, ops_stack: Optional[List[Operation]] = None): ops = ops_stack if ops_stack is not None else self.ops - db_fk_names = OrderedSet( - [fkname for fkname in dbtable._model_._foreign_keys_.keys()] - ) + db_fk_names = OrderedSet(list(dbtable._model_._foreign_keys_.keys())) meta_fk_names = OrderedSet(list(metatable.foreign_keys)) #: removed fks for fk_name in meta_fk_names.difference(db_fk_names): - ops.append( - DropForeignKeyConstraintOp.from_foreign_key( - metatable.foreign_keys[fk_name] - ) - ) + ops.append(DropForeignKeyConstraintOp.from_foreign_key(metatable.foreign_keys[fk_name])) #: new fks for fk_name in db_fk_names.difference(meta_fk_names): - ops.append( - CreateForeignKeyConstraintOp.from_foreign_key( - self._build_metafk(dbtable, fk_name) - ) - ) + ops.append(CreateForeignKeyConstraintOp.from_foreign_key(self._build_metafk(dbtable, fk_name))) #: existing fks for fk_name in meta_fk_names.intersection(db_fk_names): metafk = metatable.foreign_keys[fk_name] @@ -389,29 +318,22 @@ def foreign_keys( ops.append(CreateForeignKeyConstraintOp.from_foreign_key(dbfk)) def columns(self, dbtable: Table, metatable: MetaTable): - db_column_names = OrderedSet([fname for fname in dbtable.fields]) + db_column_names = OrderedSet(list(dbtable.fields)) meta_column_names = OrderedSet(metatable.fields) #: new columns for column_name in db_column_names.difference(meta_column_names): - self.ops.append(AddColumnOp.from_column_and_tablename( - dbtable._tablename, Column.from_field(dbtable[column_name]) - )) + self.ops.append( + AddColumnOp.from_column_and_tablename(dbtable._tablename, Column.from_field(dbtable[column_name])) + ) #: existing columns for column_name in meta_column_names.intersection(db_column_names): self.ops.append(AlterColumnOp(dbtable._tablename, column_name)) - self.column( - Column.from_field(dbtable[column_name]), - metatable.columns[column_name] - ) + self.column(Column.from_field(dbtable[column_name]), metatable.columns[column_name]) if not self.ops[-1].has_changes(): self.ops.pop() #: removed columns for column_name in meta_column_names.difference(db_column_names): - self.ops.append( - DropColumnOp.from_column_and_tablename( - dbtable._tablename, metatable.columns[column_name] - ) - ) + self.ops.append(DropColumnOp.from_column_and_tablename(dbtable._tablename, metatable.columns[column_name])) def column(self, dbcolumn: Column, metacolumn: Column): self.notnulls(dbcolumn, metacolumn) @@ -431,9 +353,7 @@ def types(self, dbcolumn: Column, metacolumn: Column): def lengths(self, dbcolumn: Column, metacolumn: Column): self.ops[-1].existing_length = metacolumn.length - if any( - field.type == "string" for field in [dbcolumn, metacolumn] - ) and dbcolumn.length != metacolumn.length: + if any(field.type == "string" for field in [dbcolumn, metacolumn]) and dbcolumn.length != metacolumn.length: self.ops[-1].modify_length = dbcolumn.length def notnulls(self, dbcolumn: Column, metacolumn: Column): @@ -461,12 +381,8 @@ def __init__(self, db: Database, scriptdir: ScriptDir, head: str): self._load_head_to_meta() def _load_head_to_meta(self): - for revision in reversed( - list(self.scriptdir.walk_revisions("base", self.head)) - ): - migration = revision.migration_class( - None, self.meta, is_meta=True - ) + for revision in reversed(list(self.scriptdir.walk_revisions("base", self.head))): + migration = revision.migration_class(None, self.meta, is_meta=True) if migration.skip_on_compare: continue migration.up() @@ -475,12 +391,7 @@ def generate(self) -> UpgradeOps: return Comparator.compare(self.db, self.meta) @classmethod - def generate_from( - cls, - dal: Database, - scriptdir: ScriptDir, - head: str - ) -> UpgradeOps: + def generate_from(cls, dal: Database, scriptdir: ScriptDir, head: str) -> UpgradeOps: return cls(dal, scriptdir, head).generate() @@ -501,10 +412,7 @@ def render_opcontainer(self, op_container: OpContainer) -> List[str]: @classmethod def render_migration(cls, migration_op: MigrationOp): r = cls() - return ( - r.render_opcontainer(migration_op.upgrade_ops), - r.render_opcontainer(migration_op.downgrade_ops) - ) + return (r.render_opcontainer(migration_op.upgrade_ops), r.render_opcontainer(migration_op.downgrade_ops)) renderers = Dispatcher() @@ -514,10 +422,7 @@ def render_migration(cls, migration_op: MigrationOp): def _add_table(op: CreateTableOp) -> str: table = op.to_table() - args = [ - col for col in [_render_column(col) for col in table.columns.values()] - if col - ] + args = [col for col in [_render_column(col) for col in table.columns.values()] if col] # + sorted([ # rcons for rcons in [ # _render_constraint(cons) for cons in table.constraints] @@ -527,18 +432,19 @@ def _add_table(op: CreateTableOp) -> str: indent = " " * 12 if len(args) > 255: - args = '*[' + (',\n' + indent).join(args) + ']' + args = "*[" + (",\n" + indent).join(args) + "]" else: - args = (',\n' + indent).join(args) + args = (",\n" + indent).join(args) text = ( - "self.create_table(\n" + indent + "%(tablename)r,\n" + indent + "%(args)s,\n" + - indent + "primary_keys=%(primary_keys)r" - ) % { - 'tablename': op.table_name, - 'args': args, - 'primary_keys': table.primary_keys - } + "self.create_table(\n" + + indent + + "%(tablename)r,\n" + + indent + + "%(args)s,\n" + + indent + + "primary_keys=%(primary_keys)r" + ) % {"tablename": op.table_name, "args": args, "primary_keys": table.primary_keys} for k in sorted(op.kw): text += ",\n" + indent + "%s=%r" % (k.replace(" ", "_"), op.kw[k]) text += ")" @@ -547,9 +453,7 @@ def _add_table(op: CreateTableOp) -> str: @renderers.dispatch_for(DropTableOp) def _drop_table(op: DropTableOp) -> str: - text = "self.drop_table(%(tname)r" % { - "tname": op.table_name - } + text = "self.drop_table(%(tname)r" % {"tname": op.table_name} text += ")" return text @@ -568,7 +472,7 @@ def _render_column(column: Column) -> str: if column.type in ("string", "password", "upload"): opts.append(("length", column.length)) - elif column.type.startswith('reference'): + elif column.type.startswith("reference"): opts.append(("ondelete", column.ondelete)) elif column.type.startswith("geo"): for key in ("geometry_type", "srid", "dimension"): @@ -577,37 +481,24 @@ def _render_column(column: Column) -> str: kw_str = "" if opts: - kw_str = ", %s" % \ - ", ".join(["%s=%r" % (key, val) for key, val in opts]) - return "migrations.Column(%(name)r, %(type)r%(kw)s)" % { - 'name': column.name, - 'type': column.type, - 'kw': kw_str - } + kw_str = ", %s" % ", ".join(["%s=%r" % (key, val) for key, val in opts]) + return "migrations.Column(%(name)r, %(type)r%(kw)s)" % {"name": column.name, "type": column.type, "kw": kw_str} @renderers.dispatch_for(AddColumnOp) def _add_column(op: AddColumnOp) -> str: - return "self.add_column(%(tname)r, %(column)s)" % { - "tname": op.table_name, - "column": _render_column(op.column) - } + return "self.add_column(%(tname)r, %(column)s)" % {"tname": op.table_name, "column": _render_column(op.column)} @renderers.dispatch_for(DropColumnOp) def _drop_column(op: DropTableOp) -> str: - return "self.drop_column(%(tname)r, %(cname)r)" % { - "tname": op.table_name, - "cname": op.column_name - } + return "self.drop_column(%(tname)r, %(cname)r)" % {"tname": op.table_name, "cname": op.column_name} @renderers.dispatch_for(AlterColumnOp) def _alter_column(op: AlterColumnOp) -> str: indent = " " * 12 - text = "self.alter_column(%(tname)r, %(cname)r" % { - 'tname': op.table_name, - 'cname': op.column_name} + text = "self.alter_column(%(tname)r, %(cname)r" % {"tname": op.table_name, "cname": op.column_name} if op.existing_type is not None: text += ",\n%sexisting_type=%r" % (indent, op.existing_type) @@ -637,47 +528,31 @@ def _alter_column(op: AlterColumnOp) -> str: def _add_index(op: CreateIndexOp) -> str: kw_str = "" if op.kw: - kw_str = ", %s" % ", ".join( - ["%s=%r" % (key, val) for key, val in op.kw.items()]) + kw_str = ", %s" % ", ".join(["%s=%r" % (key, val) for key, val in op.kw.items()]) return "self.create_index(%(iname)r, %(tname)r, %(idata)s)" % { "tname": op.table_name, "iname": op.index_name, - "idata": "%r, expressions=%r, unique=%s%s" % ( - op.fields, op.expressions, op.unique, kw_str) + "idata": "%r, expressions=%r, unique=%s%s" % (op.fields, op.expressions, op.unique, kw_str), } @renderers.dispatch_for(DropIndexOp) def _drop_index(op: DropIndexOp) -> str: - return "self.drop_index(%(iname)r, %(tname)r)" % { - "tname": op.table_name, - "iname": op.index_name - } + return "self.drop_index(%(iname)r, %(tname)r)" % {"tname": op.table_name, "iname": op.index_name} @renderers.dispatch_for(CreateForeignKeyConstraintOp) def _add_fk_constraint(op: CreateForeignKeyConstraintOp) -> str: kw_str = "" if op.kw: - kw_str = ", %s" % ", ".join( - ["%s=%r" % (key, val) for key, val in op.kw.items()] - ) + kw_str = ", %s" % ", ".join(["%s=%r" % (key, val) for key, val in op.kw.items()]) return "self.create_foreign_key(%s%s)" % ( - "%r, %r, %r, %r, %r, on_delete=%r" % ( - op.constraint_name, - op.table_name, - op.foreign_table_name, - op.column_names, - op.foreign_keys, - op.on_delete - ), - kw_str + "%r, %r, %r, %r, %r, on_delete=%r" + % (op.constraint_name, op.table_name, op.foreign_table_name, op.column_names, op.foreign_keys, op.on_delete), + kw_str, ) @renderers.dispatch_for(DropForeignKeyConstraintOp) def _drop_fk_constraint(op: DropForeignKeyConstraintOp) -> str: - return "self.drop_foreign_key(%(cname)r, %(tname)r)" % { - "tname": op.table_name, - "cname": op.constraint_name - } + return "self.drop_foreign_key(%(cname)r, %(tname)r)" % {"tname": op.table_name, "cname": op.constraint_name} diff --git a/emmett/orm/migrations/helpers.py b/emmett/orm/migrations/helpers.py index 0f804b7b..fd9ed6a6 100644 --- a/emmett/orm/migrations/helpers.py +++ b/emmett/orm/migrations/helpers.py @@ -1,12 +1,12 @@ # -*- coding: utf-8 -*- """ - emmett.orm.migrations.helpers - ----------------------------- +emmett.orm.migrations.helpers +----------------------------- - Provides helpers for migrations. +Provides helpers for migrations. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ from __future__ import annotations @@ -21,6 +21,7 @@ from ...datastructures import _unique_list from .base import Database + if TYPE_CHECKING: from .engine import MetaEngine from .operations import Operation @@ -50,12 +51,12 @@ def __init__(self): self._registry: Dict[Type[Operation], Callable[[Operation], str]] = {} def dispatch_for( - self, - target: Type[Operation] + self, target: Type[Operation] ) -> Callable[[Callable[[Operation], str]], Callable[[Operation], str]]: def wrap(fn: Callable[[Operation], str]) -> Callable[[Operation], str]: self._registry[target] = fn return fn + return wrap def dispatch(self, obj: Operation): @@ -98,11 +99,11 @@ def to_tuple(x, default=None): if x is None: return default elif isinstance(x, str): - return (x, ) + return (x,) elif isinstance(x, Iterable): return tuple(x) else: - return (x, ) + return (x,) def tuple_or_value(val): diff --git a/emmett/orm/migrations/operations.py b/emmett/orm/migrations/operations.py index b2558c6b..965e60b3 100644 --- a/emmett/orm/migrations/operations.py +++ b/emmett/orm/migrations/operations.py @@ -1,30 +1,30 @@ # -*- coding: utf-8 -*- """ - emmett.orm.migrations.operations - -------------------------------- +emmett.orm.migrations.operations +-------------------------------- - Provides operations handlers for migrations. +Provides operations handlers for migrations. - :copyright: 2014 Giovanni Barillari +:copyright: 2014 Giovanni Barillari - Based on the code of Alembic (https://bitbucket.org/zzzeek/alembic) - :copyright: (c) 2009-2015 by Michael Bayer +Based on the code of Alembic (https://bitbucket.org/zzzeek/alembic) +:copyright: (c) 2009-2015 by Michael Bayer - :license: BSD-3-Clause +:license: BSD-3-Clause """ from __future__ import annotations import re - from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple -from .base import Migration, Column +from .base import Column, Migration from .helpers import DEFAULT_VALUE + if TYPE_CHECKING: from .engine import MetaEngine - from .generation import MetaTable, MetaIndex, MetaForeignKey + from .generation import MetaForeignKey, MetaIndex, MetaTable class Operation: @@ -52,7 +52,7 @@ def as_diffs(self): @classmethod def _ops_as_diffs(cls, migrations): for op in migrations.ops: - if hasattr(op, 'ops'): + if hasattr(op, "ops"): for sub_op in cls._ops_as_diffs(op): yield sub_op else: @@ -66,10 +66,7 @@ def __init__(self, table_name: str, ops: List[Operation]): self.table_name = table_name def reverse(self) -> ModifyTableOps: - return ModifyTableOps( - self.table_name, - ops=list(reversed([op.reverse() for op in self.ops])) - ) + return ModifyTableOps(self.table_name, ops=list(reversed([op.reverse() for op in self.ops]))) class UpgradeOps(OpContainer): @@ -79,9 +76,7 @@ def __init__(self, ops: List[Operation] = [], upgrade_token: str = "upgrades"): self.upgrade_token = upgrade_token def reverse(self) -> DowngradeOps: - return DowngradeOps( - ops=list(reversed([op.reverse() for op in self.ops])) - ) + return DowngradeOps(ops=list(reversed([op.reverse() for op in self.ops]))) class DowngradeOps(OpContainer): @@ -91,9 +86,7 @@ def __init__(self, ops: List[Operation] = [], downgrade_token: str = "downgrades self.downgrade_token = downgrade_token def reverse(self): - return UpgradeOps( - ops=list(reversed([op.reverse() for op in self.ops])) - ) + return UpgradeOps(ops=list(reversed([op.reverse() for op in self.ops]))) class MigrationOp(Operation): @@ -104,7 +97,7 @@ def __init__( downgrade_ops: DowngradeOps, message: Optional[str] = None, head: Optional[str] = None, - splice: Any = None + splice: Any = None, ): self.rev_id = rev_id self.message = message @@ -122,7 +115,7 @@ def __init__( columns: List[Column], primary_keys: List[str] = [], _orig_table: Optional[MetaTable] = None, - **kw: Any + **kw: Any, ): self.table_name = table_name self.columns = columns @@ -139,45 +132,28 @@ def to_diff_tuple(self) -> Tuple[str, MetaTable]: @classmethod def from_table(cls, table: MetaTable) -> CreateTableOp: return cls( - table.name, - [table[colname] for colname in table.fields], - list(table.primary_keys), - _orig_table=table + table.name, [table[colname] for colname in table.fields], list(table.primary_keys), _orig_table=table ) def to_table(self, migration_context: Any = None) -> MetaTable: if self._orig_table is not None: return self._orig_table from .generation import MetaTable - return MetaTable( - self.table_name, - self.columns, - self.primary_keys, - **self.kw - ) + + return MetaTable(self.table_name, self.columns, self.primary_keys, **self.kw) @classmethod - def create_table( - cls, - table_name: str, - *columns: Column, - **kw: Any - ) -> CreateTableOp: + def create_table(cls, table_name: str, *columns: Column, **kw: Any) -> CreateTableOp: return cls(table_name, columns, **kw) def run(self): - self.engine.create_table( - self.table_name, self.columns, self.primary_keys, **self.kw - ) + self.engine.create_table(self.table_name, self.columns, self.primary_keys, **self.kw) @Migration.register_operation("drop_table") class DropTableOp(Operation): def __init__( - self, - table_name: str, - table_kw: Optional[Dict[str, Any]] = None, - _orig_table: Optional[MetaTable] = None + self, table_name: str, table_kw: Optional[Dict[str, Any]] = None, _orig_table: Optional[MetaTable] = None ): self.table_name = table_name self.table_kw = table_kw or {} @@ -188,9 +164,7 @@ def to_diff_tuple(self) -> Tuple[str, MetaTable]: def reverse(self) -> CreateTableOp: if self._orig_table is None: - raise ValueError( - "operation is not reversible; original table is not present" - ) + raise ValueError("operation is not reversible; original table is not present") return CreateTableOp.from_table(self._orig_table) @classmethod @@ -201,10 +175,8 @@ def to_table(self) -> MetaTable: if self._orig_table is not None: return self._orig_table from .generation import MetaTable - return MetaTable( - self.table_name, - **self.table_kw - ) + + return MetaTable(self.table_name, **self.table_kw) @classmethod def drop_table(cls, table_name: str, **kw: Any) -> DropTableOp: @@ -230,7 +202,7 @@ def rename_table(cls, old_table_name: str, new_table_name: str) -> RenameTableOp return cls(old_table_name, new_table_name) def run(self): - raise NotImplementedError('Table renaming is currently not supported.') + raise NotImplementedError("Table renaming is currently not supported.") @Migration.register_operation("add_column") @@ -262,13 +234,7 @@ def run(self): @Migration.register_operation("drop_column") class DropColumnOp(AlterTableOp): - def __init__( - self, - table_name: str, - column_name: str, - _orig_column: Optional[Column] = None, - **kw: Any - ): + def __init__(self, table_name: str, column_name: str, _orig_column: Optional[Column] = None, **kw: Any): super().__init__(table_name) self.column_name = column_name self.kw = kw @@ -279,13 +245,9 @@ def to_diff_tuple(self) -> Tuple[str, str, Column]: def reverse(self) -> AddColumnOp: if self._orig_column is None: - raise ValueError( - "operation is not reversible; original column is not present" - ) + raise ValueError("operation is not reversible; original column is not present") - return AddColumnOp.from_column_and_tablename( - self.table_name, self._orig_column - ) + return AddColumnOp.from_column_and_tablename(self.table_name, self._orig_column) @classmethod def from_column_and_tablename(cls, tname: str, col: Column) -> DropColumnOp: @@ -319,7 +281,7 @@ def __init__( modify_name: Optional[str] = None, modify_type: Optional[str] = None, modify_length: Optional[int] = None, - **kw: Any + **kw: Any, ): super().__init__(table_name) self.column_name = column_name @@ -348,13 +310,10 @@ def to_diff_tuple(self) -> List[Tuple[str, str, str, Dict[str, Any], Any, Any]]: "existing_length": self.existing_length, "existing_notnull": self.existing_notnull, "existing_default": self.existing_default, - **{ - nkey: nval for nkey, nval in self.kw.items() - if nkey.startswith('existing_') - } + **{nkey: nval for nkey, nval in self.kw.items() if nkey.startswith("existing_")}, }, self.existing_type, - self.modify_type + self.modify_type, ) ) @@ -367,10 +326,10 @@ def to_diff_tuple(self) -> List[Tuple[str, str, str, Dict[str, Any], Any, Any]]: { "existing_type": self.existing_type, "existing_notnull": self.existing_notnull, - "existing_default": self.existing_default + "existing_default": self.existing_default, }, self.existing_length, - self.modify_length + self.modify_length, ) ) @@ -380,12 +339,9 @@ def to_diff_tuple(self) -> List[Tuple[str, str, str, Dict[str, Any], Any, Any]]: "modify_notnull", tname, cname, - { - "existing_type": self.existing_type, - "existing_default": self.existing_default - }, + {"existing_type": self.existing_type, "existing_default": self.existing_default}, self.existing_notnull, - self.modify_notnull + self.modify_notnull, ) ) @@ -395,12 +351,9 @@ def to_diff_tuple(self) -> List[Tuple[str, str, str, Dict[str, Any], Any, Any]]: "modify_default", tname, cname, - { - "existing_notnull": self.existing_notnull, - "existing_type": self.existing_type - }, + {"existing_notnull": self.existing_notnull, "existing_type": self.existing_type}, self.existing_default, - self.modify_default + self.modify_default, ) ) @@ -414,13 +367,10 @@ def to_diff_tuple(self) -> List[Tuple[str, str, str, Dict[str, Any], Any, Any]]: cname, { "existing_type": self.existing_type, - **{ - nkey: nval for nkey, nval in self.kw.items() - if nkey.startswith('existing_') - } + **{nkey: nval for nkey, nval in self.kw.items() if nkey.startswith("existing_")}, }, self.kw.get(f"existing_{attr}"), - val + val, ) ) @@ -428,47 +378,42 @@ def to_diff_tuple(self) -> List[Tuple[str, str, str, Dict[str, Any], Any, Any]]: def has_changes(self) -> bool: hc = ( - self.modify_notnull is not None or - self.modify_default is not DEFAULT_VALUE or - self.modify_type is not None or - self.modify_length is not None + self.modify_notnull is not None + or self.modify_default is not DEFAULT_VALUE + or self.modify_type is not None + or self.modify_length is not None ) if hc: return True for kw in self.kw: - if kw.startswith('modify_'): + if kw.startswith("modify_"): return True return False def reverse(self) -> AlterColumnOp: kw = self.kw.copy() - kw['existing_type'] = self.existing_type - kw['existing_length'] = self.existing_length - kw['existing_notnull'] = self.existing_notnull - kw['existing_default'] = self.existing_default + kw["existing_type"] = self.existing_type + kw["existing_length"] = self.existing_length + kw["existing_notnull"] = self.existing_notnull + kw["existing_default"] = self.existing_default if self.modify_type is not None: - kw['modify_type'] = self.modify_type + kw["modify_type"] = self.modify_type if self.modify_length is not None: - kw['modify_length'] = self.modify_length + kw["modify_length"] = self.modify_length if self.modify_notnull is not None: - kw['modify_notnull'] = self.modify_notnull + kw["modify_notnull"] = self.modify_notnull if self.modify_default is not DEFAULT_VALUE: - kw['modify_default'] = self.modify_default + kw["modify_default"] = self.modify_default - all_keys = set(m.group(1) for m in [ - re.match(r'^(?:existing_|modify_)(.+)$', k) - for k in kw - ] if m) + all_keys = {m.group(1) for m in [re.match(r"^(?:existing_|modify_)(.+)$", k) for k in kw] if m} for k in all_keys: - if 'modify_%s' % k in kw: - swap = kw['existing_%s' % k] - kw['existing_%s' % k] = kw['modify_%s' % k] - kw['modify_%s' % k] = swap + if "modify_%s" % k in kw: + swap = kw["existing_%s" % k] + kw["existing_%s" % k] = kw["modify_%s" % k] + kw["modify_%s" % k] = swap - return self.__class__( - self.table_name, self.column_name, **kw - ) + return self.__class__(self.table_name, self.column_name, **kw) @classmethod def alter_column( @@ -484,7 +429,7 @@ def alter_column( existing_length: Optional[int] = None, existing_default: Any = None, existing_notnull: Optional[bool] = None, - **kw: Any + **kw: Any, ) -> AlterColumnOp: return cls( table_name, @@ -498,13 +443,11 @@ def alter_column( modify_length=length, modify_default=default, modify_notnull=notnull, - **kw + **kw, ) def run(self): - self.engine.alter_column( - self.table_name, self.column_name, self.to_diff_tuple() - ) + self.engine.alter_column(self.table_name, self.column_name, self.to_diff_tuple()) @Migration.register_operation("create_index") @@ -517,7 +460,7 @@ def __init__( expressions: List[str] = [], unique: bool = False, _orig_index: Optional[MetaIndex] = None, - **kw: Any + **kw: Any, ): self.index_name = index_name self.table_name = table_name @@ -536,22 +479,15 @@ def to_diff_tuple(self) -> Tuple[str, MetaIndex]: @classmethod def from_index(cls, index: MetaIndex) -> CreateIndexOp: return cls( - index.name, index.table_name, index.fields, index.expressions, - index.unique, _orig_index=index, **index.kw + index.name, index.table_name, index.fields, index.expressions, index.unique, _orig_index=index, **index.kw ) def to_index(self) -> MetaIndex: if self._orig_index is not None: return self._orig_index from .generation import MetaIndex - return MetaIndex( - self.table_name, - self.index_name, - self.fields, - self.expressions, - self.unique, - **self.kw - ) + + return MetaIndex(self.table_name, self.index_name, self.fields, self.expressions, self.unique, **self.kw) @classmethod def create_index( @@ -561,29 +497,19 @@ def create_index( fields: List[str] = [], expressions: List[str] = [], unique: bool = False, - **kw: Any + **kw: Any, ) -> CreateIndexOp: return cls(index_name, table_name, fields, expressions, unique, **kw) def run(self): self.engine.create_index( - self.index_name, - self.table_name, - self.fields, - self.expressions, - self.unique, - **self.kw + self.index_name, self.table_name, self.fields, self.expressions, self.unique, **self.kw ) @Migration.register_operation("drop_index") class DropIndexOp(Operation): - def __init__( - self, - index_name: str, - table_name: Optional[str] = None, - _orig_index: Optional[MetaIndex] = None - ): + def __init__(self, index_name: str, table_name: Optional[str] = None, _orig_index: Optional[MetaIndex] = None): self.index_name = index_name self.table_name = table_name self._orig_index = _orig_index @@ -593,9 +519,7 @@ def to_diff_tuple(self) -> Tuple[str, MetaIndex]: def reverse(self) -> CreateIndexOp: if self._orig_index is None: - raise ValueError( - "operation is not reversible; original index is not present" - ) + raise ValueError("operation is not reversible; original index is not present") return CreateIndexOp.from_index(self._orig_index) @classmethod @@ -606,6 +530,7 @@ def to_index(self) -> MetaIndex: if self._orig_index is not None: return self._orig_index from .generation import MetaIndex + return MetaIndex(self.table_name, self.index_name, [], [], False) @classmethod @@ -627,7 +552,7 @@ def __init__( foreign_keys: List[str], on_delete: str, _orig_fk: Optional[MetaForeignKey] = None, - **kw: Any + **kw: Any, ): super().__init__(table_name) self.constraint_name = name @@ -647,10 +572,7 @@ def to_diff_tuple(self) -> Tuple[str, MetaForeignKey]: return ("create_fk_constraint", self.to_foreign_key()) @classmethod - def from_foreign_key( - cls, - foreign_key: MetaForeignKey - ) -> CreateForeignKeyConstraintOp: + def from_foreign_key(cls, foreign_key: MetaForeignKey) -> CreateForeignKeyConstraintOp: return cls( foreign_key.name, foreign_key.table_name, @@ -658,7 +580,7 @@ def from_foreign_key( foreign_key.column_names, foreign_key.foreign_keys, foreign_key.on_delete, - _orig_fk=foreign_key + _orig_fk=foreign_key, ) def to_foreign_key(self) -> MetaForeignKey: @@ -666,13 +588,14 @@ def to_foreign_key(self) -> MetaForeignKey: return self._orig_fk from .generation import MetaForeignKey + return MetaForeignKey( self.table_name, self.constraint_name, self.column_names, self.foreign_table_name, self.foreign_keys, - self.on_delete + self.on_delete, ) @classmethod @@ -683,7 +606,7 @@ def create_foreign_key( foreign_table_name: str, column_names: List[str], foreign_keys: List[str], - on_delete: str + on_delete: str, ) -> CreateForeignKeyConstraintOp: return cls( name=name, @@ -691,7 +614,7 @@ def create_foreign_key( foreign_table_name=foreign_table_name, column_names=column_names, foreign_keys=foreign_keys, - on_delete=on_delete + on_delete=on_delete, ) def run(self): @@ -701,19 +624,13 @@ def run(self): self.column_names, self.foreign_table_name, self.foreign_keys, - self.on_delete + self.on_delete, ) @Migration.register_operation("drop_foreign_key") class DropForeignKeyConstraintOp(AlterTableOp): - def __init__( - self, - name: str, - table_name: str, - _orig_fk: Optional[MetaForeignKey] = None, - **kw: Any - ): + def __init__(self, name: str, table_name: str, _orig_fk: Optional[MetaForeignKey] = None, **kw: Any): super().__init__(table_name) self.constraint_name = name self.kw = kw @@ -721,38 +638,26 @@ def __init__( def reverse(self) -> CreateForeignKeyConstraintOp: if self._orig_fk is None: - raise ValueError( - "operation is not reversible; original constraint is not present" - ) + raise ValueError("operation is not reversible; original constraint is not present") return CreateForeignKeyConstraintOp.from_foreign_key(self._orig_fk) def to_diff_tuple(self) -> Tuple[str, MetaForeignKey]: return ("drop_fk_constraint", self.to_foreign_key()) @classmethod - def from_foreign_key( - cls, - foreign_key: MetaForeignKey - ) -> DropForeignKeyConstraintOp: - return cls( - foreign_key.name, - foreign_key.table_name, - _orig_fk=foreign_key - ) + def from_foreign_key(cls, foreign_key: MetaForeignKey) -> DropForeignKeyConstraintOp: + return cls(foreign_key.name, foreign_key.table_name, _orig_fk=foreign_key) def to_foreign_key(self): if self._orig_fk is not None: return self._orig_fk from .generation import MetaForeignKey - return MetaForeignKey(self.table_name, self.constraint_name, [], '', [], '') + + return MetaForeignKey(self.table_name, self.constraint_name, [], "", [], "") @classmethod - def drop_foreign_key( - cls, - name: str, - table_name: str - ) -> DropForeignKeyConstraintOp: + def drop_foreign_key(cls, name: str, table_name: str) -> DropForeignKeyConstraintOp: return DropForeignKeyConstraintOp(name, table_name) def run(self): diff --git a/emmett/orm/migrations/revisions.py b/emmett/orm/migrations/revisions.py index 2dac4b36..56fd4edd 100644 --- a/emmett/orm/migrations/revisions.py +++ b/emmett/orm/migrations/revisions.py @@ -1,16 +1,16 @@ # -*- coding: utf-8 -*- """ - emmett.orm.migrations.revisions - ------------------------------- +emmett.orm.migrations.revisions +------------------------------- - Provides revisions logic for migrations. +Provides revisions logic for migrations. - :copyright: 2014 Giovanni Barillari +:copyright: 2014 Giovanni Barillari - Based on the code of Alembic (https://bitbucket.org/zzzeek/alembic) - :copyright: (c) 2009-2015 by Michael Bayer +Based on the code of Alembic (https://bitbucket.org/zzzeek/alembic) +:copyright: (c) 2009-2015 by Michael Bayer - :license: BSD-3-Clause +:license: BSD-3-Clause """ from collections import deque @@ -18,10 +18,8 @@ from emmett_core.utils import cachedprop from ...datastructures import OrderedSet -from .helpers import to_tuple, tuple_or_value, dedupe_tuple -from .exceptions import ( - RevisionError, RangeNotAncestorError, ResolutionError, MultipleHeads -) +from .exceptions import MultipleHeads, RangeNotAncestorError, ResolutionError, RevisionError +from .helpers import dedupe_tuple, to_tuple, tuple_or_value class Revision(object): @@ -29,31 +27,24 @@ class Revision(object): _all_nextrev = frozenset() revision = None down_revision = None - #dependencies = None - #branch_labels = None + # dependencies = None + # branch_labels = None - def __init__(self, revision, down_revision, - dependencies=None, branch_labels=None): + def __init__(self, revision, down_revision, dependencies=None, branch_labels=None): self.revision = revision self.down_revision = tuple_or_value(down_revision) - #self.dependencies = tuple_or_value(dependencies) - #self._resolved_dependencies = () - #self._orig_branch_labels = to_tuple(branch_labels, default=()) - #self.branch_labels = set(self._orig_branch_labels) + # self.dependencies = tuple_or_value(dependencies) + # self._resolved_dependencies = () + # self._orig_branch_labels = to_tuple(branch_labels, default=()) + # self.branch_labels = set(self._orig_branch_labels) def __repr__(self): - args = [ - repr(self.revision), - repr(self.down_revision) - ] + args = [repr(self.revision), repr(self.down_revision)] # if self.dependencies: # args.append("dependencies=%r" % self.dependencies) # if self.branch_labels: # args.append("branch_labels=%r" % self.branch_labels) - return "%s(%s)" % ( - self.__class__.__name__, - ", ".join(args) - ) + return "%s(%s)" % (self.__class__.__name__, ", ".join(args)) def add_nextrev(self, revision): self._all_nextrev = self._all_nextrev.union([revision.revision]) @@ -63,7 +54,7 @@ def add_nextrev(self, revision): @property def _all_down_revisions(self): return to_tuple(self.down_revision, default=()) - #+ self._resolved_dependencies + # + self._resolved_dependencies @property def _versioned_down_revisions(self): @@ -112,14 +103,11 @@ def _revision_map(self): self.bases = () self._real_bases = () - #has_branch_labels = set() - #has_depends_on = set() + # has_branch_labels = set() + # has_depends_on = set() for revision in self._generator(): - if revision.revision in rmap: - self.app.log.warn( - "Revision %s is present more than once" % revision.revision - ) + self.app.log.warn("Revision %s is present more than once" % revision.revision) rmap[revision.revision] = revision # if revision.branch_labels: # has_branch_labels.add(revision) @@ -128,9 +116,9 @@ def _revision_map(self): heads.add(revision.revision) _real_heads.add(revision.revision) if revision.is_base: - self.bases += (revision.revision, ) - self._real_bases += (revision.revision, ) - #if revision._is_real_base: + self.bases += (revision.revision,) + self._real_bases += (revision.revision,) + # if revision._is_real_base: # self._real_bases += (revision.revision, ) # for revision in has_branch_labels: @@ -142,9 +130,7 @@ def _revision_map(self): for rev in rmap.values(): for downrev in rev._all_down_revisions: if downrev not in rmap: - self.app.log.warn( - "Revision %s referenced from %s is not present" % - (downrev, rev)) + self.app.log.warn("Revision %s referenced from %s is not present" % (downrev, rev)) down_revision = rmap[downrev] down_revision.add_nextrev(rev) if downrev in rev._versioned_down_revisions: @@ -190,14 +176,14 @@ def get_current_head(self): def _resolve_revision_number(self, rid): self._revision_map - if rid == 'heads': + if rid == "heads": return self._real_heads - elif rid == 'head': + elif rid == "head": current_head = self.get_current_head() if current_head: - return (current_head, ) + return (current_head,) return () - elif rid == 'base' or rid is None: + elif rid == "base" or rid is None: return () else: return to_tuple(rid, default=None) @@ -207,19 +193,15 @@ def _revision_for_ident(self, resolved_id): revision = self._revision_map[resolved_id] except KeyError: # do a partial lookup - revs = [x for x in self._revision_map - if x and x.startswith(resolved_id)] + revs = [x for x in self._revision_map if x and x.startswith(resolved_id)] if not revs: - raise ResolutionError( - "No such revision or branch '%s'" % resolved_id, - resolved_id) + raise ResolutionError("No such revision or branch '%s'" % resolved_id, resolved_id) elif len(revs) > 1: raise ResolutionError( "Multiple revisions start " - "with '%s': %s..." % ( - resolved_id, - ", ".join("'%s'" % r for r in revs[0:3]) - ), resolved_id) + "with '%s': %s..." % (resolved_id, ", ".join("'%s'" % r for r in revs[0:3])), + resolved_id, + ) else: revision = self._revision_map[revs[0]] return revision @@ -238,65 +220,53 @@ def get_revisions(self, rid): return sum([self.get_revisions(id_elem) for id_elem in rid], ()) else: resolved_id = self._resolve_revision_number(rid) - return tuple( - self._revision_for_ident(rev_id) for rev_id in resolved_id) + return tuple(self._revision_for_ident(rev_id) for rev_id in resolved_id) def add_revision(self, revision, _replace=False): map_ = self._revision_map if not _replace and revision.revision in map_: - self.app.log.warn( - "Revision %s is present more than once" % revision.revision) + self.app.log.warn("Revision %s is present more than once" % revision.revision) elif _replace and revision.revision not in map_: raise Exception("revision %s not in map" % revision.revision) map_[revision.revision] = revision if revision.is_base: - self.bases += (revision.revision, ) - self._real_bases += (revision.revision, ) + self.bases += (revision.revision,) + self._real_bases += (revision.revision,) # if revision._is_real_base: # self._real_bases += (revision.revision, ) for downrev in revision._all_down_revisions: if downrev not in map_: - self.app.log.warn( - "Revision %s referenced from %s is not present" - % (downrev, revision) - ) + self.app.log.warn("Revision %s referenced from %s is not present" % (downrev, revision)) map_[downrev].add_nextrev(revision) if revision._is_real_head: self._real_heads = tuple( - head for head in self._real_heads - if head not in - set(revision._all_down_revisions).union([revision.revision]) + head + for head in self._real_heads + if head not in set(revision._all_down_revisions).union([revision.revision]) ) + (revision.revision,) if revision.is_head: self.heads = tuple( - head for head in self.heads - if head not in - set(revision._versioned_down_revisions).union( - [revision.revision]) + head + for head in self.heads + if head not in set(revision._versioned_down_revisions).union([revision.revision]) ) + (revision.revision,) - def iterate_revisions(self, upper, lower, implicit_base=False, - inclusive=False): + def iterate_revisions(self, upper, lower, implicit_base=False, inclusive=False): #: iterate through script revisions, starting at the given upper # revision identifier and ending at the lower. - return self._iterate_revisions( - upper, lower, inclusive=inclusive, implicit_base=implicit_base) + return self._iterate_revisions(upper, lower, inclusive=inclusive, implicit_base=implicit_base) def _get_ancestor_nodes(self, targets, map_=None, check=False): fn = lambda rev: rev._versioned_down_revisions - return self._iterate_related_revisions( - fn, targets, map_=map_, check=check - ) + return self._iterate_related_revisions(fn, targets, map_=map_, check=check) def _get_descendant_nodes(self, targets, map_=None, check=False): fn = lambda rev: rev.nextrev - return self._iterate_related_revisions( - fn, targets, map_=map_, check=check - ) + return self._iterate_related_revisions(fn, targets, map_=map_, check=check) def _iterate_related_revisions(self, fn, targets, map_, check=False): if map_ is None: @@ -305,7 +275,6 @@ def _iterate_related_revisions(self, fn, targets, map_, check=False): seen = set() todo = deque() for target in targets: - todo.append(target) if check: per_target = set() @@ -318,16 +287,14 @@ def _iterate_related_revisions(self, fn, targets, map_, check=False): if rev in seen: continue seen.add(rev) - todo.extend( - map_[rev_id] for rev_id in fn(rev)) + todo.extend(map_[rev_id] for rev_id in fn(rev)) yield rev if check and per_target.intersection(targets).difference([target]): raise RevisionError( - "Requested revision %s overlaps with " - "other requested revisions" % target.revision) + "Requested revision %s overlaps with " "other requested revisions" % target.revision + ) - def _iterate_revisions(self, upper, lower, inclusive=True, - implicit_base=False): + def _iterate_revisions(self, upper, lower, inclusive=True, implicit_base=False): #: iterate revisions from upper to lower. requested_lowers = self.get_revisions(lower) uppers = dedupe_tuple(self.get_revisions(upper)) @@ -338,16 +305,10 @@ def _iterate_revisions(self, upper, lower, inclusive=True, upper_ancestors = set(self._get_ancestor_nodes(uppers, check=True)) if implicit_base and requested_lowers: - lower_ancestors = set( - self._get_ancestor_nodes(requested_lowers) - ) - lower_descendants = set( - self._get_descendant_nodes(requested_lowers) - ) + lower_ancestors = set(self._get_ancestor_nodes(requested_lowers)) + lower_descendants = set(self._get_descendant_nodes(requested_lowers)) base_lowers = set() - candidate_lowers = upper_ancestors.\ - difference(lower_ancestors).\ - difference(lower_descendants) + candidate_lowers = upper_ancestors.difference(lower_ancestors).difference(lower_descendants) for rev in candidate_lowers: for downrev in rev._all_down_revisions: if self._revision_map[downrev] in candidate_lowers: @@ -364,10 +325,8 @@ def _iterate_revisions(self, upper, lower, inclusive=True, lowers = requested_lowers # represents all nodes we will produce - total_space = set( - rev.revision for rev in upper_ancestors).intersection( - rev.revision for rev - in self._get_descendant_nodes(lowers, check=True) + total_space = {rev.revision for rev in upper_ancestors}.intersection( + rev.revision for rev in self._get_descendant_nodes(lowers, check=True) ) if not total_space: @@ -375,27 +334,23 @@ def _iterate_revisions(self, upper, lower, inclusive=True, # organize branch points to be consumed separately from # member nodes - branch_todo = set( - rev for rev in - (self._revision_map[rev] for rev in total_space) - if rev._is_real_branch_point and - len(total_space.intersection(rev._all_nextrev)) > 1 - ) + branch_todo = { + rev + for rev in (self._revision_map[rev] for rev in total_space) + if rev._is_real_branch_point and len(total_space.intersection(rev._all_nextrev)) > 1 + } # it's not possible for any "uppers" to be in branch_todo, # because the ._all_nextrev of those nodes is not in total_space - #assert not branch_todo.intersection(uppers) + # assert not branch_todo.intersection(uppers) - todo = deque( - r for r in uppers if r.revision in total_space) + todo = deque(r for r in uppers if r.revision in total_space) # iterate for total_space being emptied out total_space_modified = True while total_space: - if not total_space_modified: - raise RevisionError( - "Dependency resolution failed; iteration can't proceed") + raise RevisionError("Dependency resolution failed; iteration can't proceed") total_space_modified = False # when everything non-branch pending is consumed, # add to the todo any branch nodes that have no @@ -403,13 +358,10 @@ def _iterate_revisions(self, upper, lower, inclusive=True, if not todo: todo.extendleft( sorted( - ( - rev for rev in branch_todo - if not rev._all_nextrev.intersection(total_space) - ), + (rev for rev in branch_todo if not rev._all_nextrev.intersection(total_space)), # favor "revisioned" branch points before # dependent ones - key=lambda rev: 0 if rev.is_branch_point else 1 + key=lambda rev: 0 if rev.is_branch_point else 1, ) ) branch_todo.difference_update(todo) @@ -421,11 +373,13 @@ def _iterate_revisions(self, upper, lower, inclusive=True, # do depth first for elements within branches, # don't consume any actual branch nodes - todo.extendleft([ - self._revision_map[downrev] - for downrev in reversed(rev._all_down_revisions) - if self._revision_map[downrev] not in branch_todo and - downrev in total_space]) + todo.extendleft( + [ + self._revision_map[downrev] + for downrev in reversed(rev._all_down_revisions) + if self._revision_map[downrev] not in branch_todo and downrev in total_space + ] + ) if not inclusive and rev in requested_lowers: continue diff --git a/emmett/orm/migrations/scripts.py b/emmett/orm/migrations/scripts.py index eee7b6fb..452cf90c 100644 --- a/emmett/orm/migrations/scripts.py +++ b/emmett/orm/migrations/scripts.py @@ -1,22 +1,21 @@ # -*- coding: utf-8 -*- """ - emmett.orm.migrations.scripts - ----------------------------- +emmett.orm.migrations.scripts +----------------------------- - Provides scripts interface for migrations. +Provides scripts interface for migrations. - :copyright: 2014 Giovanni Barillari +:copyright: 2014 Giovanni Barillari - Based on the code of Alembic (https://bitbucket.org/zzzeek/alembic) - :copyright: (c) 2009-2015 by Michael Bayer +Based on the code of Alembic (https://bitbucket.org/zzzeek/alembic) +:copyright: (c) 2009-2015 by Michael Bayer - :license: BSD-3-Clause +:license: BSD-3-Clause """ import os import re import sys - from contextlib import contextmanager from datetime import datetime from importlib import resources @@ -26,30 +25,25 @@ from ...html import asis from . import __name__ as __pkg__ from .base import Migration -from .exceptions import ( - RangeNotAncestorError, MultipleHeads, ResolutionError, RevisionError -) -from .helpers import tuple_rev_as_scalar, format_with_comma +from .exceptions import MultipleHeads, RangeNotAncestorError, ResolutionError, RevisionError +from .helpers import format_with_comma, tuple_rev_as_scalar from .revisions import Revision, RevisionsMap class ScriptDir(object): - _slug_re = re.compile(r'\w+') + _slug_re = re.compile(r"\w+") _default_file_template = "%(rev)s_%(slug)s" def __init__(self, app, migrations_folder=None): self.app = app - self.path = os.path.join( - app.root_path, migrations_folder or 'migrations') + self.path = os.path.join(app.root_path, migrations_folder or "migrations") if not os.path.exists(self.path): os.mkdir(self.path) self.cwd = os.path.dirname(__file__) - self.file_template = self.app.config.migrations.file_template or \ - self._default_file_template - self.truncate_slug_length = \ - self.app.config.migrations.filename_len or 40 + self.file_template = self.app.config.migrations.file_template or self._default_file_template + self.truncate_slug_length = self.app.config.migrations.filename_len or 40 self.revision_map = RevisionsMap(self.app, self._load_revisions) - self.templater = Renoir(path=self.cwd, mode='plain') + self.templater = Renoir(path=self.cwd, mode="plain") def _load_revisions(self): sys.path.insert(0, self.path) @@ -60,10 +54,7 @@ def _load_revisions(self): yield script @contextmanager - def _catch_revision_errors( - self, - ancestor=None, multiple_heads=None, start=None, end=None, - resolution=None): + def _catch_revision_errors(self, ancestor=None, multiple_heads=None, start=None, end=None, resolution=None): try: yield except RangeNotAncestorError as rna: @@ -85,25 +76,20 @@ def _catch_revision_errors( "argument '%(head_arg)s'; please " "specify a specific target revision, " "'@%(head_arg)s' to " - "narrow to a specific head, or 'heads' for all heads") - multiple_heads = multiple_heads % { - "head_arg": end or mh.argument, - "heads": str(mh.heads) - } + "narrow to a specific head, or 'heads' for all heads" + ) + multiple_heads = multiple_heads % {"head_arg": end or mh.argument, "heads": str(mh.heads)} raise Exception(multiple_heads) except ResolutionError as re: if resolution is None: - resolution = "Can't locate revision identified by '%s'" % ( - re.argument - ) + resolution = "Can't locate revision identified by '%s'" % (re.argument) raise Exception(resolution) except RevisionError as err: raise Exception(err.args[0]) def walk_revisions(self, base="base", head="heads"): with self._catch_revision_errors(start=base, end=head): - for rev in self.revision_map.iterate_revisions( - head, base, inclusive=True): + for rev in self.revision_map.iterate_revisions(head, base, inclusive=True): yield rev def get_revision(self, revid): @@ -116,42 +102,41 @@ def get_revisions(self, revid): def get_upgrade_revs(self, destination, current_rev): with self._catch_revision_errors( - ancestor="Destination %(end)s is not a valid upgrade " - "target from current head(s)", end=destination): - revs = self.revision_map.iterate_revisions( - destination, current_rev, implicit_base=True) + ancestor="Destination %(end)s is not a valid upgrade " "target from current head(s)", end=destination + ): + revs = self.revision_map.iterate_revisions(destination, current_rev, implicit_base=True) return reversed(list(revs)) def get_downgrade_revs(self, destination, current_rev): with self._catch_revision_errors( - ancestor="Destination %(end)s is not a valid downgrade " - "target from current head(s)", end=destination): - revs = self.revision_map.iterate_revisions( - current_rev, destination) + ancestor="Destination %(end)s is not a valid downgrade " "target from current head(s)", end=destination + ): + revs = self.revision_map.iterate_revisions(current_rev, destination) return list(revs) def _rev_filename(self, revid, message, creation_date): slug = "_".join(self._slug_re.findall(message or "")).lower() if len(slug) > self.truncate_slug_length: - slug = slug[:self.truncate_slug_length].rsplit('_', 1)[0] + '_' + slug = slug[: self.truncate_slug_length].rsplit("_", 1)[0] + "_" filename = "%s.py" % ( - self.file_template % { - 'rev': revid, - 'slug': slug, - 'year': creation_date.year, - 'month': creation_date.month, - 'day': creation_date.day, - 'hour': creation_date.hour, - 'minute': creation_date.minute, - 'second': creation_date.second + self.file_template + % { + "rev": revid, + "slug": slug, + "year": creation_date.year, + "month": creation_date.month, + "day": creation_date.day, + "hour": creation_date.hour, + "minute": creation_date.minute, + "second": creation_date.second, } ) return filename def _generate_template(self, filename, ctx): - tmpl_source = resources.read_text(__pkg__, 'migration.tmpl') + tmpl_source = resources.read_text(__pkg__, "migration.tmpl") rendered = self.templater._render(source=tmpl_source, context=ctx) - with open(os.path.join(self.path, filename), 'w') as f: + with open(os.path.join(self.path, filename), "w") as f: f.write(rendered) def generate_revision(self, revid, message, head=None, splice=False, **kw): @@ -171,11 +156,13 @@ def generate_revision(self, revid, message, head=None, splice=False, **kw): if head is None: head = "head" - with self._catch_revision_errors(multiple_heads=( - "Multiple heads are present; please specify the head " - "revision on which the new revision should be based, " - "or perform a merge." - )): + with self._catch_revision_errors( + multiple_heads=( + "Multiple heads are present; please specify the head " + "revision on which the new revision should be based, " + "or perform a merge." + ) + ): heads = self.revision_map.get_revisions(head) if len(set(heads)) != len(heads): @@ -190,11 +177,10 @@ def generate_revision(self, revid, message, head=None, splice=False, **kw): if head is not None and not head.is_head: raise Exception( "Revision %s is not a head revision; please specify " - "--splice to create a new branch from this revision" - % head.revision) + "--splice to create a new branch from this revision" % head.revision + ) - down_migration = tuple( - h.revision if h is not None else None for h in heads) + down_migration = tuple(h.revision if h is not None else None for h in heads) down_migration_var = tuple_rev_as_scalar(down_migration) if isinstance(down_migration_var, str): @@ -202,16 +188,16 @@ def generate_revision(self, revid, message, head=None, splice=False, **kw): else: down_migration_var = str(down_migration_var) - template_ctx = dict( - asis=asis, - up_migration=revid, - down_migration=down_migration_var, - creation_date=creation_date, - down_migration_str=", ".join(r for r in down_migration), - message=message if message is not None else ("empty message"), - upgrades=kw.get('upgrades', ['pass']), - downgrades=kw.get('downgrades', ['pass']) - ) + template_ctx = { + "asis": asis, + "up_migration": revid, + "down_migration": down_migration_var, + "creation_date": creation_date, + "down_migration_str": ", ".join(r for r in down_migration), + "message": message if message is not None else ("empty message"), + "upgrades": kw.get("upgrades", ["pass"]), + "downgrades": kw.get("downgrades", ["pass"]), + } self._generate_template(rev_filename, template_ctx) script = Script._from_filename(self, rev_filename) @@ -220,7 +206,7 @@ def generate_revision(self, revid, message, head=None, splice=False, **kw): class Script(Revision): - _only_source_rev_file = re.compile(r'(?!__init__)(.*\.py)$') + _only_source_rev_file = re.compile(r"(?!__init__)(.*\.py)$") migration_class = None path = None @@ -228,10 +214,7 @@ def __init__(self, module, migration_class, path): self.module = module self.migration_class = migration_class self.path = path - super(Script, self).__init__( - self.migration_class.revision, - self.migration_class.revises - ) + super(Script, self).__init__(self.migration_class.revision, self.migration_class.revises) @property def doc(self): @@ -251,22 +234,16 @@ def log_entry(self): " (mergepoint)" if self.is_merge_point else "", ) if self.is_merge_point: - entry += "Merges: %s\n" % (self._format_down_revision(), ) + entry += "Merges: %s\n" % (self._format_down_revision(),) else: - entry += "Parent: %s\n" % (self._format_down_revision(), ) + entry += "Parent: %s\n" % (self._format_down_revision(),) if self.is_branch_point: - entry += "Branches into: %s\n" % ( - format_with_comma(self.nextrev)) + entry += "Branches into: %s\n" % (format_with_comma(self.nextrev)) entry += "Path: %s\n" % (self.path,) - entry += "\n%s\n" % ( - "\n".join( - " %s" % para - for para in self.longdoc.splitlines() - ) - ) + entry += "\n%s\n" % ("\n".join(" %s" % para for para in self.longdoc.splitlines())) return entry def __str__(self): @@ -276,21 +253,17 @@ def __str__(self): " (head)" if self.is_head else "", " (branchpoint)" if self.is_branch_point else "", " (mergepoint)" if self.is_merge_point else "", - self.doc) + self.doc, + ) - def _head_only( - self, include_doc=False, - include_parents=False, tree_indicators=True, - head_indicators=True): + def _head_only(self, include_doc=False, include_parents=False, tree_indicators=True, head_indicators=True): text = self.revision if include_parents: - text = "%s -> %s" % ( - self._format_down_revision(), text) + text = "%s -> %s" % (self._format_down_revision(), text) if head_indicators or tree_indicators: text += "%s%s" % ( " (head)" if self._is_real_head else "", - " (effective head)" if self.is_head and - not self._is_real_head else "" + " (effective head)" if self.is_head and not self._is_real_head else "", ) if tree_indicators: text += "%s%s" % ( @@ -301,17 +274,11 @@ def _head_only( text += ", %s" % self.doc return text - def cmd_format( - self, - verbose, - include_doc=False, - include_parents=False, tree_indicators=True): + def cmd_format(self, verbose, include_doc=False, include_parents=False, tree_indicators=True): if verbose: return self.log_entry else: - return self._head_only( - include_doc, - include_parents, tree_indicators) + return self._head_only(include_doc, include_parents, tree_indicators) def _format_down_revision(self): if not self.down_revision: @@ -325,14 +292,13 @@ def _from_filename(cls, scriptdir, filename): if not py_match: return None py_filename = py_match.group(1) - py_module = py_filename.split('.py')[0] + py_module = py_filename.split(".py")[0] __import__(py_module) module = sys.modules[py_module] - migration_class = getattr(module, 'Migration', None) + migration_class = getattr(module, "Migration", None) if migration_class is None: for v in module.__dict__.values(): if isinstance(v, Migration): migration_class = v break - return Script( - module, migration_class, os.path.join(scriptdir.path, filename)) + return Script(module, migration_class, os.path.join(scriptdir.path, filename)) diff --git a/emmett/orm/migrations/utils.py b/emmett/orm/migrations/utils.py index 6ae72560..38e04599 100644 --- a/emmett/orm/migrations/utils.py +++ b/emmett/orm/migrations/utils.py @@ -1,12 +1,12 @@ # -*- coding: utf-8 -*- """ - emmett.orm.migrations.utilities - ------------------------------- +emmett.orm.migrations.utilities +------------------------------- - Provides some migration utilities. +Provides some migration utilities. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ from ..base import Database @@ -22,7 +22,7 @@ def _load_head_to_meta(self): class RuntimeMigration(MigrationOp): def __init__(self, engine: Engine, ops: UpgradeOps): - super().__init__('runtime', ops, ops.reverse(), 'runtime') + super().__init__("runtime", ops, ops.reverse(), "runtime") self.engine = engine for op in self.upgrade_ops.ops: op.engine = self.engine diff --git a/emmett/orm/models.py b/emmett/orm/models.py index 186612e1..b76def18 100644 --- a/emmett/orm/models.py +++ b/emmett/orm/models.py @@ -1,39 +1,23 @@ # -*- coding: utf-8 -*- """ - emmett.orm.models - ----------------- +emmett.orm.models +----------------- - Provides model layer for Emmet's ORM. +Provides model layer for Emmet's ORM. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ import operator import types - from collections import OrderedDict from functools import reduce from typing import Any, Callable from ..datastructures import sdict -from .apis import ( - compute, - rowattr, - rowmethod, - scope, - belongs_to, - refers_to, - has_one, - has_many -) -from .errors import ( - InsertFailureOnSave, - SaveException, - UpdateFailureOnSave, - ValidationError, - DestroyException -) +from .apis import belongs_to, compute, has_many, has_one, refers_to, rowattr, rowmethod, scope +from .errors import DestroyException, InsertFailureOnSave, SaveException, UpdateFailureOnSave, ValidationError from .helpers import ( Callback, ReferenceData, @@ -45,23 +29,30 @@ typed_row_reference, typed_row_reference_from_record, wrap_scope_on_model, - wrap_virtual_on_model + wrap_virtual_on_model, ) from .objects import Field, StructuredRow -from .wrappers import HasOneWrap, HasOneViaWrap, HasManyWrap, HasManyViaWrap +from .wrappers import HasManyViaWrap, HasManyWrap, HasOneViaWrap, HasOneWrap class MetaModel(type): _inheritable_dict_attrs_ = [ - 'indexes', 'validation', ('fields_rw', {'id': False}), 'foreign_keys', - 'default_values', 'update_values', 'repr_values', - 'form_labels', 'form_info', 'form_widgets' + "indexes", + "validation", + ("fields_rw", {"id": False}), + "foreign_keys", + "default_values", + "update_values", + "repr_values", + "form_labels", + "form_info", + "form_widgets", ] def __new__(cls, name, bases, attrs): new_class = type.__new__(cls, name, bases, attrs) #: collect declared attributes - tablename = attrs.get('tablename') + tablename = attrs.get("tablename") fields = [] vfields = [] computations = [] @@ -83,8 +74,7 @@ def __new__(cls, name, bases, attrs): elif isinstance(value, scope): declared_scopes[key] = value declared_relations = sdict( - belongs=OrderedDict(), refers=OrderedDict(), - hasone=OrderedDict(), hasmany=OrderedDict() + belongs=OrderedDict(), refers=OrderedDict(), hasone=OrderedDict(), hasmany=OrderedDict() ) for ref in belongs_to._references_.values(): for item in ref.reference: @@ -132,25 +122,22 @@ def __new__(cls, name, bases, attrs): all_computations = OrderedDict() all_callbacks = OrderedDict() all_scopes = {} - all_relations = sdict( - belongs=OrderedDict(), refers=OrderedDict(), - hasone=OrderedDict(), hasmany=OrderedDict() - ) + all_relations = sdict(belongs=OrderedDict(), refers=OrderedDict(), hasone=OrderedDict(), hasmany=OrderedDict()) super_vfields = OrderedDict() for base in reversed(new_class.__mro__[1:]): - if hasattr(base, '_declared_fields_'): + if hasattr(base, "_declared_fields_"): all_fields.update(base._declared_fields_) - if hasattr(base, '_declared_virtuals_'): + if hasattr(base, "_declared_virtuals_"): all_vfields.update(base._declared_virtuals_) super_vfields.update(base._declared_virtuals_) - if hasattr(base, '_declared_computations_'): + if hasattr(base, "_declared_computations_"): all_computations.update(base._declared_computations_) - if hasattr(base, '_declared_callbacks_'): + if hasattr(base, "_declared_callbacks_"): all_callbacks.update(base._declared_callbacks_) - if hasattr(base, '_declared_scopes_'): + if hasattr(base, "_declared_scopes_"): all_scopes.update(base._declared_scopes_) for key in list(all_relations): - attrkey = '_declared_' + key + '_ref_' + attrkey = "_declared_" + key + "_ref_" if hasattr(base, attrkey): all_relations[key].update(getattr(base, attrkey)) #: compose 'all' attributes @@ -237,13 +224,13 @@ def __new__(cls): return super(Model, cls).__new__(cls) def __init__(self): - if not hasattr(self, 'migrate'): - self.migrate = self.config.get('migrate', self.db._migrate) - if not hasattr(self, 'format'): + if not hasattr(self, "migrate"): + self.migrate = self.config.get("migrate", self.db._migrate) + if not hasattr(self, "format"): self.format = None - if not hasattr(self, 'primary_keys'): + if not hasattr(self, "primary_keys"): self.primary_keys = [] - self._fieldset_pk = set(self.primary_keys or ['id']) + self._fieldset_pk = set(self.primary_keys or ["id"]) @property def config(self): @@ -253,7 +240,7 @@ def __parse_relation_via(self, via): if via is None: return via rv = sdict() - splitted = via.split('.') + splitted = via.split(".") rv.via = splitted[0] if len(splitted) > 1: rv.field = splitted[1] @@ -290,30 +277,27 @@ def __build_relation_modelname(self, name, relation, singularize): relation.model = relation.model[:-1] def __build_relation_fieldnames(self, relation): - splitted = relation.model.split('.') + splitted = relation.model.split(".") relation.model = splitted[0] if len(splitted) > 1: relation.fields = [splitted[1]] else: if len(self.primary_keys) > 1: - relation.fields = [ - f"{decamelize(self.__class__.__name__)}_{pk}" - for pk in self.primary_keys - ] + relation.fields = [f"{decamelize(self.__class__.__name__)}_{pk}" for pk in self.primary_keys] else: relation.fields = [decamelize(self.__class__.__name__)] def __parse_relation_dict(self, rel, singularize): - if 'scope' in rel.model: - rel.scope = rel.model['scope'] - if 'where' in rel.model: - rel.where = rel.model['where'] - if 'via' in rel.model: - rel.update(self.__parse_relation_via(rel.model['via'])) + if "scope" in rel.model: + rel.scope = rel.model["scope"] + if "where" in rel.model: + rel.where = rel.model["where"] + if "via" in rel.model: + rel.update(self.__parse_relation_via(rel.model["via"])) del rel.model else: - if 'target' in rel.model: - rel.model = rel.model['target'] + if "target" in rel.model: + rel.model = rel.model["target"] if not isinstance(rel.model, str): self.__build_relation_modelname(rel.name, rel, singularize) @@ -323,19 +307,16 @@ def __parse_many_relation(self, item, singularize=True): rv.name = list(item)[0] rv.model = item[rv.name] if isinstance(rv.model, dict): - if 'method' in rv.model: - if 'field' in rv.model: - rv.fields = [rv.model['field']] + if "method" in rv.model: + if "field" in rv.model: + rv.fields = [rv.model["field"]] else: if len(self.primary_keys) > 1: - rv.fields = [ - f"{decamelize(self.__class__.__name__)}_{pk}" - for pk in self.primary_keys - ] + rv.fields = [f"{decamelize(self.__class__.__name__)}_{pk}" for pk in self.primary_keys] else: rv.fields = [decamelize(self.__class__.__name__)] - rv.cast = rv.model.get('cast') - rv.method = rv.model['method'] + rv.cast = rv.model.get("cast") + rv.method = rv.model["method"] del rv.model else: self.__parse_relation_dict(rv, singularize) @@ -348,18 +329,15 @@ def __parse_many_relation(self, item, singularize=True): if rv.model == "self": rv.model = self.__class__.__name__ if not rv.via: - rv.reverse = ( - rv.fields[0] if len(rv.fields) == 1 else - decamelize(self.__class__.__name__) - ) + rv.reverse = rv.fields[0] if len(rv.fields) == 1 else decamelize(self.__class__.__name__) return rv def _define_props_(self): #: create pydal's Field elements self.fields = [] - if not self.primary_keys and 'id' not in self._all_fields_: - idfield = Field('id')._make_field('id', model=self) - setattr(self.__class__, 'id', idfield) + if not self.primary_keys and "id" not in self._all_fields_: + idfield = Field("id")._make_field("id", model=self) + self.__class__.id = idfield self.fields.append(idfield) for name, obj in self._all_fields_.items(): if obj.modelname is not None: @@ -372,10 +350,7 @@ def __find_matching_fk_definition(self, rfields, lfields, rmodel): if not set(rfields).issubset(set(rmodel.primary_keys)): return match for key, val in self.foreign_keys.items(): - if ( - set(val["foreign_fields"]) == set(rmodel.primary_keys) and - set(lfields).issubset(set(val["fields"])) - ): + if set(val["foreign_fields"]) == set(rmodel.primary_keys) and set(lfields).issubset(set(val["fields"])): match = key break return match @@ -383,13 +358,10 @@ def __find_matching_fk_definition(self, rfields, lfields, rmodel): def _define_relations_(self): self._virtual_relations_ = OrderedDict() self._compound_relations_ = {} - bad_args_error = ( - "belongs_to, has_one and has_many " - "only accept strings or dicts as arguments" - ) + bad_args_error = "belongs_to, has_one and has_many " "only accept strings or dicts as arguments" #: belongs_to and refers_to are mapped with 'reference' type Field _references = [] - _reference_keys = ['_all_belongs_ref_', '_all_refers_ref_'] + _reference_keys = ["_all_belongs_ref_", "_all_refers_ref_"] belongs_references = {} belongs_fks = {} for key in _reference_keys: @@ -397,24 +369,18 @@ def _define_relations_(self): _references.append(list(getattr(self, key).values())) else: _references.append([]) - is_belongs, ondelete = True, 'cascade' + is_belongs, ondelete = True, "cascade" for _references_obj in _references: for item in _references_obj: if not isinstance(item, (str, dict)): raise RuntimeError(bad_args_error) reference = self.__parse_belongs_relation(item, ondelete) reference.is_refers = not is_belongs - refmodel = ( - self.db[reference.model]._model_ - if reference.model != self.__class__.__name__ - else self - ) + refmodel = self.db[reference.model]._model_ if reference.model != self.__class__.__name__ else self ref_multi_pk = len(refmodel._fieldset_pk) > 1 fk_def_key, fks_data, multi_fk = None, {}, [] if ref_multi_pk and reference.fk: - fk_def_key = self.__find_matching_fk_definition( - [reference.fk], [reference.name], refmodel - ) + fk_def_key = self.__find_matching_fk_definition([reference.fk], [reference.name], refmodel) if not fk_def_key: raise SyntaxError( f"{self.__class__.__name__}.{reference.name} relation " @@ -445,10 +411,9 @@ def _define_relations_(self): local_fields=fks_data["fields"], foreign_fields=fks_data["foreign_fields"], coupled_fields=[ - (local, fks_data["foreign_fields"][idx]) - for idx, local in enumerate(fks_data["fields"]) + (local, fks_data["foreign_fields"][idx]) for idx, local in enumerate(fks_data["fields"]) ], - is_refers=reference.is_refers + is_refers=reference.is_refers, ) self._compound_relations_[reference.name] = sdict( model=reference.model, @@ -465,40 +430,31 @@ def _define_relations_(self): local_fields=[reference.name], foreign_fields=[reference.fk], coupled_fields=[(reference.name, reference.fk)], - is_refers=reference.is_refers + is_refers=reference.is_refers, ) if not fk_def_key and fks_data: - self.foreign_keys[reference.name] = self.foreign_keys.get( - reference.name - ) or fks_data + self.foreign_keys[reference.name] = self.foreign_keys.get(reference.name) or fks_data for reference in references: if reference.model != self.__class__.__name__: tablename = self.db[reference.model]._tablename else: tablename = self.tablename fieldobj = Field( - ( - f"reference {tablename}" if not ref_multi_pk else - f"reference {tablename}.{reference.fk}" - ), + (f"reference {tablename}" if not ref_multi_pk else f"reference {tablename}.{reference.fk}"), ondelete=reference.on_delete, - _isrefers=not is_belongs + _isrefers=not is_belongs, ) setattr(self.__class__, reference.name, fieldobj) - self.fields.append( - getattr(self, reference.name)._make_field( - reference.name, self - ) - ) + self.fields.append(getattr(self, reference.name)._make_field(reference.name, self)) belongs_references[reference.name] = reference is_belongs = False - ondelete = 'nullify' - setattr(self.__class__, '_belongs_ref_', belongs_references) - setattr(self.__class__, '_belongs_fks_', belongs_fks) + ondelete = "nullify" + self.__class__._belongs_ref_ = belongs_references + self.__class__._belongs_fks_ = belongs_fks #: has_one are mapped with rowattr hasone_references = {} - if hasattr(self, '_all_hasone_ref_'): - for item in getattr(self, '_all_hasone_ref_').values(): + if hasattr(self, "_all_hasone_ref_"): + for item in getattr(self, "_all_hasone_ref_").values(): if not isinstance(item, (str, dict)): raise RuntimeError(bad_args_error) reference = self.__parse_many_relation(item, False) @@ -508,15 +464,13 @@ def _define_relations_(self): else: #: maps has_one('thing'), has_one({'thing': 'othername'}) wrapper = HasOneWrap - self._virtual_relations_[reference.name] = rowattr( - reference.name - )(wrapper(reference)) + self._virtual_relations_[reference.name] = rowattr(reference.name)(wrapper(reference)) hasone_references[reference.name] = reference - setattr(self.__class__, '_hasone_ref_', hasone_references) + self.__class__._hasone_ref_ = hasone_references #: has_many are mapped with rowattr hasmany_references = {} - if hasattr(self, '_all_hasmany_ref_'): - for item in getattr(self, '_all_hasmany_ref_').values(): + if hasattr(self, "_all_hasmany_ref_"): + for item in getattr(self, "_all_hasmany_ref_").values(): if not isinstance(item, (str, dict)): raise RuntimeError(bad_args_error) reference = self.__parse_many_relation(item) @@ -526,11 +480,9 @@ def _define_relations_(self): else: #: maps has_many('things'), has_many({'things': 'othername'}) wrapper = HasManyWrap - self._virtual_relations_[reference.name] = rowattr( - reference.name - )(wrapper(reference)) + self._virtual_relations_[reference.name] = rowattr(reference.name)(wrapper(reference)) hasmany_references[reference.name] = reference - setattr(self.__class__, '_hasmany_ref_', hasmany_references) + self.__class__._hasmany_ref_ = hasmany_references self.__define_fks() def __define_fks(self): @@ -538,15 +490,8 @@ def __define_fks(self): implicit_defs = {} grouped_rels = {} for rname, rel in self._belongs_ref_.items(): - rmodel = ( - self.db[rel.model]._model_ - if rel.model != self.__class__.__name__ - else self - ) - if ( - not rmodel.primary_keys and - getattr(rmodel, list(rmodel._fieldset_pk)[0]).type == 'id' - ): + rmodel = self.db[rel.model]._model_ if rel.model != self.__class__.__name__ else self + if not rmodel.primary_keys and getattr(rmodel, list(rmodel._fieldset_pk)[0]).type == "id": continue if len(rmodel._fieldset_pk) > 1: match = self.__find_matching_fk_definition([rel.fk], [rel.name], rmodel) @@ -557,46 +502,43 @@ def __define_fks(self): "needs to be defined into `foreign_keys`." ) crels = grouped_rels[match] = grouped_rels.get( - match, { - 'rels': {}, - 'table': rmodel.tablename, - 'on_delete': self.foreign_keys[match].get( - "on_delete", "cascade" - ) - } + match, + { + "rels": {}, + "table": rmodel.tablename, + "on_delete": self.foreign_keys[match].get("on_delete", "cascade"), + }, ) - crels['rels'][rname] = rel + crels["rels"][rname] = rel else: # NOTE: we need this since pyDAL doesn't support id/refs types != int implicit_defs[rname] = { - 'table': rmodel.tablename, - 'fields_local': [rname], - 'fields_foreign': [rel.fk], - 'on_delete': Field._internal_delete[rel.on_delete] + "table": rmodel.tablename, + "fields_local": [rname], + "fields_foreign": [rel.fk], + "on_delete": Field._internal_delete[rel.on_delete], } - for rname, rel in implicit_defs.items(): - constraint_name = self.__create_fk_contraint_name( - rel['table'], *rel['fields_local'] - ) + for _, rel in implicit_defs.items(): + constraint_name = self.__create_fk_contraint_name(rel["table"], *rel["fields_local"]) self._foreign_keys_[constraint_name] = {**rel} for crels in grouped_rels.values(): constraint_name = self.__create_fk_contraint_name( - crels['table'], *[rel.name for rel in crels['rels'].values()] + crels["table"], *[rel.name for rel in crels["rels"].values()] ) self._foreign_keys_[constraint_name] = { - 'table': crels['table'], - 'fields_local': [rel.name for rel in crels['rels'].values()], - 'fields_foreign': [rel.fk for rel in crels['rels'].values()], - 'on_delete': Field._internal_delete[crels['on_delete']] + "table": crels["table"], + "fields_local": [rel.name for rel in crels["rels"].values()], + "fields_foreign": [rel.fk for rel in crels["rels"].values()], + "on_delete": Field._internal_delete[crels["on_delete"]], } def _define_virtuals_(self): self._all_rowattrs_ = {} self._all_rowmethods_ = {} self._super_rowmethods_ = {} - err = 'rowattr or rowmethod cannot have the name of an existent field!' + err = "rowattr or rowmethod cannot have the name of an existent field!" field_names = [field.name for field in self.fields] - for attr in ['_virtual_relations_', '_all_virtuals_']: + for attr in ["_virtual_relations_", "_all_virtuals_"]: for obj in getattr(self, attr, {}).values(): if obj.field_name in field_names: raise RuntimeError(err) @@ -616,47 +558,36 @@ def _define_virtuals_(self): def _set_row_persistence_id(self, row, ret): row.id = ret.id - object.__setattr__(row, '_concrete', True) + object.__setattr__(row, "_concrete", True) def _set_row_persistence_pk(self, row, ret): row[self.primary_keys[0]] = ret[self.primary_keys[0]] - object.__setattr__(row, '_concrete', True) + object.__setattr__(row, "_concrete", True) def _set_row_persistence_pks(self, row, ret): for field_name in self.primary_keys: row[field_name] = ret[field_name] - object.__setattr__(row, '_concrete', True) + object.__setattr__(row, "_concrete", True) def _unset_row_persistence(self, row): for field_name in self._fieldset_pk: row[field_name] = None - object.__setattr__(row, '_concrete', False) + object.__setattr__(row, "_concrete", False) def _build_rowclass_(self): #: build helpers for rows save_excluded_fields = ( - set( - field.name for field in self.fields if - getattr(field, "type", None) == "id" - ) | - set(self._all_rowattrs_.keys()) | - set(self._all_rowmethods_.keys()) + {field.name for field in self.fields if getattr(field, "type", None) == "id"} + | set(self._all_rowattrs_.keys()) + | set(self._all_rowmethods_.keys()) ) - self._fieldset_initable = set([ - field.name for field in self.fields - ]) - save_excluded_fields - self._fieldset_editable = set([ - field.name for field in self.fields - ]) - save_excluded_fields - self._fieldset_pk + self._fieldset_initable = {field.name for field in self.fields} - save_excluded_fields + self._fieldset_editable = {field.name for field in self.fields} - save_excluded_fields - self._fieldset_pk self._fieldset_all = self._fieldset_initable | self._fieldset_pk - self._fieldset_update = set([ - field.name for field in self.fields - if getattr(field, "update", None) is not None - ]) & self._fieldset_editable - self._relations_wrapset = ( - set(self._belongs_fks_.keys()) - - set(self._compound_relations_.keys()) - ) + self._fieldset_update = { + field.name for field in self.fields if getattr(field, "update", None) is not None + } & self._fieldset_editable + self._relations_wrapset = set(self._belongs_fks_.keys()) - set(self._compound_relations_.keys()) if not self.primary_keys: self._set_row_persistence = self._set_row_persistence_id elif len(self.primary_keys) == 1: @@ -665,22 +596,12 @@ def _build_rowclass_(self): self._set_row_persistence = self._set_row_persistence_pks #: create dynamic row class clsname = self.__class__.__name__ + "Row" - attrs = {'_model': self} + attrs = {"_model": self} attrs.update({k: RowFieldMapper(k) for k in self._fieldset_all}) - attrs.update({ - k: RowVirtualMapper(k, v) - for k, v in self._all_rowattrs_.items() - }) + attrs.update({k: RowVirtualMapper(k, v) for k, v in self._all_rowattrs_.items()}) attrs.update(self._all_rowmethods_) - attrs.update({ - k: RowRelationMapper(self.db, self._belongs_ref_[k]) - for k in self._relations_wrapset - }) - attrs.update({ - k: RowCompoundRelationMapper( - self.db, data - ) for k, data in self._compound_relations_.items() - }) + attrs.update({k: RowRelationMapper(self.db, self._belongs_ref_[k]) for k in self._relations_wrapset}) + attrs.update({k: RowCompoundRelationMapper(self.db, data) for k, data in self._compound_relations_.items()}) self._rowclass_ = type(clsname, (StructuredRow,), attrs) globals()[clsname] = self._rowclass_ @@ -714,7 +635,7 @@ def __define_validation(self): def __define_access(self): for field, value in self.fields_rw.items(): - if field == 'id' and field not in self.table: + if field == "id" and field not in self.table: continue if isinstance(value, (tuple, list)): readable, writable = value @@ -737,14 +658,12 @@ def __define_representation(self): self.table[field].represent = value def __define_computations(self): - err = 'computations should have the name of an existing field to compute!' + err = "computations should have the name of an existing field to compute!" field_names = [field.name for field in self.fields] for obj in self._all_computations_.values(): if obj.field_name not in field_names: raise RuntimeError(err) - self.table[obj.field_name].compute = ( - lambda row, obj=obj, self=self: obj.compute(self, row) - ) + self.table[obj.field_name].compute = lambda row, obj=obj, self=self: obj.compute(self, row) def __define_callbacks(self): for obj in self._all_callbacks_.values(): @@ -766,72 +685,63 @@ def __define_callbacks(self): "_after_commit_update", "_after_commit_delete", "_after_commit_save", - "_after_commit_destroy" + "_after_commit_destroy", ]: - getattr(self.table, t).append( - lambda a, obj=obj, self=self: obj.f(self, a) - ) + getattr(self.table, t).append(lambda a, obj=obj, self=self: obj.f(self, a)) else: - getattr(self.table, t).append( - lambda a, b, obj=obj, self=self: obj.f(self, a, b) - ) + getattr(self.table, t).append(lambda a, b, obj=obj, self=self: obj.f(self, a, b)) def __define_scopes(self): self._scopes_ = {} for obj in self._all_scopes_.values(): self._scopes_[obj.name] = obj if not hasattr(self.__class__, obj.name): - setattr( - self.__class__, obj.name, - classmethod(wrap_scope_on_model(obj.f)) - ) + setattr(self.__class__, obj.name, classmethod(wrap_scope_on_model(obj.f))) def __prepend_table_name(self, name, ns): - return '%s_%s__%s' % (self.tablename, ns, name) + return "%s_%s__%s" % (self.tablename, ns, name) def __create_index_name(self, *values): components = [] for value in values: - components.append(value.replace('_', '')) - return self.__prepend_table_name("_".join(components), 'widx') + components.append(value.replace("_", "")) + return self.__prepend_table_name("_".join(components), "widx") def __create_fk_contraint_name(self, *values): components = [] for value in values: - components.append(value.replace('_', '')) - return self.__prepend_table_name("fk__" + "_".join(components), 'ecnt') + components.append(value.replace("_", "")) + return self.__prepend_table_name("fk__" + "_".join(components), "ecnt") def __parse_index_dict(self, value): rv = {} - fields = value.get('fields') or [] + fields = value.get("fields") or [] if not isinstance(fields, (list, tuple)): fields = [fields] - rv['fields'] = fields + rv["fields"] = fields where_query = None - where_cond = value.get('where') + where_cond = value.get("where") if callable(where_cond): where_query = where_cond(self.__class__) if where_query: - rv['where'] = where_query + rv["where"] = where_query expressions = [] - expressions_cond = value.get('expressions') + expressions_cond = value.get("expressions") if callable(expressions_cond): expressions = expressions_cond(self.__class__) if not isinstance(expressions, (tuple, list)): expressions = [expressions] - rv['expressions'] = expressions - rv['unique'] = value.get('unique', False) + rv["expressions"] = expressions + rv["unique"] = value.get("unique", False) return rv def __define_indexes(self): self._indexes_ = {} #: auto-define indexes based on fields for field in self.fields: - if getattr(field, 'unique', False): - idx_name = self.__prepend_table_name(f'{field.name}_unique', 'widx') - idx_dict = self.__parse_index_dict( - {'fields': [field.name], 'unique': True} - ) + if getattr(field, "unique", False): + idx_name = self.__prepend_table_name(f"{field.name}_unique", "widx") + idx_dict = self.__parse_index_dict({"fields": [field.name], "unique": True}) self._indexes_[idx_name] = idx_dict #: parse user-defined fields for key, value in self.indexes.items(): @@ -841,14 +751,14 @@ def __define_indexes(self): if not isinstance(key, tuple): key = [key] if any(field not in self.table for field in key): - raise SyntaxError(f'Invalid field specified in indexes: {key}') + raise SyntaxError(f"Invalid field specified in indexes: {key}") idx_name = self.__create_index_name(*key) - idx_dict = {'fields': key, 'expressions': [], 'unique': False} + idx_dict = {"fields": key, "expressions": [], "unique": False} elif isinstance(value, dict): - idx_name = self.__prepend_table_name(key, 'widx') + idx_name = self.__prepend_table_name(key, "widx") idx_dict = self.__parse_index_dict(value) else: - raise SyntaxError('Values in indexes dict should be booleans or dicts') + raise SyntaxError("Values in indexes dict should be booleans or dicts") self._indexes_[idx_name] = idx_dict def _row_record_query_id(self, row): @@ -858,35 +768,30 @@ def _row_record_query_pk(self, row): return self.table[self.primary_keys[0]] == row[self.primary_keys[0]] def _row_record_query_pks(self, row): - return reduce( - operator.and_, [self.table[pk] == row[pk] for pk in self.primary_keys] - ) + return reduce(operator.and_, [self.table[pk] == row[pk] for pk in self.primary_keys]) def __define_query_helpers(self): if not self.primary_keys: - self._query_id = self.table.id != None # noqa + self._query_id = self.table.id != None # noqa self._query_row = self._row_record_query_id self._order_by_id_asc = self.table.id self._order_by_id_desc = ~self.table.id elif len(self.primary_keys) == 1: - self._query_id = self.table[self.primary_keys[0]] != None # noqa + self._query_id = self.table[self.primary_keys[0]] != None # noqa self._query_row = self._row_record_query_pk self._order_by_id_asc = self.table[self.primary_keys[0]] self._order_by_id_desc = ~self.table[self.primary_keys[0]] else: self._query_id = reduce( - operator.and_, [ - self.table[key] != None # noqa + operator.and_, + [ + self.table[key] != None # noqa for key in self.primary_keys - ] + ], ) self._query_row = self._row_record_query_pks - self._order_by_id_asc = reduce( - operator.or_, [self.table[key] for key in self.primary_keys] - ) - self._order_by_id_desc = reduce( - operator.or_, [~self.table[key] for key in self.primary_keys] - ) + self._order_by_id_asc = reduce(operator.or_, [self.table[key] for key in self.primary_keys]) + self._order_by_id_desc = reduce(operator.or_, [~self.table[key] for key in self.primary_keys]) def __define_form_utils(self): #: labels @@ -924,27 +829,26 @@ def new(cls, **attributes): if callable(val): val = val() rowattrs[field] = val - for field in (inst.primary_keys or ["id"]): + for field in inst.primary_keys or ["id"]: if inst.table[field].type == "id": rowattrs[field] = None for field in set(inst._compound_relations_.keys()) & attrset: reldata = inst._compound_relations_[field] for local_field, foreign_field in reldata.coupled_fields: rowattrs[local_field] = attributes[field][foreign_field] - rv = inst._rowclass_( - rowattrs, __concrete=False, - **{k: attributes[k] for k in attrset - set(rowattrs)} - ) - rv._fields.update({ - field: attributes[field] if not attributes[field] else ( - typed_row_reference_from_record( - attributes[field], inst.db[inst._belongs_fks_[field].model]._model_ - ) if isinstance(attributes[field], StructuredRow) else - typed_row_reference( - attributes[field], inst.db[inst._belongs_fks_[field].model] + rv = inst._rowclass_(rowattrs, __concrete=False, **{k: attributes[k] for k in attrset - set(rowattrs)}) + rv._fields.update( + { + field: attributes[field] + if not attributes[field] + else ( + typed_row_reference_from_record(attributes[field], inst.db[inst._belongs_fks_[field].model]._model_) + if isinstance(attributes[field], StructuredRow) + else typed_row_reference(attributes[field], inst.db[inst._belongs_fks_[field].model]) ) - ) for field in inst._relations_wrapset & attrset - }) + for field in inst._relations_wrapset & attrset + } + ) return rv @classmethod @@ -965,7 +869,7 @@ def validate(cls, row, write_values: bool = False): inst, errors = cls._instance_(), sdict() for field_name in inst._fieldset_all: field = inst.table[field_name] - default = getattr(field, 'default') + default = getattr(field, "default") if callable(default): default = default() value = row.get(field_name, default) @@ -988,17 +892,11 @@ def all(cls): @classmethod def first(cls): - return cls.all().select( - orderby=cls._instance_()._order_by_id_asc, - limitby=(0, 1) - ).first() + return cls.all().select(orderby=cls._instance_()._order_by_id_asc, limitby=(0, 1)).first() @classmethod def last(cls): - return cls.all().select( - orderby=cls._instance_()._order_by_id_desc, - limitby=(0, 1) - ).first() + return cls.all().select(orderby=cls._instance_()._order_by_id_desc, limitby=(0, 1)).first() @classmethod def get(cls, *args, **kwargs): @@ -1010,40 +908,32 @@ def get(cls, *args, **kwargs): elif isinstance(args[0], dict) and not kwargs: return cls.table(**args[0]) if len(args) != len(inst._fieldset_pk): - raise SyntaxError( - f"{cls.__name__}.get requires the same number of arguments " - "as its primary keys" - ) + raise SyntaxError(f"{cls.__name__}.get requires the same number of arguments " "as its primary keys") pks = inst.primary_keys or ["id"] - return cls.table( - **{pks[idx]: val for idx, val in enumerate(args)} - ) + return cls.table(**{pks[idx]: val for idx, val in enumerate(args)}) return cls.table(**kwargs) - @rowmethod('update_record') + @rowmethod("update_record") def _update_record(self, row, skip_callbacks=False, **fields): newfields = fields or dict(row) for field_name in set(newfields.keys()) - self._fieldset_editable: del newfields[field_name] - res = self.db( - self._query_row(row), ignore_common_filters=True - ).update(skip_callbacks=skip_callbacks, **newfields) + res = self.db(self._query_row(row), ignore_common_filters=True).update( + skip_callbacks=skip_callbacks, **newfields + ) if res: row.update(self.get(**{key: row[key] for key in self._fieldset_pk})) return row - @rowmethod('delete_record') + @rowmethod("delete_record") def _delete_record(self, row, skip_callbacks=False): return self.db(self._query_row(row)).delete(skip_callbacks=skip_callbacks) - @rowmethod('refresh') + @rowmethod("refresh") def _row_refresh(self, row) -> bool: if not row._concrete: return False - last = self.db(self._query_row(row)).select( - limitby=(0, 1), - orderby_on_limitby=False - ).first() + last = self.db(self._query_row(row)).select(limitby=(0, 1), orderby_on_limitby=False).first() if not last: return False row._fields.update(last._fields) @@ -1051,19 +941,12 @@ def _row_refresh(self, row) -> bool: row._changes.clear() return True - @rowmethod('save') - def _row_save( - self, - row, - raise_on_error: bool = False, - skip_callbacks: bool = False - ) -> bool: + @rowmethod("save") + def _row_save(self, row, raise_on_error: bool = False, skip_callbacks: bool = False) -> bool: if row._concrete: if set(row._changes.keys()) & self._fieldset_pk: if raise_on_error: - raise SaveException( - 'Cannot save a record with altered primary key(s)' - ) + raise SaveException("Cannot save a record with altered primary key(s)") return False for field_name in self._fieldset_update: val = self.table[field_name].update @@ -1076,9 +959,9 @@ def _row_save( raise ValidationError return False if row._concrete: - res = self.db( - self._query_row(row), ignore_common_filters=True - )._update_from_save(self, row, skip_callbacks=skip_callbacks) + res = self.db(self._query_row(row), ignore_common_filters=True)._update_from_save( + self, row, skip_callbacks=skip_callbacks + ) if not res: if raise_on_error: raise UpdateFailureOnSave @@ -1089,26 +972,18 @@ def _row_save( if raise_on_error: raise InsertFailureOnSave return False - extra_changes = { - key: row._changes[key] - for key in set(row._changes.keys()) & set(row.__dict__.keys()) - } + extra_changes = {key: row._changes[key] for key in set(row._changes.keys()) & set(row.__dict__.keys())} row._changes.clear() row._changes.update(extra_changes) return True - @rowmethod('destroy') - def _row_destroy( - self, - row, - raise_on_error: bool = False, - skip_callbacks: bool = False - ) -> bool: + @rowmethod("destroy") + def _row_destroy(self, row, raise_on_error: bool = False, skip_callbacks: bool = False) -> bool: if not row._concrete: return False - res = self.db( - self._query_row(row), ignore_common_filters=True - )._delete_from_destroy(self, row, skip_callbacks=skip_callbacks) + res = self.db(self._query_row(row), ignore_common_filters=True)._delete_from_destroy( + self, row, skip_callbacks=skip_callbacks + ) if not res: if raise_on_error: raise DestroyException @@ -1192,9 +1067,9 @@ def __get__(self, obj, objtype=None): pks = {fk: obj[lk] for lk, fk in self.fields} key = (self.field, *pks.values()) if key not in obj._compound_rels: - obj._compound_rels[key] = RowReferenceMulti(pks, self.table) if all( - v is not None for v in pks.values() - ) else None + obj._compound_rels[key] = ( + RowReferenceMulti(pks, self.table) if all(v is not None for v in pks.values()) else None + ) return obj._compound_rels[key] def __delete__(self, obj): diff --git a/emmett/orm/objects.py b/emmett/orm/objects.py index b85f3f28..3dcba987 100644 --- a/emmett/orm/objects.py +++ b/emmett/orm/objects.py @@ -1,12 +1,12 @@ # -*- coding: utf-8 -*- """ - emmett.orm.objects - ------------------ +emmett.orm.objects +------------------ - Provides pyDAL objects implementation for Emmett. +Provides pyDAL objects implementation for Emmett. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ import copy @@ -14,23 +14,22 @@ import decimal import operator import types - from collections import OrderedDict, defaultdict from enum import Enum from functools import reduce from typing import Any, Dict, Optional +from emmett_core.utils import cachedprop from pydal.objects import ( - Table as _Table, + Expression, Field as _Field, - Set as _Set, - Row as _Row, - Rows as _Rows, IterRows as _IterRows, Query, - Expression + Row as _Row, + Rows as _Rows, + Set as _Set, + Table as _Table, ) -from emmett_core.utils import cachedprop from ..ctx import current from ..datastructures import sdict @@ -38,23 +37,21 @@ from ..serializers import xml_encode from ..validators import ValidateFromDict from .helpers import ( - RelationBuilder, GeoFieldWrapper, + RelationBuilder, RowReferenceMixin, + typed_row_reference_from_record, wrap_scope_on_set, - typed_row_reference_from_record ) + type_int = int class Table(_Table): def __init__(self, db, tablename, *fields, **kwargs): - _primary_keys, _notnulls = list(kwargs.get('primarykey', [])), {} - _notnulls = { - field.name: field.notnull - for field in fields if hasattr(field, 'notnull') - } + _primary_keys, _notnulls = list(kwargs.get("primarykey", [])), {} + _notnulls = {field.name: field.notnull for field in fields if hasattr(field, "notnull")} super(Table, self).__init__(db, tablename, *fields, **kwargs) self._before_save = [] self._after_save = [] @@ -75,57 +72,32 @@ def __init__(self, db, tablename, *fields, **kwargs): self._unique_fields_validation_ = {} self._primary_keys = _primary_keys #: avoid pyDAL mess in ops and migrations - if len(self._primary_keys) == 1 and getattr(self, '_primarykey', None): + if len(self._primary_keys) == 1 and getattr(self, "_primarykey", None): del self._primarykey for key in self._primary_keys: self[key].notnull = _notnulls[key] - if not hasattr(self, '_id'): + if not hasattr(self, "_id"): self._id = None @cachedprop def _has_commit_insert_callbacks(self): - return any([ - self._before_commit, - self._after_commit, - self._before_commit_insert, - self._after_commit_insert - ]) + return any([self._before_commit, self._after_commit, self._before_commit_insert, self._after_commit_insert]) @cachedprop def _has_commit_update_callbacks(self): - return any([ - self._before_commit, - self._after_commit, - self._before_commit_update, - self._after_commit_update - ]) + return any([self._before_commit, self._after_commit, self._before_commit_update, self._after_commit_update]) @cachedprop def _has_commit_delete_callbacks(self): - return any([ - self._before_commit, - self._after_commit, - self._before_commit_delete, - self._after_commit_delete - ]) + return any([self._before_commit, self._after_commit, self._before_commit_delete, self._after_commit_delete]) @cachedprop def _has_commit_save_callbacks(self): - return any([ - self._before_commit, - self._after_commit, - self._before_commit_save, - self._after_commit_save - ]) + return any([self._before_commit, self._after_commit, self._before_commit_save, self._after_commit_save]) @cachedprop def _has_commit_destroy_callbacks(self): - return any([ - self._before_commit, - self._after_commit, - self._before_commit_destroy, - self._after_commit_destroy - ]) + return any([self._before_commit, self._after_commit, self._before_commit_destroy, self._after_commit_destroy]) def _create_references(self): self._referenced_by = [] @@ -145,14 +117,7 @@ def insert(self, skip_callbacks=False, **fields): if self._has_commit_insert_callbacks: txn = self._db._adapter.top_transaction() if txn: - txn._add_op(TransactionOp( - TransactionOps.insert, - self, - TransactionOpContext( - values=row, - ret=ret - ) - )) + txn._add_op(TransactionOp(TransactionOps.insert, self, TransactionOpContext(values=row, ret=ret))) if ret: for f in self._after_insert: f(row, ret) @@ -177,16 +142,13 @@ def _insert_from_save(self, row, skip_callbacks=False): if self._has_commit_save_callbacks: txn = self._db._adapter.top_transaction() if txn: - txn._add_op(TransactionOp( - TransactionOps.save, - self, - TransactionOpContext( - values=fields, - ret=ret, - row=row.clone_changed(), - changes=row.changes + txn._add_op( + TransactionOp( + TransactionOps.save, + self, + TransactionOpContext(values=fields, ret=ret, row=row.clone_changed(), changes=row.changes), ) - )) + ) if row._concrete: for f in self._after_save: f(row) @@ -194,58 +156,48 @@ def _insert_from_save(self, row, skip_callbacks=False): class Field(_Field): - _internal_types = { - 'integer': 'int', - 'double': 'float', - 'boolean': 'bool', - 'list:integer': 'list:int' - } + _internal_types = {"integer": "int", "double": "float", "boolean": "bool", "list:integer": "list:int"} _pydal_types = { - 'int': 'integer', - 'bool': 'boolean', - 'list:int': 'list:integer', - } - _internal_delete = { - 'cascade': 'CASCADE', 'nullify': 'SET NULL', 'nothing': 'NO ACTION' + "int": "integer", + "bool": "boolean", + "list:int": "list:integer", } + _internal_delete = {"cascade": "CASCADE", "nullify": "SET NULL", "nothing": "NO ACTION"} _inst_count_ = 0 _obj_created_ = False - def __init__(self, type='string', *args, **kwargs): + def __init__(self, type="string", *args, **kwargs): self.modelname = None #: convert type self._type = self._internal_types.get(type, type) #: convert 'rw' -> 'readable', 'writeable' - if 'rw' in kwargs: - _rw = kwargs.pop('rw') + if "rw" in kwargs: + _rw = kwargs.pop("rw") if isinstance(_rw, (tuple, list)): read, write = _rw else: read = write = _rw - kwargs['readable'] = read - kwargs['writable'] = write + kwargs["readable"] = read + kwargs["writable"] = write #: convert 'info' -> 'comment' - _info = kwargs.pop('info', None) + _info = kwargs.pop("info", None) if _info: - kwargs['comment'] = _info + kwargs["comment"] = _info #: convert ondelete parameter - _ondelete = kwargs.get('ondelete') + _ondelete = kwargs.get("ondelete") if _ondelete: if _ondelete not in list(self._internal_delete): - raise SyntaxError( - 'Field ondelete should be set on %s, %s or %s' % - list(self._internal_delete) - ) - kwargs['ondelete'] = self._internal_delete[_ondelete] + raise SyntaxError("Field ondelete should be set on %s, %s or %s" % list(self._internal_delete)) + kwargs["ondelete"] = self._internal_delete[_ondelete] #: process 'refers_to' fields - self._isrefers = kwargs.pop('_isrefers', None) + self._isrefers = kwargs.pop("_isrefers", None) #: get auto validation preferences - self._auto_validation = kwargs.pop('auto_validation', True) + self._auto_validation = kwargs.pop("auto_validation", True) #: intercept validation (will be processed by `_make_field`) self._requires = {} self._custom_requires = [] - if 'validation' in kwargs: - _validation = kwargs.pop('validation') + if "validation" in kwargs: + _validation = kwargs.pop("validation") if isinstance(_validation, dict): self._requires = _validation else: @@ -255,10 +207,10 @@ def __init__(self, type='string', *args, **kwargs): self._validation = {} self._vparser = ValidateFromDict() #: ensure 'length' is an integer - if 'length' in kwargs: - kwargs['length'] = int(kwargs['length']) + if "length" in kwargs: + kwargs["length"] = int(kwargs["length"]) #: store args and kwargs for `_make_field` - self._ormkw = kwargs.pop('_kw', {}) + self._ormkw = kwargs.pop("_kw", {}) self._args = args self._kwargs = kwargs #: increase creation counter (used to keep order of fields) @@ -267,42 +219,36 @@ def __init__(self, type='string', *args, **kwargs): def _default_validation(self): rv = {} - auto_types = [ - 'int', 'float', 'date', 'time', 'datetime', 'json' - ] + auto_types = ["int", "float", "date", "time", "datetime", "json"] if self._type in auto_types: - rv['is'] = self._type - elif self._type.startswith('decimal'): - rv['is'] = 'decimal' - elif self._type == 'jsonb': - rv['is'] = 'json' - if self._type == 'bigint': - rv['is'] = 'int' - if self._type == 'bool': - rv['in'] = (False, True) - if self._type in ['string', 'text', 'password']: - rv['len'] = {'lte': self.length} - if self._type == 'password': - rv['len']['gte'] = 6 - rv['crypt'] = True - if self._type == 'list:int': - rv['is'] = 'list:int' - if ( - self.notnull or self._type.startswith('reference') or - self._type.startswith('list:reference') - ): - rv['presence'] = True + rv["is"] = self._type + elif self._type.startswith("decimal"): + rv["is"] = "decimal" + elif self._type == "jsonb": + rv["is"] = "json" + if self._type == "bigint": + rv["is"] = "int" + if self._type == "bool": + rv["in"] = (False, True) + if self._type in ["string", "text", "password"]: + rv["len"] = {"lte": self.length} + if self._type == "password": + rv["len"]["gte"] = 6 + rv["crypt"] = True + if self._type == "list:int": + rv["is"] = "list:int" + if self.notnull or self._type.startswith("reference") or self._type.startswith("list:reference"): + rv["presence"] = True if not self.notnull and self._isrefers is True: - rv['allow'] = 'empty' + rv["allow"] = "empty" if self.unique: - rv['unique'] = True + rv["unique"] = True return rv def _parse_validation(self): for key in list(self._requires): self._validation[key] = self._requires[key] - self.requires = self._vparser(self, self._validation) + \ - self._custom_requires + self.requires = self._vparser(self, self._validation) + self._custom_requires #: `_make_field` will be called by `Model` class or `Form` class # it will make internal Field class compatible with the pyDAL's one @@ -339,113 +285,102 @@ def __str__(self): return object.__str__(self) def __repr__(self): - if self.modelname and hasattr(self, 'name'): - return "<%s.%s (%s) field>" % (self.modelname, self.name, - self._type) + if self.modelname and hasattr(self, "name"): + return "<%s.%s (%s) field>" % (self.modelname, self.name, self._type) return super(Field, self).__repr__() @classmethod def string(cls, *args, **kwargs): - return cls('string', *args, **kwargs) + return cls("string", *args, **kwargs) @classmethod def int(cls, *args, **kwargs): - return cls('int', *args, **kwargs) + return cls("int", *args, **kwargs) @classmethod def bigint(cls, *args, **kwargs): - return cls('bigint', *args, **kwargs) + return cls("bigint", *args, **kwargs) @classmethod def float(cls, *args, **kwargs): - return cls('float', *args, **kwargs) + return cls("float", *args, **kwargs) @classmethod def text(cls, *args, **kwargs): - return cls('text', *args, **kwargs) + return cls("text", *args, **kwargs) @classmethod def bool(cls, *args, **kwargs): - return cls('bool', *args, **kwargs) + return cls("bool", *args, **kwargs) @classmethod def blob(cls, *args, **kwargs): - return cls('blob', *args, **kwargs) + return cls("blob", *args, **kwargs) @classmethod def date(cls, *args, **kwargs): - return cls('date', *args, **kwargs) + return cls("date", *args, **kwargs) @classmethod def time(cls, *args, **kwargs): - return cls('time', *args, **kwargs) + return cls("time", *args, **kwargs) @classmethod def datetime(cls, *args, **kwargs): - return cls('datetime', *args, **kwargs) + return cls("datetime", *args, **kwargs) @classmethod def decimal(cls, precision, scale, *args, **kwargs): - return cls('decimal({},{})'.format(precision, scale), *args, **kwargs) + return cls("decimal({},{})".format(precision, scale), *args, **kwargs) @classmethod def json(cls, *args, **kwargs): - return cls('json', *args, **kwargs) + return cls("json", *args, **kwargs) @classmethod def jsonb(cls, *args, **kwargs): - return cls('jsonb', *args, **kwargs) + return cls("jsonb", *args, **kwargs) @classmethod def password(cls, *args, **kwargs): - return cls('password', *args, **kwargs) + return cls("password", *args, **kwargs) @classmethod def upload(cls, *args, **kwargs): - return cls('upload', *args, **kwargs) + return cls("upload", *args, **kwargs) @classmethod def int_list(cls, *args, **kwargs): - return cls('list:int', *args, **kwargs) + return cls("list:int", *args, **kwargs) @classmethod def string_list(cls, *args, **kwargs): - return cls('list:string', *args, **kwargs) + return cls("list:string", *args, **kwargs) @classmethod def geography( cls, - geometry_type: str = 'GEOMETRY', + geometry_type: str = "GEOMETRY", srid: Optional[type_int] = None, dimension: Optional[type_int] = None, - **kwargs + **kwargs, ): - kwargs['_kw'] = { - "geometry_type": geometry_type, - "srid": srid, - "dimension": dimension - } + kwargs["_kw"] = {"geometry_type": geometry_type, "srid": srid, "dimension": dimension} return cls("geography", **kwargs) @classmethod def geometry( cls, - geometry_type: str = 'GEOMETRY', + geometry_type: str = "GEOMETRY", srid: Optional[type_int] = None, dimension: Optional[type_int] = None, - **kwargs + **kwargs, ): - kwargs['_kw'] = { - "geometry_type": geometry_type, - "srid": srid, - "dimension": dimension - } + kwargs["_kw"] = {"geometry_type": geometry_type, "srid": srid, "dimension": dimension} return cls("geometry", **kwargs) def cast(self, value, **kwargs): - return Expression( - self.db, self._dialect.cast, self, - self._dialect.types[value] % kwargs, value) + return Expression(self.db, self._dialect.cast, self, self._dialect.types[value] % kwargs, value) class Set(_Set): @@ -465,9 +400,10 @@ def _load_scopes_(self): def _clone(self, ignore_common_filters=None, model=None, **changes): return self.__class__( - self.db, changes.get('query', self.query), + self.db, + changes.get("query", self.query), ignore_common_filters=ignore_common_filters, - model=model or self._model_ + model=model or self._model_, ) def where(self, query, ignore_common_filters=None, model=None): @@ -478,12 +414,11 @@ def where(self, query, ignore_common_filters=None, model=None): elif isinstance(query, str): query = Expression(self.db, query) elif isinstance(query, Field): - query = query != None + query = query != None # noqa: E711 elif isinstance(query, types.LambdaType): model = model or self._model_ if not model: - raise ValueError( - "Too many models involved in the Set to use a lambda") + raise ValueError("Too many models involved in the Set to use a lambda") query = query(model) q = self.query & query if self.query else query return self._clone(ignore_common_filters, model, query=q) @@ -498,27 +433,21 @@ def _parse_paginate(self, pagination): return ((offset - 1) * limit, offset * limit) def _join_set_builder(self, obj, jdata, auto_select_tables): - return JoinedSet._from_set( - obj, jdata=jdata, auto_select_tables=auto_select_tables - ) + return JoinedSet._from_set(obj, jdata=jdata, auto_select_tables=auto_select_tables) def _left_join_set_builder(self, jdata): - return JoinedSet._from_set( - self, ljdata=jdata, auto_select_tables=[self._model_.table] - ) + return JoinedSet._from_set(self, ljdata=jdata, auto_select_tables=[self._model_.table]) def _run_select_(self, *fields, **options): tablemap = self.db._adapter.tables( self.query, - options.get('join', None), - options.get('left', None), - options.get('orderby', None), - options.get('groupby', None) + options.get("join", None), + options.get("left", None), + options.get("orderby", None), + options.get("groupby", None), ) - fields, concrete_tables = self.db._adapter._expand_all_with_concrete_tables( - fields, tablemap - ) - options['_concrete_tables'] = concrete_tables + fields, concrete_tables = self.db._adapter._expand_all_with_concrete_tables(fields, tablemap) + options["_concrete_tables"] = concrete_tables return self.db._adapter.select(self.query, fields, options) def _get_table_from_query(self) -> Table: @@ -528,32 +457,27 @@ def _get_table_from_query(self) -> Table: def select(self, *fields, **options): obj = self - pagination, including = ( - options.pop('paginate', None), - options.pop('including', None) - ) + pagination, including = (options.pop("paginate", None), options.pop("including", None)) if pagination: - options['limitby'] = self._parse_paginate(pagination) + options["limitby"] = self._parse_paginate(pagination) if including and self._model_ is not None: - options['left'], jdata = self._parse_left_rjoins(including) + options["left"], jdata = self._parse_left_rjoins(including) obj = self._left_join_set_builder(jdata) return obj._run_select_(*fields, **options) def iterselect(self, *fields, **options): - pagination = options.pop('paginate', None) + pagination = options.pop("paginate", None) if pagination: - options['limitby'] = self._parse_paginate(pagination) + options["limitby"] = self._parse_paginate(pagination) tablemap = self.db._adapter.tables( self.query, - options.get('join', None), - options.get('left', None), - options.get('orderby', None), - options.get('groupby', None) + options.get("join", None), + options.get("left", None), + options.get("orderby", None), + options.get("groupby", None), ) - fields, concrete_tables = self.db._adapter._expand_all_with_concrete_tables( - fields, tablemap - ) - options['_concrete_tables'] = concrete_tables + fields, concrete_tables = self.db._adapter._expand_all_with_concrete_tables(fields, tablemap) + options["_concrete_tables"] = concrete_tables return self.db._adapter.iterselect(self.query, fields, options) def update(self, skip_callbacks=False, **update_fields): @@ -568,15 +492,11 @@ def update(self, skip_callbacks=False, **update_fields): if table._has_commit_update_callbacks: txn = self._db._adapter.top_transaction() if txn: - txn._add_op(TransactionOp( - TransactionOps.update, - table, - TransactionOpContext( - values=row, - dbset=self, - ret=ret + txn._add_op( + TransactionOp( + TransactionOps.update, table, TransactionOpContext(values=row, dbset=self, ret=ret) ) - )) + ) ret and [f(self, row) for f in table._after_update] return ret @@ -589,14 +509,7 @@ def delete(self, skip_callbacks=False): if table._has_commit_delete_callbacks: txn = self._db._adapter.top_transaction() if txn: - txn._add_op(TransactionOp( - TransactionOps.delete, - table, - TransactionOpContext( - dbset=self, - ret=ret - ) - )) + txn._add_op(TransactionOp(TransactionOps.delete, table, TransactionOpContext(dbset=self, ret=ret))) ret and [f(self) for f in table._after_delete] return ret @@ -604,19 +517,15 @@ def validate_and_update(self, skip_callbacks=False, **update_fields): table = self._get_table_from_query() current._dbvalidation_record_id_ = None if table._unique_fields_validation_ and self.count() == 1: - if any( - table._unique_fields_validation_.get(fieldname) - for fieldname in update_fields.keys() - ): - current._dbvalidation_record_id_ = \ - self.select(table.id).first().id + if any(table._unique_fields_validation_.get(fieldname) for fieldname in update_fields.keys()): + current._dbvalidation_record_id_ = self.select(table.id).first().id response = Row() response.errors = Row() new_fields = copy.copy(update_fields) for key, value in update_fields.items(): value, error = table[key].validate(value) if error: - response.errors[key] = '%s' % error + response.errors[key] = "%s" % error else: new_fields[key] = value del current._dbvalidation_record_id_ @@ -629,9 +538,7 @@ def validate_and_update(self, skip_callbacks=False, **update_fields): if not skip_callbacks and any(f(self, row) for f in table._before_update): ret = 0 else: - ret = self.db._adapter.update( - table, self.query, row.op_values() - ) + ret = self.db._adapter.update(table, self.query, row.op_values()) if not skip_callbacks and ret: for f in table._after_update: f(self, row) @@ -642,9 +549,7 @@ def _update_from_save(self, model, row, skip_callbacks=False): table: Table = model.table if not skip_callbacks and any(f(row) for f in table._before_save): return False - fields = table._fields_and_values_for_save( - row, model._fieldset_editable, table._fields_and_values_for_update - ) + fields = table._fields_and_values_for_save(row, model._fieldset_editable, table._fields_and_values_for_update) if not skip_callbacks and any(f(self, fields) for f in table._before_update): return False ret = self.db._adapter.update(table, self.query, fields.op_values()) @@ -652,27 +557,21 @@ def _update_from_save(self, model, row, skip_callbacks=False): if table._has_commit_update_callbacks or table._has_commit_save_callbacks: txn = self._db._adapter.top_transaction() if txn and table._has_commit_update_callbacks: - txn._add_op(TransactionOp( - TransactionOps.update, - table, - TransactionOpContext( - values=fields, - dbset=self, - ret=ret + txn._add_op( + TransactionOp( + TransactionOps.update, table, TransactionOpContext(values=fields, dbset=self, ret=ret) ) - )) + ) if txn and table._has_commit_save_callbacks: - txn._add_op(TransactionOp( - TransactionOps.save, - table, - TransactionOpContext( - values=fields, - dbset=self, - ret=ret, - row=row.clone_changed(), - changes=row.changes + txn._add_op( + TransactionOp( + TransactionOps.save, + table, + TransactionOpContext( + values=fields, dbset=self, ret=ret, row=row.clone_changed(), changes=row.changes + ), ) - )) + ) ret and [f(self, fields) for f in table._after_update] ret and [f(row) for f in table._after_save] return bool(ret) @@ -687,31 +586,18 @@ def _delete_from_destroy(self, model, row, skip_callbacks=False): if ret: model._unset_row_persistence(row) if not skip_callbacks: - if ( - table._has_commit_delete_callbacks or - table._has_commit_destroy_callbacks - ): + if table._has_commit_delete_callbacks or table._has_commit_destroy_callbacks: txn = self._db._adapter.top_transaction() if txn and table._has_commit_delete_callbacks: - txn._add_op(TransactionOp( - TransactionOps.delete, - table, - TransactionOpContext( - dbset=self, - ret=ret - ) - )) + txn._add_op(TransactionOp(TransactionOps.delete, table, TransactionOpContext(dbset=self, ret=ret))) if txn and table._has_commit_destroy_callbacks: - txn._add_op(TransactionOp( - TransactionOps.destroy, - table, - TransactionOpContext( - dbset=self, - ret=ret, - row=row.clone_changed(), - changes=row.changes + txn._add_op( + TransactionOp( + TransactionOps.destroy, + table, + TransactionOpContext(dbset=self, ret=ret, row=row.clone_changed(), changes=row.changes), ) - )) + ) ret and [f(self) for f in table._after_delete] ret and [f(row) for f in table._after_destroy] return bool(ret) @@ -745,25 +631,23 @@ def _parse_rjoin(self, arg): if rel: if rel.via: r = RelationBuilder(rel, self._model_._instance_()).via() - return r[0], r[1]._table, 'many' + return r[0], r[1]._table, "many" r = RelationBuilder(rel, self._model_._instance_()) - return r.many(), rel.table, 'many' + return r.many(), rel.table, "many" #: match belongs_to and refers_to rel = self._model_._belongs_fks_.get(arg) if rel: r = RelationBuilder(rel, self._model_._instance_()).belongs_query() - return r, self._model_.db[rel.model], 'belongs' + return r, self._model_.db[rel.model], "belongs" #: match has_one rel = self._model_._hasone_ref_.get(arg) if rel: if rel.via: r = RelationBuilder(rel, self._model_._instance_()).via() - return r[0], r[1]._table, 'one' + return r[0], r[1]._table, "one" r = RelationBuilder(rel, self._model_._instance_()) - return r.many(), rel.table, 'one' - raise RuntimeError( - f'Unable to find {arg} relation of {self._model_.__name__} model' - ) + return r.many(), rel.table, "one" + raise RuntimeError(f"Unable to find {arg} relation of {self._model_.__name__} model") def _parse_left_rjoins(self, args): if not isinstance(args, (list, tuple)): @@ -798,7 +682,7 @@ def __getattr__(self, name): class RelationSet(object): - _relation_method_ = 'many' + _relation_method_ = "many" def __init__(self, db, relation_builder, row): self.db = db @@ -821,8 +705,7 @@ def _model_(self): def _fields_(self): pks = self._relation_.model.primary_keys or ["id"] return [ - (relation_field.name, pks[idx]) - for idx, relation_field in enumerate(self._relation_.ref.fields_instances) + (relation_field.name, pks[idx]) for idx, relation_field in enumerate(self._relation_.ref.fields_instances) ] @cachedprop @@ -845,12 +728,12 @@ def __call__(self, query, ignore_common_filters=False): return self._set.where(query, ignore_common_filters) def _last_resultset(self, refresh=False): - if refresh or not hasattr(self, '_cached_resultset'): + if refresh or not hasattr(self, "_cached_resultset"): self._cached_resultset = self._cache_resultset() return self._cached_resultset def _filter_reload(self, kwargs): - return kwargs.pop('reload', False) + return kwargs.pop("reload", False) def new(self, **kwargs): attrs = self._get_fields_from_scopes(self._scopes_, self._model_.tablename) @@ -879,10 +762,7 @@ def _get_fields_from_scopes(scopes, table_name): components.append(component.second) components.append(component.first) else: - if ( - isinstance(component, Field) and - component._tablename == table_name - ): + if isinstance(component, Field) and component._tablename == table_name: current_kv.append(component) else: if current_kv: @@ -903,7 +783,7 @@ def __call__(self, *args, **kwargs): refresh = self._filter_reload(kwargs) if not args and not kwargs: return self._last_resultset(refresh) - kwargs['limitby'] = (0, 1) + kwargs["limitby"] = (0, 1) return self.select(*args, **kwargs).first() @@ -920,24 +800,18 @@ def __call__(self, *args, **kwargs): def add(self, obj, skip_callbacks=False): if not isinstance(obj, (StructuredRow, RowReferenceMixin)): raise RuntimeError(f"Unsupported parameter {obj}") - attrs = self._get_fields_from_scopes( - self._scopes_, self._model_.tablename - ) + attrs = self._get_fields_from_scopes(self._scopes_, self._model_.tablename) rev_attrs = {} for ref, local in self._fields_: attrs[ref] = self._row_[local] rev_attrs[local] = attrs[ref] - rv = self.db(self._model_._query_row(obj)).validate_and_update( - skip_callbacks=skip_callbacks, **attrs - ) + rv = self.db(self._model_._query_row(obj)).validate_and_update(skip_callbacks=skip_callbacks, **attrs) if rv: for key, val in attrs.items(): obj[key] = val if len(rev_attrs) > 1: comprel_key = (self._relation_.ref.reverse, *rev_attrs.values()) - obj._compound_rels[comprel_key] = typed_row_reference_from_record( - self._row_, self._row_._model - ) + obj._compound_rels[comprel_key] = typed_row_reference_from_record(self._row_, self._row_._model) return rv def remove(self, obj, skip_callbacks=False): @@ -945,15 +819,10 @@ def remove(self, obj, skip_callbacks=False): raise RuntimeError(f"Unsupported parameter {obj}") attrs, is_delete = {ref: None for ref, _ in self._fields_}, False if self._model_._belongs_fks_[self._relation_.ref.reverse].is_refers: - rv = self.db(self._model_._query_row(obj)).validate_and_update( - skip_callbacks=skip_callbacks, - **attrs - ) + rv = self.db(self._model_._query_row(obj)).validate_and_update(skip_callbacks=skip_callbacks, **attrs) else: is_delete = True - rv = self.db(self._model_._query_row(obj)).delete( - skip_callbacks=skip_callbacks - ) + rv = self.db(self._model_._query_row(obj)).delete(skip_callbacks=skip_callbacks) if rv: for key, val in attrs.items(): obj[key] = val @@ -965,19 +834,12 @@ def remove(self, obj, skip_callbacks=False): class ViaSet(RelationSet): - _relation_method_ = 'via' + _relation_method_ = "via" @cachedprop def _viadata(self): query, rfield, model_name, rid, via, viadata = super()._get_query_() - return sdict( - query=query, - rfield=rfield, - model_name=model_name, - rid=rid, - via=via, - data=viadata - ) + return sdict(query=query, rfield=rfield, model_name=model_name, rid=rid, via=via, data=viadata) def _get_query_(self): return self._viadata.query @@ -997,11 +859,11 @@ def __call__(self, *args, **kwargs): if not kwargs: return self._last_resultset(refresh) args = [self._viadata.rfield] - kwargs['limitby'] = (0, 1) + kwargs["limitby"] = (0, 1) return self.select(*args, **kwargs).first() def create(self, **kwargs): - raise RuntimeError('Cannot create third objects for one via relations') + raise RuntimeError("Cannot create third objects for one via relations") class HasManyViaSet(ViaSet): @@ -1037,12 +899,12 @@ def _fields_from_scopes(self): return self._get_fields_from_scopes(scopes, rel.table_name) def create(self, **kwargs): - raise RuntimeError('Cannot create third objects for many via relations') + raise RuntimeError("Cannot create third objects for many via relations") def add(self, obj, skip_callbacks=False, **kwargs): # works on join tables only! if self._viadata.via is None: - raise RuntimeError(self._via_error % 'add') + raise RuntimeError(self._via_error % "add") nrow = self._fields_from_scopes() nrow.update(**kwargs) #: get belongs references @@ -1052,25 +914,22 @@ def add(self, obj, skip_callbacks=False, **kwargs): for local_field, foreign_field in rel_fields: nrow[local_field] = obj[foreign_field] #: validate and insert - return self.db[self._viadata.via]._model_.create( - nrow, skip_callbacks=skip_callbacks - ) + return self.db[self._viadata.via]._model_.create(nrow, skip_callbacks=skip_callbacks) def remove(self, obj, skip_callbacks=False): # works on join tables only! if self._viadata.via is None: - raise RuntimeError(self._via_error % 'remove') + raise RuntimeError(self._via_error % "remove") #: get belongs references self_fields, rel_fields = self._get_relation_fields() #: delete query = reduce( - operator.and_, [ - self.db[self._viadata.via][field] == self._viadata.rid[idx] - for idx, field in enumerate(self_fields) - ] + [ + operator.and_, + [self.db[self._viadata.via][field] == self._viadata.rid[idx] for idx, field in enumerate(self_fields)] + + [ self.db[self._viadata.via][local_field] == obj[foreign_field] for local_field, foreign_field in rel_fields - ] + ], ) return self.db(query).delete(skip_callbacks=skip_callbacks) @@ -1097,46 +956,40 @@ def _clone(self, ignore_common_filters=None, model=None, **changes): def _join_set_builder(self, obj, jdata, auto_select_tables): return JoinedSet._from_set( - obj, jdata=self._jdata_ + jdata, ljdata=self._ljdata_, - auto_select_tables=self._auto_select_tables_ + auto_select_tables + obj, + jdata=self._jdata_ + jdata, + ljdata=self._ljdata_, + auto_select_tables=self._auto_select_tables_ + auto_select_tables, ) def _left_join_set_builder(self, jdata): return JoinedSet._from_set( - self, jdata=self._jdata_, ljdata=self._ljdata_ + jdata, - auto_select_tables=self._auto_select_tables_ + self, jdata=self._jdata_, ljdata=self._ljdata_ + jdata, auto_select_tables=self._auto_select_tables_ ) def _iterselect_rows(self, *fields, **options): tablemap = self.db._adapter.tables( self.query, - options.get('join', None), - options.get('left', None), - options.get('orderby', None), - options.get('groupby', None) - ) - fields, concrete_tables = self.db._adapter._expand_all_with_concrete_tables( - fields, tablemap + options.get("join", None), + options.get("left", None), + options.get("orderby", None), + options.get("groupby", None), ) + fields, concrete_tables = self.db._adapter._expand_all_with_concrete_tables(fields, tablemap) colnames, sql = self.db._adapter._select_wcols(self.query, fields, **options) return JoinIterRows(self.db, sql, fields, concrete_tables, colnames) def _split_joins(self, joins): - rv = {'belongs': [], 'one': [], 'many': []} + rv = {"belongs": [], "one": [], "many": []} for jname, jtable, rel_type in joins: rv[rel_type].append((jname, jtable)) - return rv['belongs'], rv['one'], rv['many'] + return rv["belongs"], rv["one"], rv["many"] def _build_records_from_joined(self, rowmap, inclusions, colnames): for rid, many_data in inclusions.items(): for jname, included in many_data.items(): - rowmap[rid][jname]._cached_resultset = Rows( - self.db, list(included.values()), [] - ) - return JoinRows( - self.db, list(rowmap.values()), colnames, - _jdata=self._jdata_ + self._ljdata_ - ) + rowmap[rid][jname]._cached_resultset = Rows(self.db, list(included.values()), []) + return JoinRows(self.db, list(rowmap.values()), colnames, _jdata=self._jdata_ + self._ljdata_) def _select_rowpks_extractor(self, row): if not set(row.keys()).issuperset(self._pks_): @@ -1149,27 +1002,20 @@ def _run_select_(self, *fields, **options): #: build parsers belongs_j, one_j, many_j = self._split_joins(self._jdata_) belongs_l, one_l, many_l = self._split_joins(self._ljdata_) - parsers = ( - self._build_jparsers(belongs_j, one_j, many_j) + - self._build_lparsers(belongs_l, one_l, many_l) - ) + parsers = self._build_jparsers(belongs_j, one_j, many_j) + self._build_lparsers(belongs_l, one_l, many_l) #: auto add selection field for left joins if self._ljdata_: fields = list(fields) if not fields: fields = [v.ALL for v in self._auto_select_tables_] - for join in options['left']: + for join in options["left"]: fields.append(join.first.ALL) #: use iterselect for performance rows = self._iterselect_rows(*fields, **options) #: rebuild rowset using nested objects plainrows = [] rowmap = OrderedDict() - inclusions = defaultdict( - lambda: { - jname: OrderedDict() for jname, _ in (many_j + many_l) - } - ) + inclusions = defaultdict(lambda: {jname: OrderedDict() for jname, _ in (many_j + many_l)}) for row in rows: if self._stable_ not in row: plainrows.append(row) @@ -1183,9 +1029,7 @@ def _run_select_(self, *fields, **options): parser(rowmap, inclusions, row, rid) if not rowmap and plainrows: return Rows(self.db, plainrows, rows.colnames) - return self._build_records_from_joined( - rowmap, inclusions, rows.colnames - ) + return self._build_records_from_joined(rowmap, inclusions, rows.colnames) def _build_jparsers(self, belongs, one, many): rv = [] @@ -1212,15 +1056,15 @@ def _jbelong_parser(db, fieldname, tablename): rmodel = db[tablename]._model_ def parser(rowmap, inclusions, row, rid): - rowmap[rid][fieldname] = typed_row_reference_from_record( - row[tablename], rmodel - ) + rowmap[rid][fieldname] = typed_row_reference_from_record(row[tablename], rmodel) + return parser @staticmethod def _jone_parser(db, fieldname, tablename): def parser(rowmap, inclusions, row, rid): rowmap[rid][fieldname]._cached_resultset = row[tablename] + return parser @staticmethod @@ -1230,10 +1074,10 @@ def _jmany_parser(db, fieldname, tablename): ext = lambda row: tuple(row[pk] for pk in pks) if len(pks) > 1 else row[pks[0]] def parser(rowmap, inclusions, row, rid): - inclusions[rid][fieldname][ext(row[tablename])] = \ - inclusions[rid][fieldname].get( - ext(row[tablename]), row[tablename] - ) + inclusions[rid][fieldname][ext(row[tablename])] = inclusions[rid][fieldname].get( + ext(row[tablename]), row[tablename] + ) + return parser @staticmethod @@ -1245,9 +1089,8 @@ def _lbelong_parser(db, fieldname, tablename): def parser(rowmap, inclusions, row, rid): if not check(row[tablename]): return - rowmap[rid][fieldname] = typed_row_reference_from_record( - row[tablename], rmodel - ) + rowmap[rid][fieldname] = typed_row_reference_from_record(row[tablename], rmodel) + return parser @staticmethod @@ -1260,6 +1103,7 @@ def parser(rowmap, inclusions, row, rid): if not check(row[tablename]): return rowmap[rid][fieldname]._cached_resultset = row[tablename] + return parser @staticmethod @@ -1272,17 +1116,16 @@ def _lmany_parser(db, fieldname, tablename): def parser(rowmap, inclusions, row, rid): if not check(row[tablename]): return - inclusions[rid][fieldname][ext(row[tablename])] = \ - inclusions[rid][fieldname].get( - ext(row[tablename]), row[tablename] - ) + inclusions[rid][fieldname][ext(row[tablename])] = inclusions[rid][fieldname].get( + ext(row[tablename]), row[tablename] + ) + return parser class Row(_Row): _as_dict_types_ = tuple( - [type(None)] + [int, float, bool, list, dict, str] + - [datetime.datetime, datetime.date, datetime.time] + [type(None)] + [int, float, bool, list, dict, str] + [datetime.datetime, datetime.date, datetime.time] ) @classmethod @@ -1312,10 +1155,10 @@ def __json__(self): return self.as_dict() def __xml__(self, key=None, quote=True): - return xml_encode(self.as_dict(), key or 'row', quote) + return xml_encode(self.as_dict(), key or "row", quote) def __str__(self): - return ''.format(self.as_dict(geo_coordinates=False)) + return "".format(self.as_dict(geo_coordinates=False)) def __repr__(self): return str(self) @@ -1353,11 +1196,7 @@ def _from_engine(cls, data: Dict[str, Any]): rv._fields.update(data) return rv - def __init__( - self, - fields: Optional[Dict[str, Any]] = None, - **extras: Any - ): + def __init__(self, fields: Optional[Dict[str, Any]] = None, **extras: Any): object.__setattr__(self, "_changes", {}) object.__setattr__(self, "_compound_rels", {}) object.__setattr__(self, "_concrete", extras.pop("__concrete", False)) @@ -1374,10 +1213,7 @@ def __getitem__(self, name): def __setattr__(self, key, value): if key in self.__slots__: return - oldv = ( - self._changes[key][0] if key in self._changes else - getattr(self, key, None) - ) + oldv = self._changes[key][0] if key in self._changes else getattr(self, key, None) object.__setattr__(self, key, value) newv = getattr(self, key, None) if (oldv is None and value is not None) or oldv != newv: @@ -1392,12 +1228,7 @@ def __getstate__(self): return { "__fields": self._fields, "__extras": self.__dict__, - "__struct": { - "_concrete": self._concrete, - "_changes": {}, - "_compound_rels": {}, - "_virtuals": {} - } + "__struct": {"_concrete": self._concrete, "_changes": {}, "_compound_rels": {}, "_virtuals": {}}, } def __setstate__(self, state): @@ -1464,9 +1295,7 @@ def clone(self): return self.__class__(fields, __concrete=self._concrete, **self.__dict__) def clone_changed(self): - return self.__class__( - {**self._fields}, __concrete=self._concrete, **self.__dict__ - ) + return self.__class__({**self._fields}, __concrete=self._concrete, **self.__dict__) @property def validation_errors(self): @@ -1478,9 +1307,7 @@ def is_valid(self): class Rows(_Rows): - def __init__( - self, db=None, records=[], colnames=[], compact=True, rawrows=None - ): + def __init__(self, db=None, records=[], colnames=[], compact=True, rawrows=None): self.db = db self.records = records self.colnames = colnames @@ -1491,7 +1318,7 @@ def __init__( def compact(self): if not self.records: return False - return len(self._rowkeys_) == 1 and self._rowkeys_[0] != '_extra' + return len(self._rowkeys_) == 1 and self._rowkeys_[0] != "_extra" @cachedprop def compact_tablename(self): @@ -1519,7 +1346,7 @@ def sorted(self, f, reverse=False): keyf = lambda r: f(r[self.compact_tablename]) else: keyf = f - return [r for r in sorted(self.records, key=keyf, reverse=reverse)] + return sorted(self.records, key=keyf, reverse=reverse) def sort(self, f, reverse=False): self.records = self.sorted(f, reverse) @@ -1538,9 +1365,9 @@ def render(self, *args, **kwargs): def as_list(self, datetime_to_str=False, custom_types=None): return [item.as_dict(datetime_to_str, custom_types) for item in self] - def as_dict(self, key='id', datetime_to_str=False, custom_types=None): - if '.' in key: - splitted_key = key.split('.') + def as_dict(self, key="id", datetime_to_str=False, custom_types=None): + if "." in key: + splitted_key = key.split(".") keyf = lambda row: row[splitted_key[0]][splitted_key[1]] else: keyf = lambda row: row[key] @@ -1550,7 +1377,7 @@ def __json__(self): return [item.__json__() for item in self] def __xml__(self, key=None, quote=True): - key = key or 'rows' + key = key or "rows" return tag[key](*[item.__xml__(quote=quote) for item in self]) def __str__(self): @@ -1580,17 +1407,11 @@ def __next__(self): if db_row is None: raise StopIteration row = self.db._adapter._parse( - db_row, - self.fdata, - self.tables, - self.concrete_tables, - self.fields, - self.colnames, - self.blob_decode + db_row, self.fdata, self.tables, self.concrete_tables, self.fields, self.colnames, self.blob_decode ) if self.compact: keys = list(row.keys()) - if len(keys) == 1 and keys[0] != '_extra': + if len(keys) == 1 and keys[0] != "_extra": row = row[keys[0]] return row @@ -1609,13 +1430,10 @@ def __iter__(self): class JoinRows(Rows): def __init__(self, *args, **kwargs): - self._joins_ = kwargs.pop('_jdata') + self._joins_ = kwargs.pop("_jdata") super(JoinRows, self).__init__(*args, **kwargs) - def as_list( - self, compact=True, storage_to_dict=True, datetime_to_str=False, - custom_types=None - ): + def as_list(self, compact=True, storage_to_dict=True, datetime_to_str=False, custom_types=None): if storage_to_dict: items = [] for row in self: @@ -1625,7 +1443,7 @@ def as_list( item[jdata[0]] = row[jdata[0]].as_list() items.append(item) else: - items = [item for item in self] + items = list(self) return items @@ -1654,7 +1472,7 @@ def __init__( dbset: Optional[Set] = None, ret: Any = None, row: Optional[Row] = None, - changes: Optional[sdict] = None + changes: Optional[sdict] = None, ): self.values = values self.dbset = dbset @@ -1666,12 +1484,7 @@ def __init__( class TransactionOp: __slots__ = ["op_type", "table", "context"] - def __init__( - self, - op_type: TransactionOps, - table: Table, - context: TransactionOpContext - ): + def __init__(self, op_type: TransactionOps, table: Table, context: TransactionOpContext): self.op_type = op_type self.table = table self.context = context diff --git a/emmett/orm/transactions.py b/emmett/orm/transactions.py index d7e7160d..7dfed161 100644 --- a/emmett/orm/transactions.py +++ b/emmett/orm/transactions.py @@ -1,20 +1,19 @@ # -*- coding: utf-8 -*- """ - emmett.orm.transactions - ----------------------- +emmett.orm.transactions +----------------------- - Provides pyDAL advanced transactions implementation for Emmett. +Provides pyDAL advanced transactions implementation for Emmett. - :copyright: 2014 Giovanni Barillari +:copyright: 2014 Giovanni Barillari - Parts of this code are inspired to peewee - :copyright: (c) 2010 by Charles Leifer +Parts of this code are inspired to peewee +:copyright: (c) 2010 by Charles Leifer - :license: BSD-3-Clause +:license: BSD-3-Clause """ import uuid - from functools import wraps @@ -24,6 +23,7 @@ def __call__(self, fn): def inner(*args, **kwargs): with self: return fn(*args, **kwargs) + return inner @@ -105,7 +105,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): class _savepoint(callable_context_manager): def __init__(self, adapter, sid=None): self.adapter = adapter - self.sid = sid or 's' + uuid.uuid4().hex + self.sid = sid or "s" + uuid.uuid4().hex self.quoted_sid = self.adapter.dialect.quote(self.sid) self._ops = [] self._parent = None @@ -117,16 +117,16 @@ def _add_ops(self, ops): self._ops.extend(ops) def _begin(self): - self.adapter.execute('SAVEPOINT %s;' % self.quoted_sid) + self.adapter.execute("SAVEPOINT %s;" % self.quoted_sid) def commit(self, begin=True): - self.adapter.execute('RELEASE SAVEPOINT %s;' % self.quoted_sid) + self.adapter.execute("RELEASE SAVEPOINT %s;" % self.quoted_sid) if begin: self._begin() def rollback(self): self._ops.clear() - self.adapter.execute('ROLLBACK TO SAVEPOINT %s;' % self.quoted_sid) + self.adapter.execute("ROLLBACK TO SAVEPOINT %s;" % self.quoted_sid) def __enter__(self): self._parent = self.adapter.top_transaction() diff --git a/emmett/orm/wrappers.py b/emmett/orm/wrappers.py index a588c2fa..658fe1cf 100644 --- a/emmett/orm/wrappers.py +++ b/emmett/orm/wrappers.py @@ -1,16 +1,16 @@ # -*- coding: utf-8 -*- """ - emmett.orm.wrappers - ------------------- +emmett.orm.wrappers +------------------- - Provides ORM wrappers utilities. +Provides ORM wrappers utilities. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ from .helpers import RelationBuilder -from .objects import HasOneSet, HasOneViaSet, HasManySet, HasManyViaSet +from .objects import HasManySet, HasManyViaSet, HasOneSet, HasOneViaSet class Wrapper(object): diff --git a/emmett/parsers.py b/emmett/parsers.py index d38e4dae..6f0ec3bc 100644 --- a/emmett/parsers.py +++ b/emmett/parsers.py @@ -1,10 +1,10 @@ # -*- coding: utf-8 -*- """ - emmett.parsers - -------------- +emmett.parsers +-------------- - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ from emmett_core.parsers import Parsers as Parsers diff --git a/emmett/pipeline.py b/emmett/pipeline.py index dcf7d792..16613c97 100644 --- a/emmett/pipeline.py +++ b/emmett/pipeline.py @@ -1,19 +1,19 @@ # -*- coding: utf-8 -*- """ - emmett.pipeline - --------------- +emmett.pipeline +--------------- - Provides the pipeline classes. +Provides the pipeline classes. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ import types from emmett_core.http.helpers import redirect -from emmett_core.pipeline.pipe import Pipe as Pipe from emmett_core.pipeline.extras import RequirePipe as _RequirePipe +from emmett_core.pipeline.pipe import Pipe as Pipe from .ctx import current from .helpers import flash @@ -36,26 +36,22 @@ async def pipe_request(self, next_pipe, **kwargs): redirect(self.__class__._current, self.otherwise) else: if self.flash: - flash('Insufficient privileges') + flash("Insufficient privileges") redirect(self.__class__._current, "/") return await next_pipe(**kwargs) class Injector(Pipe): - namespace: str = '__global__' + namespace: str = "__global__" def __init__(self): self._injections_ = {} - if self.namespace != '__global__': + if self.namespace != "__global__": self._inject = self._inject_local return self._inject = self._inject_global - for attr_name in ( - set(dir(self)) - - self.__class__._pipeline_methods_ - - {'output', 'namespace'} - ): - if attr_name.startswith('_'): + for attr_name in set(dir(self)) - self.__class__._pipeline_methods_ - {"output", "namespace"}: + if attr_name.startswith("_"): continue attr = getattr(self, attr_name) if isinstance(attr, types.MethodType): @@ -67,6 +63,7 @@ def __init__(self): def _wrapped_method(method): def wrap(*args, **kwargs): return method(*args, **kwargs) + return wrap def _inject_local(self, ctx): diff --git a/emmett/routing/response.py b/emmett/routing/response.py index 14404d0c..d9e7b365 100644 --- a/emmett/routing/response.py +++ b/emmett/routing/response.py @@ -1,17 +1,17 @@ # -*- coding: utf-8 -*- """ - emmett.routing.response - ----------------------- +emmett.routing.response +----------------------- - Provides response builders for http routes. +Provides response builders for http routes. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ from __future__ import annotations -from typing import Any, Dict, Union, Tuple +from typing import Any, Dict, Tuple, Union from emmett_core.http.response import HTTPResponse, HTTPStringResponse from emmett_core.routing.response import ResponseProcessor @@ -22,53 +22,28 @@ from ..html import asis from .urls import url -_html_content_type = 'text/html; charset=utf-8' + +_html_content_type = "text/html; charset=utf-8" class TemplateResponseBuilder(ResponseProcessor): - def process( - self, - output: Union[Dict[str, Any], None], - response - ) -> str: - response.headers._data['content-type'] = _html_content_type - base_ctx = { - 'current': current, - 'url': url, - 'asis': asis, - 'load_component': load_component - } + def process(self, output: Union[Dict[str, Any], None], response) -> str: + response.headers._data["content-type"] = _html_content_type + base_ctx = {"current": current, "url": url, "asis": asis, "load_component": load_component} output = base_ctx if output is None else {**base_ctx, **output} try: - return self.route.app.templater.render( - self.route.template, output - ) + return self.route.app.templater.render(self.route.template, output) except TemplateMissingError as exc: - raise HTTPStringResponse( - 404, - body="{}\n".format(exc.message), - cookies=response.cookies - ) + raise HTTPStringResponse(404, body="{}\n".format(exc.message), cookies=response.cookies) class SnippetResponseBuilder(ResponseProcessor): - def process( - self, - output: Tuple[str, Union[Dict[str, Any], None]], - response - ) -> str: - response.headers._data['content-type'] = _html_content_type + def process(self, output: Tuple[str, Union[Dict[str, Any], None]], response) -> str: + response.headers._data["content-type"] = _html_content_type template, output = output - base_ctx = { - 'current': current, - 'url': url, - 'asis': asis, - 'load_component': load_component - } + base_ctx = {"current": current, "url": url, "asis": asis, "load_component": load_component} output = base_ctx if output is None else {**base_ctx, **output} - return self.route.app.templater._render( - template, current.request.name, output - ) + return self.route.app.templater._render(template, current.request.name, output) class AutoResponseBuilder(ResponseProcessor): @@ -78,39 +53,18 @@ def process(self, output: Any, response) -> str: snippet, output = output if isinstance(output, dict): is_template = True - output = { - **{ - 'current': current, - 'url': url, - 'asis': asis, - 'load_component': load_component - }, - **output - } + output = {**{"current": current, "url": url, "asis": asis, "load_component": load_component}, **output} elif output is None: is_template = True - output = { - 'current': current, - 'url': url, - 'asis': asis, - 'load_component': load_component - } + output = {"current": current, "url": url, "asis": asis, "load_component": load_component} if is_template: - response.headers._data['content-type'] = _html_content_type + response.headers._data["content-type"] = _html_content_type if snippet: - return self.route.app.templater._render( - snippet, current.request.name, output - ) + return self.route.app.templater._render(snippet, current.request.name, output) try: - return self.route.app.templater.render( - self.route.template, output - ) + return self.route.app.templater.render(self.route.template, output) except TemplateMissingError as exc: - raise HTTPStringResponse( - 404, - body="{}\n".format(exc.message), - cookies=response.cookies - ) + raise HTTPStringResponse(404, body="{}\n".format(exc.message), cookies=response.cookies) elif isinstance(output, str): return output elif isinstance(output, HTTPResponse): diff --git a/emmett/routing/router.py b/emmett/routing/router.py index 0edc5865..5fa1d913 100644 --- a/emmett/routing/router.py +++ b/emmett/routing/router.py @@ -1,33 +1,38 @@ # -*- coding: utf-8 -*- """ - emmett.routing.router - --------------------- +emmett.routing.router +--------------------- - Provides router implementations. +Provides router implementations. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ from __future__ import annotations -from emmett_core.routing.router import HTTPRouter as _HTTPRouter, WebsocketRouter as WebsocketRouter, RoutingCtx as RoutingCtx, RoutingCtxGroup as RoutingCtxGroup +from emmett_core.routing.router import ( + HTTPRouter as _HTTPRouter, + RoutingCtx as RoutingCtx, + RoutingCtxGroup as RoutingCtxGroup, + WebsocketRouter as WebsocketRouter, +) -from .response import AutoResponseBuilder, TemplateResponseBuilder, SnippetResponseBuilder +from .response import AutoResponseBuilder, SnippetResponseBuilder, TemplateResponseBuilder from .rules import HTTPRoutingRule class HTTPRouter(_HTTPRouter): - __slots__ = ['injectors'] + __slots__ = ["injectors"] _routing_rule_cls = HTTPRoutingRule _outputs = { **_HTTPRouter._outputs, **{ - 'auto': AutoResponseBuilder, - 'template': TemplateResponseBuilder, - 'snippet': SnippetResponseBuilder, - } + "auto": AutoResponseBuilder, + "template": TemplateResponseBuilder, + "snippet": SnippetResponseBuilder, + }, } def __init__(self, *args, **kwargs): diff --git a/emmett/routing/routes.py b/emmett/routing/routes.py index 5f7d5adc..e1201230 100644 --- a/emmett/routing/routes.py +++ b/emmett/routing/routes.py @@ -1,20 +1,18 @@ # -*- coding: utf-8 -*- """ - emmett.routing.routes - --------------------- +emmett.routing.routes +--------------------- - Provides routes objects. +Provides routes objects. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ import re - from functools import wraps import pendulum - from emmett_core.http.response import HTTPResponse from emmett_core.routing.routes import HTTPRoute as _HTTPRoute @@ -27,16 +25,14 @@ def __init__(self, rule, path, idx): self.build_argparser() def build_argparser(self): - parsers = {'date': self._parse_date_reqarg} - opt_parsers = {'date': self._parse_date_reqarg_opt} + parsers = {"date": self._parse_date_reqarg} + opt_parsers = {"date": self._parse_date_reqarg_opt} pipeline = [] for key in parsers.keys(): optionals = [] - for element in re.compile( - r'\(([^<]+)?<{}\:(\w+)>\)\?'.format(key) - ).findall(self.path): + for element in re.compile(r"\(([^<]+)?<{}\:(\w+)>\)\?".format(key)).findall(self.path): optionals.append(element[1]) - elements = set(re.compile(r'<{}\:(\w+)>'.format(key)).findall(self.path)) + elements = set(re.compile(r"<{}\:(\w+)>".format(key)).findall(self.path)) args = elements - set(optionals) if args: parser = self._wrap_reqargs_parser(parsers[key], args) @@ -73,6 +69,7 @@ def _wrap_reqargs_parser(parser, args): @wraps(parser) def wrapped(route_args): return parser(args, route_args) + return wrapped diff --git a/emmett/routing/rules.py b/emmett/routing/rules.py index 7d8edcc2..3891ad72 100644 --- a/emmett/routing/rules.py +++ b/emmett/routing/rules.py @@ -1,18 +1,17 @@ # -*- coding: utf-8 -*- """ - emmett.routing.rules - -------------------- +emmett.routing.rules +-------------------- - Provides routing rules definition apis. +Provides routing rules definition apis. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ from __future__ import annotations import os - from typing import Any, Callable from emmett_core.routing.rules import HTTPRoutingRule as _HTTPRoutingRule @@ -22,19 +21,26 @@ class HTTPRoutingRule(_HTTPRoutingRule): - __slots__ = [ - 'injectors', - 'template_folder', - 'template_path', - 'template' - ] + __slots__ = ["injectors", "template_folder", "template_path", "template"] current = current route_cls = HTTPRoute def __init__( - self, router, paths=None, name=None, template=None, pipeline=None, - injectors=None, schemes=None, hostname=None, methods=None, prefix=None, - template_folder=None, template_path=None, cache=None, output='auto' + self, + router, + paths=None, + name=None, + template=None, + pipeline=None, + injectors=None, + schemes=None, + hostname=None, + methods=None, + prefix=None, + template_folder=None, + template_path=None, + cache=None, + output="auto", ): super().__init__( router, @@ -46,19 +52,16 @@ def __init__( methods=methods, prefix=prefix, cache=cache, - output=output + output=output, ) self.template = template self.template_folder = template_folder self.template_path = template_path or self.app.template_path - self.pipeline = ( - self.pipeline + - self.router.injectors + (injectors or []) - ) + self.pipeline = self.pipeline + self.router.injectors + (injectors or []) def _make_builders(self, output_type): builder_cls = self.router._outputs[output_type] - return builder_cls(self), self.router._outputs['empty'](self) + return builder_cls(self), self.router._outputs["empty"](self) def __call__(self, f: Callable[..., Any]) -> Callable[..., Any]: if not self.template: diff --git a/emmett/routing/urls.py b/emmett/routing/urls.py index 151de7ea..f0604fa1 100644 --- a/emmett/routing/urls.py +++ b/emmett/routing/urls.py @@ -1,12 +1,12 @@ # -*- coding: utf-8 -*- """ - emmett.routing.urls - ------------------- +emmett.routing.urls +------------------- - Provides url builder apis. +Provides url builder apis. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ from emmett_core.routing.urls import Url diff --git a/emmett/rsgi/handlers.py b/emmett/rsgi/handlers.py index 0e6a147e..de7373ec 100644 --- a/emmett/rsgi/handlers.py +++ b/emmett/rsgi/handlers.py @@ -1,18 +1,17 @@ # -*- coding: utf-8 -*- """ - emmett.rsgi.handlers - -------------------- +emmett.rsgi.handlers +-------------------- - Provides RSGI handlers. +Provides RSGI handlers. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ from __future__ import annotations import os - from typing import Awaitable, Callable from emmett_core.http.response import HTTPResponse @@ -24,9 +23,8 @@ ) from ..ctx import current -from ..debug import smart_traceback, debug_handler +from ..debug import debug_handler, smart_traceback from ..wrappers.response import Response - from .wrappers import Request, Websocket @@ -37,25 +35,16 @@ class HTTPHandler(_HTTPHandler): @cachedprop def error_handler(self) -> Callable[[], Awaitable[str]]: - return ( - self._debug_handler if self.app.debug else self.exception_handler - ) + return self._debug_handler if self.app.debug else self.exception_handler - def _static_handler( - self, - scope: Scope, - protocol: HTTPProtocol, - path: str - ) -> Awaitable[HTTPResponse]: + def _static_handler(self, scope: Scope, protocol: HTTPProtocol, path: str) -> Awaitable[HTTPResponse]: #: handle internal assets - if path.startswith('/__emmett__'): + if path.startswith("/__emmett__"): file_name = path[12:] if not file_name: return self._http_response(404) - static_file = os.path.join( - os.path.dirname(__file__), '..', 'assets', file_name - ) - if os.path.splitext(static_file)[1] == 'html': + static_file = os.path.join(os.path.dirname(__file__), "..", "assets", file_name) + if os.path.splitext(static_file)[1] == "html": return self._http_response(404) return self._static_response(static_file) #: handle app assets @@ -65,9 +54,7 @@ def _static_handler( return self.dynamic_handler(scope, protocol, path) async def _debug_handler(self) -> str: - current.response.headers._data['content-type'] = ( - 'text/html; charset=utf-8' - ) + current.response.headers._data["content-type"] = "text/html; charset=utf-8" return debug_handler(smart_traceback(self.app)) diff --git a/emmett/rsgi/wrappers.py b/emmett/rsgi/wrappers.py index c1a7b57a..9996a295 100644 --- a/emmett/rsgi/wrappers.py +++ b/emmett/rsgi/wrappers.py @@ -1,16 +1,15 @@ # -*- coding: utf-8 -*- """ - emmett.rsgi.wrappers - -------------------- +emmett.rsgi.wrappers +-------------------- - Provides RSGI request and websocket wrappers +Provides RSGI request and websocket wrappers - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ import pendulum - from emmett_core.protocols.rsgi.wrappers import Request as _Request, Websocket as Websocket from emmett_core.utils import cachedprop diff --git a/emmett/security.py b/emmett/security.py index 5d8abdee..01cb4d84 100644 --- a/emmett/security.py +++ b/emmett/security.py @@ -1,18 +1,17 @@ # -*- coding: utf-8 -*- """ - emmett.security - --------------- +emmett.security +--------------- - Miscellaneous security helpers. +Miscellaneous security helpers. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ import hashlib import hmac import time - from collections import OrderedDict from uuid import uuid4 @@ -38,11 +37,11 @@ def gen_token(self): def md5_hash(text): - """ Generate a md5 hash with the given text """ + """Generate a md5 hash with the given text""" return hashlib.md5(text).hexdigest() -def simple_hash(text, key='', salt='', digest_alg='md5'): +def simple_hash(text, key="", salt="", digest_alg="md5"): """ Generates hash with the given text using the specified digest hashing algorithm @@ -51,14 +50,10 @@ def simple_hash(text, key='', salt='', digest_alg='md5'): raise RuntimeError("simple_hash with digest_alg=None") elif not isinstance(digest_alg, str): # manual approach h = digest_alg(text + key + salt) - elif digest_alg.startswith('pbkdf2'): # latest and coolest! - iterations, keylen, alg = digest_alg[7:-1].split(',') + elif digest_alg.startswith("pbkdf2"): # latest and coolest! + iterations, keylen, alg = digest_alg[7:-1].split(",") return kdf.pbkdf2_hex( - text, - salt, - iterations=int(iterations), - keylen=int(keylen), - hash_algorithm=kdf.PBKDF2_HMAC[alg] + text, salt, iterations=int(iterations), keylen=int(keylen), hash_algorithm=kdf.PBKDF2_HMAC[alg] ) elif key: # use hmac digest_alg = get_digest(digest_alg) @@ -93,10 +88,10 @@ def get_digest(value): DIGEST_ALG_BY_SIZE = { - 128 / 4: 'md5', - 160 / 4: 'sha1', - 224 / 4: 'sha224', - 256 / 4: 'sha256', - 384 / 4: 'sha384', - 512 / 4: 'sha512', + 128 / 4: "md5", + 160 / 4: "sha1", + 224 / 4: "sha224", + 256 / 4: "sha256", + 384 / 4: "sha384", + 512 / 4: "sha512", } diff --git a/emmett/serializers.py b/emmett/serializers.py index bf5e55b6..e5a73bb0 100644 --- a/emmett/serializers.py +++ b/emmett/serializers.py @@ -1,37 +1,28 @@ # -*- coding: utf-8 -*- """ - emmett.serializers - ------------------ +emmett.serializers +------------------ - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ from emmett_core.serializers import Serializers as Serializers -from .html import tag, htmlescape +from .html import htmlescape, tag def xml_encode(value, key=None, quote=True): - if hasattr(value, '__xml__'): + if hasattr(value, "__xml__"): return value.__xml__(key, quote) if isinstance(value, dict): - return tag[key]( - *[ - tag[k](xml_encode(v, None, quote)) - for k, v in value.items() - ]) + return tag[key](*[tag[k](xml_encode(v, None, quote)) for k, v in value.items()]) if isinstance(value, list): - return tag[key]( - *[ - tag[item](xml_encode(item, None, quote)) - for item in value - ]) + return tag[key](*[tag[item](xml_encode(item, None, quote)) for item in value]) return htmlescape(value) -@Serializers.register_for('xml') -def xml(value, encoding='UTF-8', key='document', quote=True): - rv = ('' % encoding) + \ - str(xml_encode(value, key, quote)) +@Serializers.register_for("xml") +def xml(value, encoding="UTF-8", key="document", quote=True): + rv = ('' % encoding) + str(xml_encode(value, key, quote)) return rv diff --git a/emmett/sessions.py b/emmett/sessions.py index 5f1606fc..7a62c697 100644 --- a/emmett/sessions.py +++ b/emmett/sessions.py @@ -1,12 +1,12 @@ # -*- coding: utf-8 -*- """ - emmett.sessions - --------------- +emmett.sessions +--------------- - Provides session managers for applications. +Provides session managers for applications. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ from __future__ import annotations diff --git a/emmett/templating/lexers.py b/emmett/templating/lexers.py index d7b5c5b6..801807f6 100644 --- a/emmett/templating/lexers.py +++ b/emmett/templating/lexers.py @@ -1,12 +1,12 @@ # -*- coding: utf-8 -*- """ - emmett.templating.lexers - ------------------------ +emmett.templating.lexers +------------------------ - Provides the Emmett lexers for Renoir engine. +Provides the Emmett lexers for Renoir engine. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ from renoir import Lexer @@ -17,10 +17,8 @@ class HelpersLexer(Lexer): helpers = [ - '', - '' + '', + '', ] def process(self, ctx, value): @@ -30,17 +28,12 @@ def process(self, ctx, value): class MetaLexer(Lexer): def process(self, ctx, value): - ctx.python_node('for name, value in current.response._meta_tmpl():') - ctx.variable( - "'' % (name, value)", - escape=False) - ctx.python_node('pass') - ctx.python_node( - 'for name, value in current.response._meta_tmpl_prop():') - ctx.variable( - "'' % (name, value)", - escape=False) - ctx.python_node('pass') + ctx.python_node("for name, value in current.response._meta_tmpl():") + ctx.variable('\'\' % (name, value)', escape=False) + ctx.python_node("pass") + ctx.python_node("for name, value in current.response._meta_tmpl_prop():") + ctx.variable('\'\' % (name, value)', escape=False) + ctx.python_node("pass") class StaticLexer(Lexer): @@ -48,20 +41,16 @@ class StaticLexer(Lexer): def process(self, ctx, value): file_name = value.split("?")[0] - surl = url('static', file_name) + surl = url("static", file_name) file_ext = file_name.rsplit(".", 1)[-1] - if file_ext == 'js': - s = u'' % surl + if file_ext == "js": + s = '' % surl elif file_ext == "css": - s = u'' % surl + s = '' % surl else: s = None if s: ctx.html(s) -lexers = { - 'include_helpers': HelpersLexer(), - 'include_meta': MetaLexer(), - 'include_static': StaticLexer() -} +lexers = {"include_helpers": HelpersLexer(), "include_meta": MetaLexer(), "include_static": StaticLexer()} diff --git a/emmett/templating/templater.py b/emmett/templating/templater.py index cf088534..a4ce70ad 100644 --- a/emmett/templating/templater.py +++ b/emmett/templating/templater.py @@ -1,16 +1,15 @@ # -*- coding: utf-8 -*- """ - emmett.templating.templater - --------------------------- +emmett.templating.templater +--------------------------- - Provides the Emmett implementation for Renoir engine. +Provides the Emmett implementation for Renoir engine. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ import os - from functools import reduce from typing import Optional, Tuple @@ -21,7 +20,7 @@ class Templater(Renoir): def __init__(self, **kwargs): - kwargs['lexers'] = lexers + kwargs["lexers"] = lexers super().__init__(**kwargs) self._namespaces = {} @@ -46,9 +45,7 @@ def register_namespace(self, namespace: str, path: Optional[str] = None): path = path or self.path self._namespaces[namespace] = path - def _get_namespace_path_elements( - self, file_name: str, path: Optional[str] - ) -> Tuple[str, str]: + def _get_namespace_path_elements(self, file_name: str, path: Optional[str]) -> Tuple[str, str]: if ":" in file_name: namespace, file_name = file_name.split(":") path = self._namespaces.get(namespace, self.path) @@ -60,9 +57,7 @@ def _preload(self, file_name: str, path: Optional[str] = None): path, file_name = self._get_namespace_path_elements(file_name, path) file_extension = os.path.splitext(file_name)[1] return reduce( - lambda args, loader: loader(args[0], args[1]), - self.loaders.get(file_extension, []), - (path, file_name) + lambda args, loader: loader(args[0], args[1]), self.loaders.get(file_extension, []), (path, file_name) ) def _no_preload(self, file_name: str, path: Optional[str] = None): diff --git a/emmett/testing.py b/emmett/testing.py index dac5ffb5..5fc9746c 100644 --- a/emmett/testing.py +++ b/emmett/testing.py @@ -1,4 +1,8 @@ -from emmett_core.protocols.rsgi.test_client.client import EmmettTestClient as _EmmettTestClient, ClientContext as _ClientContext, ClientHTTPHandlerMixin +from emmett_core.protocols.rsgi.test_client.client import ( + ClientContext as _ClientContext, + ClientHTTPHandlerMixin, + EmmettTestClient as _EmmettTestClient, +) from .ctx import current from .rsgi.handlers import HTTPHandler diff --git a/emmett/tools/__init__.py b/emmett/tools/__init__.py index eb311265..c4deab59 100644 --- a/emmett/tools/__init__.py +++ b/emmett/tools/__init__.py @@ -1,4 +1,4 @@ -from .service import ServicePipe from .auth import Auth -from .mailer import Mailer from .decorators import requires, service +from .mailer import Mailer +from .service import ServicePipe diff --git a/emmett/tools/auth/apis.py b/emmett/tools/auth/apis.py index 480c055f..66e61e99 100644 --- a/emmett/tools/auth/apis.py +++ b/emmett/tools/auth/apis.py @@ -1,12 +1,12 @@ # -*- coding: utf-8 -*- """ - emmett.tools.auth.apis - ---------------------- +emmett.tools.auth.apis +---------------------- - Provides the interface for the auth system. +Provides the interface for the auth system. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ from __future__ import annotations @@ -14,52 +14,42 @@ from datetime import timedelta from typing import Any, Callable, Dict, List, Optional, Type, Union -from pydal.helpers.classes import Reference as _RecordReference from emmett_core.routing.cache import RouteCacheRule +from pydal.helpers.classes import Reference as _RecordReference from ...datastructures import sdict from ...locals import now, session -from ...orm.objects import Row from ...orm.models import Model -from ...pipeline import Pipe, Injector -from .ext import AuthExtension +from ...orm.objects import Row +from ...pipeline import Injector, Pipe from .exposer import AuthModule +from .ext import AuthExtension class Auth: - def __init__( - self, - app, - db, - user_model=None, - group_model=None, - membership_model=None, - permission_model=None - ): + def __init__(self, app, db, user_model=None, group_model=None, membership_model=None, permission_model=None): self.ext = app.use_extension(AuthExtension) self.ext.bind_auth(self) - self.ext.use_database( - db, user_model, group_model, membership_model, permission_model - ) + self.ext.use_database(db, user_model, group_model, membership_model, permission_model) self.ext.init_forms() self.pipe = AuthPipe(self) def module( self, import_name: str, - name: str = 'auth', - template_folder: str = 'auth', + name: str = "auth", + template_folder: str = "auth", template_path: Optional[str] = None, static_folder: Optional[str] = None, static_path: Optional[str] = None, - url_prefix: Optional[str] = 'auth', + url_prefix: Optional[str] = "auth", hostname: Optional[str] = None, cache: Optional[RouteCacheRule] = None, root_path: Optional[str] = None, pipeline: Optional[List[Pipe]] = None, injectors: Optional[List[Injector]] = None, module_class: Type[AuthModule] = AuthModule, - **kwargs: Any + **kwargs: Any, ) -> AuthModule: return module_class.from_app( self.ext.app, @@ -75,7 +65,7 @@ def module( root_path=root_path, pipeline=pipeline or [], injectors=injectors or [], - opts=kwargs + opts=kwargs, ) @property @@ -83,7 +73,7 @@ def models(self) -> Dict[str, Model]: return self.ext.config.models def group_for_role(self, role: str) -> Row: - return self.models['group'].get(role=role) + return self.models["group"].get(role=role) #: context @property @@ -106,7 +96,7 @@ def has_membership( self, group: Optional[Union[str, int, Row]] = None, user: Optional[Union[Row, int]] = None, - role: Optional[str] = None + role: Optional[str] = None, ) -> bool: rv = False if not group and role: @@ -116,25 +106,28 @@ def has_membership( if not user and self.user: user = self.user.id if group and user: - if self.models['membership'].where( - lambda m: - (m.table[self.ext.relation_names['user']] == user) & - (m.table[self.ext.relation_names['group']] == group) - ).count(): + if ( + self.models["membership"] + .where( + lambda m: (m.table[self.ext.relation_names["user"]] == user) + & (m.table[self.ext.relation_names["group"]] == group) + ) + .count() + ): rv = True return rv def has_permission( self, - name: str = 'any', + name: str = "any", table_name: Optional[str] = None, record_id: Optional[int] = None, user: Optional[Union[int, Row]] = None, - group: Optional[Union[str, int, Row]] = None + group: Optional[Union[str, int, Row]] = None, ) -> bool: - permission = self.models['permission'] + permission = self.models["permission"] parent = None - query = (permission.name == name) + query = permission.name == name if table_name: query = query & (permission.table_name == table_name) if record_id: @@ -144,98 +137,72 @@ def has_permission( if not user and not group: return False if user is not None: - parent = self.models['user'].get(id=user) + parent = self.models["user"].get(id=user) elif group is not None: if isinstance(group, str): group = self.group_for_role(group) - parent = self.models['group'].get(id=group) + parent = self.models["group"].get(id=group) if not parent: return False - return ( - parent[self.ext.relation_names['permission'] + 's'].where( - query - ).count() > 0 - ) + return parent[self.ext.relation_names["permission"] + "s"].where(query).count() > 0 #: operations - def create_group(self, role: str, description: str = '') -> _RecordReference: - res = self.models['group'].create( - role=role, description=description - ) + def create_group(self, role: str, description: str = "") -> _RecordReference: + res = self.models["group"].create(role=role, description=description) return res.id def delete_group(self, group: Union[str, int, Row]) -> int: if isinstance(group, str): group = self.group_for_role(group) - return self.ext.db(self.models['group'].id == group).delete() + return self.ext.db(self.models["group"].id == group).delete() - def add_membership( - self, - group: Union[str, int, Row], - user: Optional[Row] = None - ) -> _RecordReference: + def add_membership(self, group: Union[str, int, Row], user: Optional[Row] = None) -> _RecordReference: if isinstance(group, int): - group = self.models['group'].get(group) + group = self.models["group"].get(group) elif isinstance(group, str): group = self.group_for_role(group) if not user and self.user: user = self.user.id - res = getattr( - group, self.ext.relation_names['user'] + 's' - ).add(user) + res = getattr(group, self.ext.relation_names["user"] + "s").add(user) return res.id - def remove_membership( - self, - group: Union[str, int, Row], - user: Optional[Row] = None - ): + def remove_membership(self, group: Union[str, int, Row], user: Optional[Row] = None): if isinstance(group, int): - group = self.models['group'].get(group) + group = self.models["group"].get(group) elif isinstance(group, str): group = self.group_for_role(group) if not user and self.user: user = self.user.id - return getattr( - group, self.ext.relation_names['user'] + 's' - ).remove(user) + return getattr(group, self.ext.relation_names["user"] + "s").remove(user) def login(self, email: str, password: str): - user = self.models['user'].get(email=email) - if user and user.get('password', False): - password = self.models['user'].password.validate(password)[0] + user = self.models["user"].get(email=email) + if user and user.get("password", False): + password = self.models["user"].password.validate(password)[0] if not user.registration_key and password == user.password: self.ext.login_user(user) return user return None def change_user_status(self, user: Union[int, Row], status: str) -> int: - return self.ext.db(self.models['user'].id == user).update( - registration_key=status - ) + return self.ext.db(self.models["user"].id == user).update(registration_key=status) def disable_user(self, user: Union[int, Row]) -> int: - return self.change_user_status(user, 'disabled') + return self.change_user_status(user, "disabled") def block_user(self, user: Union[int, Row]) -> int: - return self.change_user_status(user, 'blocked') + return self.change_user_status(user, "blocked") def allow_user(self, user: Union[int, Row]) -> int: - return self.change_user_status(user, '') + return self.change_user_status(user, "") #: emails decorators - def registration_mail( - self, - f: Callable[[Row, Dict[str, Any]], bool] - ) -> Callable[[Row, Dict[str, Any]], bool]: - self.ext.mails['registration'] = f + def registration_mail(self, f: Callable[[Row, Dict[str, Any]], bool]) -> Callable[[Row, Dict[str, Any]], bool]: + self.ext.mails["registration"] = f return f - def reset_password_mail( - self, - f: Callable[[Row, Dict[str, Any]], bool] - ) -> Callable[[Row, Dict[str, Any]], bool]: - self.ext.mails['reset_password'] = f + def reset_password_mail(self, f: Callable[[Row, Dict[str, Any]], bool]) -> Callable[[Row, Dict[str, Any]], bool]: + self.ext.mails["reset_password"] = f return f @@ -255,26 +222,20 @@ def session_open(self): return #: is session expired? visit_dt = now().as_naive_datetime() - if ( - authsess.last_visit + timedelta(seconds=authsess.expiration) < - visit_dt - ): + if authsess.last_visit + timedelta(seconds=authsess.expiration) < visit_dt: del session.auth #: does session need re-sync with db? elif authsess.last_dbcheck + timedelta(seconds=300) < visit_dt: if self.auth.user: #: is user still valid? - dbrow = self.auth.models['user'].get(self.auth.user.id) + dbrow = self.auth.models["user"].get(self.auth.user.id) if dbrow and not dbrow.registration_key: self.auth.ext.login_user(dbrow, authsess.remember) else: del session.auth else: #: set last_visit if make sense - if ( - (visit_dt - authsess.last_visit).seconds > - min(authsess.expiration / 10, 600) - ): + if (visit_dt - authsess.last_visit).seconds > min(authsess.expiration / 10, 600): authsess.last_visit = visit_dt def session_close(self): diff --git a/emmett/tools/auth/exposer.py b/emmett/tools/auth/exposer.py index 9206bc51..54d73fbf 100644 --- a/emmett/tools/auth/exposer.py +++ b/emmett/tools/auth/exposer.py @@ -1,12 +1,12 @@ # -*- coding: utf-8 -*- """ - emmett.tools.auth.exposer - ------------------------- +emmett.tools.auth.exposer +------------------------- - Provides the routes layer for the auth system. +Provides the routes layer for the auth system. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ from __future__ import annotations @@ -39,7 +39,7 @@ def __init__( root_path: Optional[str] = None, pipeline: Optional[List[Pipe]] = None, injectors: Optional[List[Injector]] = None, - **kwargs: Any + **kwargs: Any, ): super().__init__( app=app, @@ -55,7 +55,7 @@ def __init__( root_path=root_path, pipeline=pipeline, injectors=injectors, - **kwargs + **kwargs, ) self.init() @@ -64,30 +64,27 @@ def init(self): self.auth = self.ext.auth self.config = self.ext.config self._callbacks = { - 'after_login': self._after_login, - 'after_logout': self._after_logout, - 'after_registration': self._after_registration, - 'after_profile': self._after_profile, - 'after_email_verification': self._after_email_verification, - 'after_password_retrieval': self._after_password_retrieval, - 'after_password_reset': self._after_password_reset, - 'after_password_change': self._after_password_change + "after_login": self._after_login, + "after_logout": self._after_logout, + "after_registration": self._after_registration, + "after_profile": self._after_profile, + "after_email_verification": self._after_email_verification, + "after_password_retrieval": self._after_password_retrieval, + "after_password_reset": self._after_password_reset, + "after_password_change": self._after_password_change, } auth_pipe = [] if not self.config.inject_pipe else [self.auth.pipe] - requires_login = [ - RequirePipe( - lambda: self.auth.is_logged(), - lambda: redirect(self.url('login')))] + requires_login = [RequirePipe(lambda: self.auth.is_logged(), lambda: redirect(self.url("login")))] self._methods_pipelines = { - 'login': [], - 'logout': auth_pipe + requires_login, - 'registration': [], - 'profile': auth_pipe + requires_login, - 'email_verification': [], - 'password_retrieval': [], - 'password_reset': [], - 'password_change': auth_pipe + requires_login, - 'download': [] + "login": [], + "logout": auth_pipe + requires_login, + "registration": [], + "profile": auth_pipe + requires_login, + "email_verification": [], + "password_retrieval": [], + "password_reset": [], + "password_change": auth_pipe + requires_login, + "download": [], # 'not_authorized': [] } self.enabled_routes = list(self.config.enabled_routes) @@ -106,63 +103,57 @@ def init(self): self.ext.bind_exposer(self) def _template_for(self, key): - tname = 'auth' if self.config.single_template else key - return '{}{}'.format(tname, self.app.template_default_extension) + tname = "auth" if self.config.single_template else key + return "{}{}".format(tname, self.app.template_default_extension) def url(self, path, *args, **kwargs): path = "{}.{}".format(self.name, path) return url(path, *args, **kwargs) def _flash(self, message): - return flash(message, 'auth') + return flash(message, "auth") #: routes async def _login(self): def _validate_form(form): - row = self.config.models['user'].get(email=form.params.email) + row = self.config.models["user"].get(email=form.params.email) if row: #: verify password if form.params.password == row.password: - res['user'] = row + res["user"] = row return - form.errors.email = self.config.messages['login_invalid'] + form.errors.email = self.config.messages["login_invalid"] - rv = {'message': None} + rv = {"message": None} res = {} - rv['form'] = await self.ext.forms.login(onvalidation=_validate_form) - if rv['form'].accepted: + rv["form"] = await self.ext.forms.login(onvalidation=_validate_form) + if rv["form"].accepted: messages = self.config.messages - if res['user'].registration_key == 'pending': - rv['message'] = messages['approval_pending'] - elif res['user'].registration_key in ('disabled', 'blocked'): - rv['message'] = messages['login_disabled'] - elif ( - res['user'].registration_key is not None and - res['user'].registration_key.strip() - ): - rv['message'] = messages['verification_pending'] - if rv['message']: - self.flash(rv['message']) + if res["user"].registration_key == "pending": + rv["message"] = messages["approval_pending"] + elif res["user"].registration_key in ("disabled", "blocked"): + rv["message"] = messages["login_disabled"] + elif res["user"].registration_key is not None and res["user"].registration_key.strip(): + rv["message"] = messages["verification_pending"] + if rv["message"]: + self.flash(rv["message"]) else: - self.ext.login_user( - res['user'], rv['form'].params.get('remember', False)) - self.ext.log_event( - self.config.messages['login_log'], {'id': res['user'].id}) + self.ext.login_user(res["user"], rv["form"].params.get("remember", False)) + self.ext.log_event(self.config.messages["login_log"], {"id": res["user"].id}) redirect_after = (await request.body_params)._after if redirect_after: redirect(redirect_after) - self._callbacks['after_login'](rv['form']) + self._callbacks["after_login"](rv["form"]) return rv async def _logout(self): - self.ext.log_event( - self.config.messages['logout_log'], {'id': self.auth.user.id}) + self.ext.log_event(self.config.messages["logout_log"], {"id": self.auth.user.id}) session.auth = None - self.flash(self.config.messages['logged_out']) + self.flash(self.config.messages["logged_out"]) redirect_after = request.query_params._after if redirect_after: redirect(redirect_after) - self._callbacks['after_logout']() + self._callbacks["after_logout"]() async def _registration(self): def _validate_form(form): @@ -171,122 +162,105 @@ def _validate_form(form): form.errors.password2 = "password mismatch" return del form.params.password2 - res['id'] = self.config.models['user'].table.insert( - **form.params) + res["id"] = self.config.models["user"].table.insert(**form.params) - rv = {'message': None} + rv = {"message": None} res = {} - rv['form'] = await self.ext.forms.registration( - onvalidation=_validate_form) - if rv['form'].accepted: + rv["form"] = await self.ext.forms.registration(onvalidation=_validate_form) + if rv["form"].accepted: logged_in = False - row = self.config.models['user'].get(res['id']) + row = self.config.models["user"].get(res["id"]) if self.config.registration_verification: - email_data = { - 'link': self.url( - 'email_verification', row.registration_key, - scheme=True)} - if not self.ext.mails['registration'](row, email_data): - rv['message'] = self.config.messages['mail_failure'] + email_data = {"link": self.url("email_verification", row.registration_key, scheme=True)} + if not self.ext.mails["registration"](row, email_data): + rv["message"] = self.config.messages["mail_failure"] self.ext.db.rollback() - self.flash(rv['message']) + self.flash(rv["message"]) return rv - rv['message'] = self.config.messages['mail_success'] - self.flash(rv['message']) + rv["message"] = self.config.messages["mail_success"] + self.flash(rv["message"]) elif self.config.registration_approval: - rv['message'] = self.config.messages['approval_pending'] - self.flash(rv['message']) + rv["message"] = self.config.messages["approval_pending"] + self.flash(rv["message"]) else: - rv['message'] = self.config.messages['registration_success'] - self.flash(rv['message']) + rv["message"] = self.config.messages["registration_success"] + self.flash(rv["message"]) self.ext.login_user(row) logged_in = True - self.ext.log_event( - self.config.messages['registration_log'], - {'id': res['id']}) + self.ext.log_event(self.config.messages["registration_log"], {"id": res["id"]}) redirect_after = (await request.body_params)._after if redirect_after: redirect(redirect_after) - self._callbacks['after_registration'](rv['form'], row, logged_in) + self._callbacks["after_registration"](rv["form"], row, logged_in) return rv async def _profile(self): - rv = {'message': None, 'form': await self.ext.forms.profile()} - if rv['form'].accepted: - self.auth.user.update( - self.config.models['user'].table._filter_fields( - rv['form'].params)) - rv['message'] = self.config.messages['profile_updated'] - self.flash(rv['message']) - self.ext.log_event( - self.config.messages['profile_log'], {'id': self.auth.user.id}) + rv = {"message": None, "form": await self.ext.forms.profile()} + if rv["form"].accepted: + self.auth.user.update(self.config.models["user"].table._filter_fields(rv["form"].params)) + rv["message"] = self.config.messages["profile_updated"] + self.flash(rv["message"]) + self.ext.log_event(self.config.messages["profile_log"], {"id": self.auth.user.id}) redirect_after = (await request.body_params)._after if redirect_after: redirect(redirect_after) - self._callbacks['after_profile'](rv['form']) + self._callbacks["after_profile"](rv["form"]) return rv async def _email_verification(self, key): - rv = {'message': None} - user = self.config.models['user'].get(registration_key=key) + rv = {"message": None} + user = self.config.models["user"].get(registration_key=key) if not user: - redirect(self.url('login')) + redirect(self.url("login")) if self.config.registration_approval: - user.update_record(registration_key='pending') - rv['message'] = self.config.messages['approval_pending'] - self.flash(rv['message']) + user.update_record(registration_key="pending") + rv["message"] = self.config.messages["approval_pending"] + self.flash(rv["message"]) else: - user.update_record(registration_key='') - rv['message'] = self.config.messages['verification_success'] - self.flash(rv['message']) + user.update_record(registration_key="") + rv["message"] = self.config.messages["verification_success"] + self.flash(rv["message"]) #: make sure session has same user.registration_key as db record if self.auth.user: self.auth.user.registration_key = user.registration_key - self.ext.log_event( - self.config.messages['email_verification_log'], {'id': user.id}) + self.ext.log_event(self.config.messages["email_verification_log"], {"id": user.id}) redirect_after = request.query_params._after if redirect_after: redirect(redirect_after) - self._callbacks['after_email_verification'](user) + self._callbacks["after_email_verification"](user) return rv async def _password_retrieval(self): def _validate_form(form): messages = self.config.messages - row = self.config.models['user'].get(email=form.params.email) + row = self.config.models["user"].get(email=form.params.email) if not row: form.errors.email = "invalid email" return - if row.registration_key == 'pending': - form.errors.email = messages['approval_pending'] + if row.registration_key == "pending": + form.errors.email = messages["approval_pending"] return - if row.registration_key == 'blocked': - form.errors.email = messages['login_disabled'] + if row.registration_key == "blocked": + form.errors.email = messages["login_disabled"] return - res['user'] = row + res["user"] = row - rv = {'message': None} + rv = {"message": None} res = {} - rv['form'] = await self.ext.forms.password_retrieval( - onvalidation=_validate_form) - if rv['form'].accepted: - user = res['user'] + rv["form"] = await self.ext.forms.password_retrieval(onvalidation=_validate_form) + if rv["form"].accepted: + user = res["user"] reset_key = self.ext.generate_reset_key(user) - email_data = { - 'link': self.url( - 'password_reset', reset_key, scheme=True)} - if not self.ext.mails['reset_password'](user, email_data): - rv['message'] = self.config.messages['mail_failure'] - rv['message'] = self.config.messages['mail_success'] - self.flash(rv['message']) - self.ext.log_event( - self.config.messages['password_retrieval_log'], - {'id': user.id}, - user=user) + email_data = {"link": self.url("password_reset", reset_key, scheme=True)} + if not self.ext.mails["reset_password"](user, email_data): + rv["message"] = self.config.messages["mail_failure"] + rv["message"] = self.config.messages["mail_success"] + self.flash(rv["message"]) + self.ext.log_event(self.config.messages["password_retrieval_log"], {"id": user.id}, user=user) redirect_after = (await request.body_params)._after if redirect_after: redirect(redirect_after) - self._callbacks['after_password_retrieval'](user) + self._callbacks["after_password_retrieval"](user) return rv async def _password_reset(self, key): @@ -295,63 +269,49 @@ def _validate_form(form): form.errors.password = "password mismatch" form.errors.password2 = "password mismatch" - rv = {'message': None} + rv = {"message": None} redirect_after = request.query_params._after user = self.ext.get_user_by_reset_key(key) if not user: - rv['message'] = self.config.messages['reset_key_invalid'] - self.flash(rv['message']) + rv["message"] = self.config.messages["reset_key_invalid"] + self.flash(rv["message"]) if redirect_after: redirect(redirect_after) - self._callbacks['after_password_reset'](user) + self._callbacks["after_password_reset"](user) return rv - rv['form'] = await self.ext.forms.password_reset( - onvalidation=_validate_form) - if rv['form'].accepted: - user.update_record( - password=str(rv['form'].params.password), - registration_key='', - reset_password_key='' - ) - rv['message'] = self.config.messages['password_changed'] - self.flash(rv['message']) - self.ext.log_event( - self.config.messages['password_reset_log'], - {'id': user.id}, - user=user) + rv["form"] = await self.ext.forms.password_reset(onvalidation=_validate_form) + if rv["form"].accepted: + user.update_record(password=str(rv["form"].params.password), registration_key="", reset_password_key="") + rv["message"] = self.config.messages["password_changed"] + self.flash(rv["message"]) + self.ext.log_event(self.config.messages["password_reset_log"], {"id": user.id}, user=user) if redirect_after: redirect(redirect_after) - self._callbacks['after_password_reset'](user) + self._callbacks["after_password_reset"](user) return rv async def _password_change(self): def _validate_form(form): messages = self.config.messages if form.params.old_password != row.password: - form.errors.old_password = messages['invalid_password'] + form.errors.old_password = messages["invalid_password"] return - if ( - form.params.new_password.password != - form.params.new_password2.password - ): + if form.params.new_password.password != form.params.new_password2.password: form.errors.new_password = "password mismatch" form.errors.new_password2 = "password mismatch" - rv = {'message': None} - row = self.config.models['user'].get(self.auth.user.id) - rv['form'] = await self.ext.forms.password_change( - onvalidation=_validate_form) - if rv['form'].accepted: - row.update_record(password=str(rv['form'].params.new_password)) - rv['message'] = self.config.messages['password_changed'] - self.flash(rv['message']) - self.ext.log_event( - self.config.messages['password_change_log'], - {'id': row.id}) + rv = {"message": None} + row = self.config.models["user"].get(self.auth.user.id) + rv["form"] = await self.ext.forms.password_change(onvalidation=_validate_form) + if rv["form"].accepted: + row.update_record(password=str(rv["form"].params.new_password)) + rv["message"] = self.config.messages["password_changed"] + self.flash(rv["message"]) + self.ext.log_event(self.config.messages["password_change_log"], {"id": row.id}) redirect_after = request.query_params._after if redirect_after: redirect(redirect_after) - self._callbacks['after_password_change']() + self._callbacks["after_password_change"]() return rv def _download(self, file_name): @@ -359,70 +319,96 @@ def _download(self, file_name): #: routes decorators def login(self, template=None, pipeline=None, injectors=None): - pipeline = self._methods_pipelines['login'] + (pipeline or []) + pipeline = self._methods_pipelines["login"] + (pipeline or []) return self.route( - self.config['routes_paths']['login'], name='login', - template=template or self._template_for('login'), - pipeline=pipeline, injectors=injectors or []) + self.config["routes_paths"]["login"], + name="login", + template=template or self._template_for("login"), + pipeline=pipeline, + injectors=injectors or [], + ) def logout(self, template=None, pipeline=None, injectors=None): - pipeline = self._methods_pipelines['logout'] + (pipeline or []) + pipeline = self._methods_pipelines["logout"] + (pipeline or []) return self.route( - self.config['routes_paths']['logout'], name='logout', - template=template or self._template_for('logout'), - pipeline=pipeline, injectors=injectors or [], methods='get') + self.config["routes_paths"]["logout"], + name="logout", + template=template or self._template_for("logout"), + pipeline=pipeline, + injectors=injectors or [], + methods="get", + ) def registration(self, template=None, pipeline=None, injectors=None): - pipeline = self._methods_pipelines['registration'] + (pipeline or []) + pipeline = self._methods_pipelines["registration"] + (pipeline or []) return self.route( - self.config['routes_paths']['registration'], name='registration', - template=template or self._template_for('registration'), - pipeline=pipeline, injectors=injectors or []) + self.config["routes_paths"]["registration"], + name="registration", + template=template or self._template_for("registration"), + pipeline=pipeline, + injectors=injectors or [], + ) def profile(self, template=None, pipeline=None, injectors=None): - pipeline = self._methods_pipelines['profile'] + (pipeline or []) + pipeline = self._methods_pipelines["profile"] + (pipeline or []) return self.route( - self.config['routes_paths']['profile'], name='profile', - template=template or self._template_for('profile'), - pipeline=pipeline, injectors=injectors or []) + self.config["routes_paths"]["profile"], + name="profile", + template=template or self._template_for("profile"), + pipeline=pipeline, + injectors=injectors or [], + ) def email_verification(self, template=None, pipeline=None, injectors=None): - pipeline = self._methods_pipelines['email_verification'] + (pipeline or []) + pipeline = self._methods_pipelines["email_verification"] + (pipeline or []) return self.route( - self.config['routes_paths']['email_verification'], - name='email_verification', - template=template or self._template_for('email_verification'), - pipeline=pipeline, injectors=injectors or [], methods='get') + self.config["routes_paths"]["email_verification"], + name="email_verification", + template=template or self._template_for("email_verification"), + pipeline=pipeline, + injectors=injectors or [], + methods="get", + ) def password_retrieval(self, template=None, pipeline=None, injectors=None): - pipeline = self._methods_pipelines['password_retrieval'] + (pipeline or []) + pipeline = self._methods_pipelines["password_retrieval"] + (pipeline or []) return self.route( - self.config['routes_paths']['password_retrieval'], - name='password_retrieval', - template=template or self._template_for('password_retrieval'), - pipeline=pipeline, injectors=injectors or []) + self.config["routes_paths"]["password_retrieval"], + name="password_retrieval", + template=template or self._template_for("password_retrieval"), + pipeline=pipeline, + injectors=injectors or [], + ) def password_reset(self, template=None, pipeline=None, injectors=None): - pipeline = self._methods_pipelines['password_reset'] + (pipeline or []) + pipeline = self._methods_pipelines["password_reset"] + (pipeline or []) return self.route( - self.config['routes_paths']['password_reset'], - name='password_reset', - template=template or self._template_for('password_reset'), - pipeline=pipeline, injectors=injectors or []) + self.config["routes_paths"]["password_reset"], + name="password_reset", + template=template or self._template_for("password_reset"), + pipeline=pipeline, + injectors=injectors or [], + ) def password_change(self, template=None, pipeline=None, injectors=None): - pipeline = self._methods_pipelines['password_change'] + (pipeline or []) + pipeline = self._methods_pipelines["password_change"] + (pipeline or []) return self.route( - self.config['routes_paths']['password_change'], - name='password_change', - template=template or self._template_for('password_change'), - pipeline=pipeline, injectors=injectors or []) + self.config["routes_paths"]["password_change"], + name="password_change", + template=template or self._template_for("password_change"), + pipeline=pipeline, + injectors=injectors or [], + ) def download(self, template=None, pipeline=None, injectors=None): - pipeline = self._methods_pipelines['download'] + (pipeline or []) + pipeline = self._methods_pipelines["download"] + (pipeline or []) return self.route( - self.config['routes_paths']['download'], name='download', - pipeline=pipeline, injectors=injectors or [], methods='get') + self.config["routes_paths"]["download"], + name="download", + pipeline=pipeline, + injectors=injectors or [], + methods="get", + ) #: callbacks def _after_login(self, form): @@ -453,33 +439,33 @@ def _after_password_change(self): #: callbacks decorators def after_login(self, f): - self._callbacks['after_login'] = f + self._callbacks["after_login"] = f return f def after_logout(self, f): - self._callbacks['after_logout'] = f + self._callbacks["after_logout"] = f return f def after_registration(self, f): - self._callbacks['after_registration'] = f + self._callbacks["after_registration"] = f return f def after_profile(self, f): - self._callbacks['after_profile'] = f + self._callbacks["after_profile"] = f return f def after_email_verification(self, f): - self._callbacks['after_email_verification'] = f + self._callbacks["after_email_verification"] = f return f def after_password_retrieval(self, f): - self._callbacks['after_password_retrieval'] = f + self._callbacks["after_password_retrieval"] = f return f def after_password_reset(self, f): - self._callbacks['after_password_reset'] = f + self._callbacks["after_password_reset"] = f return f def after_password_change(self, f): - self._callbacks['after_password_change'] = f + self._callbacks["after_password_change"] = f return f diff --git a/emmett/tools/auth/ext.py b/emmett/tools/auth/ext.py index 5c68961b..fd9dcd36 100644 --- a/emmett/tools/auth/ext.py +++ b/emmett/tools/auth/ext.py @@ -1,16 +1,15 @@ # -*- coding: utf-8 -*- """ - emmett.tools.auth.ext - --------------------- +emmett.tools.auth.ext +--------------------- - Provides the main auth layer. +Provides the main auth layer. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ import time - from functools import wraps import click @@ -23,90 +22,93 @@ from ...locals import T, now, session from ...orm.helpers import decamelize from .forms import AuthForms -from .models import ( - AuthModel, AuthUser, AuthGroup, AuthMembership, AuthPermission, AuthEvent -) +from .models import AuthEvent, AuthGroup, AuthMembership, AuthModel, AuthPermission, AuthUser class AuthExtension(Extension): - namespace = 'auth' + namespace = "auth" default_config = { - 'models': { - 'user': AuthUser, - 'group': AuthGroup, - 'membership': AuthMembership, - 'permission': AuthPermission, - 'event': AuthEvent + "models": { + "user": AuthUser, + "group": AuthGroup, + "membership": AuthMembership, + "permission": AuthPermission, + "event": AuthEvent, }, - 'hmac_key': None, - 'hmac_alg': 'pbkdf2(2000,20,sha512)', - 'inject_pipe': False, - 'log_events': True, - 'flash_messages': True, - 'csrf': True, - 'enabled_routes': [ - 'login', 'logout', 'registration', 'profile', 'email_verification', - 'password_retrieval', 'password_reset', 'password_change', - 'download'], - 'disabled_routes': [], - 'routes_paths': { - 'login': '/login', - 'logout': '/logout', - 'registration': '/registration', - 'profile': '/profile', - 'email_verification': '/email_verification/', - 'password_retrieval': '/password_retrieval', - 'password_reset': '/password_reset/', - 'password_change': '/password_change', - 'download': '/download/' + "hmac_key": None, + "hmac_alg": "pbkdf2(2000,20,sha512)", + "inject_pipe": False, + "log_events": True, + "flash_messages": True, + "csrf": True, + "enabled_routes": [ + "login", + "logout", + "registration", + "profile", + "email_verification", + "password_retrieval", + "password_reset", + "password_change", + "download", + ], + "disabled_routes": [], + "routes_paths": { + "login": "/login", + "logout": "/logout", + "registration": "/registration", + "profile": "/profile", + "email_verification": "/email_verification/", + "password_retrieval": "/password_retrieval", + "password_reset": "/password_reset/", + "password_change": "/password_change", + "download": "/download/", }, - 'single_template': False, - 'password_min_length': 6, - 'remember_option': True, - 'session_expiration': 3600, - 'session_long_expiration': 3600 * 24 * 30, - 'registration_verification': True, - 'registration_approval': False + "single_template": False, + "password_min_length": 6, + "remember_option": True, + "session_expiration": 3600, + "session_long_expiration": 3600 * 24 * 30, + "registration_verification": True, + "registration_approval": False, } default_messages = { - 'approval_pending': 'Registration is pending approval', - 'verification_pending': 'Registration needs verification', - 'login_disabled': 'The account is locked', - 'login_invalid': 'Invalid credentials', - 'logged_out': 'Logged out successfully', - 'registration_success': 'Registration completed', - 'profile_updated': 'Profile updated successfully', - 'verification_success': 'Account verification completed', - 'password_changed': 'Password changed successfully', - 'mail_failure': 'Something went wrong with the email, try again later', - 'mail_success': 'We sent you an email, check your inbox', - 'reset_key_invalid': 'The reset link was invalid or expired', - 'login_button': 'Sign in', - 'registration_button': 'Register', - 'profile_button': 'Save', - 'remember_button': 'Remember me', - 'password_retrieval_button': 'Retrieve password', - 'password_reset_button': 'Reset password', - 'password_change_button': 'Change password', - 'login_log': 'User {id} logged in', - 'logout_log': 'User {id} logged out', - 'registration_log': 'User {id} registered', - 'profile_log': 'User {id} updated profile', - 'email_verification_log': 'Verification email sent to user {id}', - 'password_retrieval_log': 'User {id} asked for password retrieval', - 'password_reset_log': 'User {id} reset the password', - 'password_change_log': 'User {id} changed the password', - 'old_password': 'Current password', - 'new_password': 'New password', - 'verify_password': 'Confirm password', - 'registration_email_subject': 'Email verification', - 'registration_email_text': - 'Hello {email}! Click on the link {link} to verify your email', - 'reset_password_email_subject': 'Password reset requested', - 'reset_password_email_text': - 'A password reset was requested for your account, ' - 'click on the link {link} to proceed' + "approval_pending": "Registration is pending approval", + "verification_pending": "Registration needs verification", + "login_disabled": "The account is locked", + "login_invalid": "Invalid credentials", + "logged_out": "Logged out successfully", + "registration_success": "Registration completed", + "profile_updated": "Profile updated successfully", + "verification_success": "Account verification completed", + "password_changed": "Password changed successfully", + "mail_failure": "Something went wrong with the email, try again later", + "mail_success": "We sent you an email, check your inbox", + "reset_key_invalid": "The reset link was invalid or expired", + "login_button": "Sign in", + "registration_button": "Register", + "profile_button": "Save", + "remember_button": "Remember me", + "password_retrieval_button": "Retrieve password", + "password_reset_button": "Reset password", + "password_change_button": "Change password", + "login_log": "User {id} logged in", + "logout_log": "User {id} logged out", + "registration_log": "User {id} registered", + "profile_log": "User {id} updated profile", + "email_verification_log": "Verification email sent to user {id}", + "password_retrieval_log": "User {id} asked for password retrieval", + "password_reset_log": "User {id} reset the password", + "password_change_log": "User {id} changed the password", + "old_password": "Current password", + "new_password": "New password", + "verify_password": "Confirm password", + "registration_email_subject": "Email verification", + "registration_email_text": "Hello {email}! Click on the link {link} to verify your email", + "reset_password_email_subject": "Password reset requested", + "reset_password_email_text": "A password reset was requested for your account, " + "click on the link {link} to proceed", } def __init__(self, app, env, config): @@ -117,45 +119,38 @@ def __init__(self, app, env, config): AuthModel.auth = self def __init_messages(self): - self.config.messages = self.config.get('messages', sdict()) + self.config.messages = self.config.get("messages", sdict()) for key, dval in self.default_messages.items(): self.config.messages[key] = T(self.config.messages.get(key, dval)) def __init_mails(self): - self.mails = { - 'registration': self._registration_email, - 'reset_password': self._reset_password_email - } + self.mails = {"registration": self._registration_email, "reset_password": self._reset_password_email} def __register_commands(self): - @self.app.cli.group('auth', short_help='Auth commands') + @self.app.cli.group("auth", short_help="Auth commands") def cli_group(): pass - @cli_group.command('generate_key', short_help='Generate an auth key') + @cli_group.command("generate_key", short_help="Generate an auth key") @pass_script_info def generate_key(info): click.echo(uuid()) def __ensure_config(self): - for key in ( - set(self.default_config['routes_paths'].keys()) - - set(self.config['routes_paths'].keys()) - ): - self.config['routes_paths'][key] = \ - self.default_config['routes_paths'][key] + for key in set(self.default_config["routes_paths"].keys()) - set(self.config["routes_paths"].keys()): + self.config["routes_paths"][key] = self.default_config["routes_paths"][key] def __get_relnames(self): rv = {} def_names = { - 'user': 'user', - 'group': 'auth_group', - 'membership': 'auth_membership', - 'permission': 'auth_permission', - 'event': 'auth_event' + "user": "user", + "group": "auth_group", + "membership": "auth_membership", + "permission": "auth_permission", + "event": "auth_event", } - for m in ['user', 'group', 'membership', 'permission', 'event']: - if self.config.models[m] == self.default_config['models'][m]: + for m in ["user", "group", "membership", "permission", "event"]: + if self.config.models[m] == self.default_config["models"][m]: rv[m] = def_names[m] else: rv[m] = decamelize(self.config.models[m].__name__) @@ -168,13 +163,12 @@ def on_load(self): "An auto-generated 'hmac_key' was added to the auth " "configuration.\nPlase add your own key to the configuration. " "You can generate a key using the auth command.\n" - "> emmett -a {your_app_name} auth generate_key") + "> emmett -a {your_app_name} auth generate_key" + ) self.config.hmac_key = uuid() - self._hmac_key = self.config.hmac_alg + ':' + self.config.hmac_key - if 'MailExtension' not in self.app.ext: - self.app.log.warn( - "No mailer seems to be configured. The auth features " - "requiring mailer won't work.") + self._hmac_key = self.config.hmac_alg + ":" + self.config.hmac_key + if "MailExtension" not in self.app.ext: + self.app.log.warn("No mailer seems to be configured. The auth features " "requiring mailer won't work.") self.__ensure_config() self.relation_names = self.__get_relnames() @@ -193,23 +187,16 @@ def __set_model_for_key(self, key, model): if not model: return _model_bases = { - 'user': AuthModel, - 'group': AuthGroup, - 'membership': AuthMembership, - 'permission': AuthPermission + "user": AuthModel, + "group": AuthGroup, + "membership": AuthMembership, + "permission": AuthPermission, } if not issubclass(model, _model_bases[key]): - raise RuntimeError(f'{model.__name__} is an invalid {key} auth model') + raise RuntimeError(f"{model.__name__} is an invalid {key} auth model") self.config.models[key] = model - def use_database( - self, - db, - user_model=None, - group_model=None, - membership_model=None, - permission_model=None - ): + def use_database(self, db, user_model=None, group_model=None, membership_model=None, permission_model=None): self.db = db self.__set_model_for_key("user", user_model) self.__set_model_for_key("group", group_model) @@ -218,13 +205,11 @@ def use_database( self.define_models() def __set_models_labels(self): - for model in self.default_config['models'].values(): + for model in self.default_config["models"].values(): for supmodel in list(reversed(model.__mro__))[1:]: - if not supmodel.__module__.startswith( - 'emmett.tools.auth.models' - ): + if not supmodel.__module__.startswith("emmett.tools.auth.models"): continue - if not hasattr(supmodel, 'form_labels'): + if not hasattr(supmodel, "form_labels"): continue current_labels = {} for key, val in supmodel.form_labels.items(): @@ -236,66 +221,58 @@ def define_models(self): names = self.relation_names models = self.config.models #: AuthUser - user_model = models['user'] + user_model = models["user"] many_refs = [ - {names['membership'] + 's': models['membership'].__name__}, - {names['event'] + 's': models['event'].__name__}, - {names['group'] + 's': {'via': names['membership'] + 's'}}, - {names['permission'] + 's': {'via': names['group'] + 's'}} + {names["membership"] + "s": models["membership"].__name__}, + {names["event"] + "s": models["event"].__name__}, + {names["group"] + "s": {"via": names["membership"] + "s"}}, + {names["permission"] + "s": {"via": names["group"] + "s"}}, ] - if getattr(user_model, '_auto_relations', True): + if getattr(user_model, "_auto_relations", True): for el in many_refs: key = list(el)[0] user_model._all_hasmany_ref_[key] = el - if user_model.validation.get('password') is None: - user_model.validation['password'] = { - 'len': {'gte': self.config.password_min_length}, - 'crypt': {'key': self._hmac_key} + if user_model.validation.get("password") is None: + user_model.validation["password"] = { + "len": {"gte": self.config.password_min_length}, + "crypt": {"key": self._hmac_key}, } #: AuthGroup - group_model = models['group'] - if not hasattr(group_model, 'format'): - setattr(group_model, 'format', '%(role)s (%(id)s)') + group_model = models["group"] + if not hasattr(group_model, "format"): + group_model.format = "%(role)s (%(id)s)" many_refs = [ - {names['membership'] + 's': models['membership'].__name__}, - {names['permission'] + 's': models['permission'].__name__}, - {names['user'] + 's': {'via': names['membership'] + 's'}} + {names["membership"] + "s": models["membership"].__name__}, + {names["permission"] + "s": models["permission"].__name__}, + {names["user"] + "s": {"via": names["membership"] + "s"}}, ] - if getattr(group_model, '_auto_relations', True): + if getattr(group_model, "_auto_relations", True): for el in many_refs: key = list(el)[0] group_model._all_hasmany_ref_[key] = el #: AuthMembership - membership_model = models['membership'] - belongs_refs = [ - {names['user']: models['user'].__name__}, - {names['group']: models['group'].__name__} - ] - if getattr(membership_model, '_auto_relations', True): + membership_model = models["membership"] + belongs_refs = [{names["user"]: models["user"].__name__}, {names["group"]: models["group"].__name__}] + if getattr(membership_model, "_auto_relations", True): for el in belongs_refs: key = list(el)[0] membership_model._all_belongs_ref_[key] = el #: AuthPermission - permission_model = models['permission'] - belongs_refs = [ - {names['group']: models['group'].__name__} - ] - if getattr(permission_model, '_auto_relations', True): + permission_model = models["permission"] + belongs_refs = [{names["group"]: models["group"].__name__}] + if getattr(permission_model, "_auto_relations", True): for el in belongs_refs: key = list(el)[0] permission_model._all_belongs_ref_[key] = el #: AuthEvent - event_model = models['event'] - belongs_refs = [ - {names['user']: models['user'].__name__} - ] - if getattr(event_model, '_auto_relations', True): + event_model = models["event"] + belongs_refs = [{names["user"]: models["user"].__name__}] + if getattr(event_model, "_auto_relations", True): for el in belongs_refs: key = list(el)[0] event_model._all_belongs_ref_[key] = el self.db.define_models( - models['user'], models['group'], models['membership'], - models['permission'], models['event'] + models["user"], models["group"], models["membership"], models["permission"], models["event"] ) self.model_names = sdict() for key, value in models.items(): @@ -304,25 +281,23 @@ def define_models(self): def init_forms(self): self.forms = sdict() for key, (method, fields_method) in AuthForms.map().items(): - self.forms[key] = _wrap_form( - method, fields_method(self.auth), self.auth) + self.forms[key] = _wrap_form(method, fields_method(self.auth), self.auth) def login_user(self, user, remember=False): try: del user.password except Exception: pass - expiration = remember and self.config.session_long_expiration or \ - self.config.session_expiration + expiration = remember and self.config.session_long_expiration or self.config.session_expiration session.auth = sdict( user=user, last_visit=now().as_naive_datetime(), last_dbcheck=now().as_naive_datetime(), expiration=expiration, - remember=remember + remember=remember, ) - def log_event(self, description, data={}, origin='auth', user=None): + def log_event(self, description, data={}, origin="auth", user=None): if not self.config.log_events or not description: return try: @@ -332,42 +307,43 @@ def log_event(self, description, data={}, origin='auth', user=None): # log messages should not be translated if isinstance(description, Tstr): description = description.text - self.config.models['event'].table.insert( - description=str(description % data), - origin=origin, user=user_id) + self.config.models["event"].table.insert(description=str(description % data), origin=origin, user=user_id) def generate_reset_key(self, user): - key = str(int(time.time())) + '-' + uuid() + key = str(int(time.time())) + "-" + uuid() user.update_record(reset_password_key=key) return key def get_user_by_reset_key(self, key): try: - generated_at = int(key.split('-')[0]) + generated_at = int(key.split("-")[0]) if time.time() - generated_at > 60 * 60 * 24: raise ValueError - user = self.config.models['user'].get(reset_password_key=key) + user = self.config.models["user"].get(reset_password_key=key) except ValueError: user = None return user def _registration_email(self, user, data): - data['email'] = user.email + data["email"] = user.email return self.app.ext.MailExtension.send_mail( recipients=user.email, - subject=str(self.config.messages['registration_email_subject']), - body=str(self.config.messages['registration_email_text'] % data)) + subject=str(self.config.messages["registration_email_subject"]), + body=str(self.config.messages["registration_email_text"] % data), + ) def _reset_password_email(self, user, data): - data['email'] = user.email + data["email"] = user.email return self.app.ext.MailExtension.send_mail( recipients=user.email, - subject=str(self.config.messages['reset_password_email_subject']), - body=str(self.config.messages['reset_password_email_text'] % data)) + subject=str(self.config.messages["reset_password_email_subject"]), + body=str(self.config.messages["reset_password_email_text"] % data), + ) def _wrap_form(f, fields, auth): @wraps(f) def wrapped(*args, **kwargs): return f(auth, fields, *args, **kwargs) + return wrapped diff --git a/emmett/tools/auth/forms.py b/emmett/tools/auth/forms.py index 5a161b5d..d77ce600 100644 --- a/emmett/tools/auth/forms.py +++ b/emmett/tools/auth/forms.py @@ -1,12 +1,12 @@ # -*- coding: utf-8 -*- """ - emmett.tools.auth.forms - ----------------------- +emmett.tools.auth.forms +----------------------- - Provides the forms for the authorization system. +Provides the forms for the authorization system. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ from ...forms import Form, ModelForm @@ -23,6 +23,7 @@ def wrap(f): cls._registry_[target] = f cls._fields_registry_[target] = fields return f + return wrap @classmethod @@ -38,37 +39,25 @@ def map(cls): def login_fields(auth): - model = auth.models['user'] + model = auth.models["user"] rv = { - 'email': Field( - validation={'is': 'email', 'presence': True}, - label=model.email.label), - 'password': Field( - 'password', validation=model.password._requires, - label=model.password.label) + "email": Field(validation={"is": "email", "presence": True}, label=model.email.label), + "password": Field("password", validation=model.password._requires, label=model.password.label), } if auth.ext.config.remember_option: - rv['remember'] = Field( - 'bool', default=True, - label=auth.ext.config.messages['remember_button']) + rv["remember"] = Field("bool", default=True, label=auth.ext.config.messages["remember_button"]) return rv def registration_fields(auth): - rw_data = auth.models['user']._instance_()._merged_form_rw_ - user_table = auth.models['user'].table - all_fields = [ - (field_name, user_table[field_name].clone()) - for field_name in rw_data['registration']['writable'] - ] - for i, (field_name, field) in enumerate(all_fields): - if field_name == 'password': + rw_data = auth.models["user"]._instance_()._merged_form_rw_ + user_table = auth.models["user"].table + all_fields = [(field_name, user_table[field_name].clone()) for field_name in rw_data["registration"]["writable"]] + for i, (field_name, _) in enumerate(all_fields): + if field_name == "password": all_fields.insert( - i + 1, ( - 'password2', - Field( - 'password', - label=auth.ext.config.messages['verify_password']))) + i + 1, ("password2", Field("password", label=auth.ext.config.messages["verify_password"])) + ) break rv = {} for i, (field_name, field) in enumerate(all_fields): @@ -79,112 +68,81 @@ def registration_fields(auth): def profile_fields(auth): - rw_data = auth.models['user']._instance_()._merged_form_rw_ - return rw_data['profile'] + rw_data = auth.models["user"]._instance_()._merged_form_rw_ + return rw_data["profile"] def password_retrieval_fields(auth): rv = { - 'email': Field( - validation={'is': 'email', 'presence': True, 'lower': True}), + "email": Field(validation={"is": "email", "presence": True, "lower": True}), } return rv def password_reset_fields(auth): - password_field = auth.ext.config.models['user'].password + password_field = auth.ext.config.models["user"].password rv = { - 'password': Field( - 'password', validation=password_field._requires, - label=auth.ext.config.messages['new_password']), - 'password2': Field( - 'password', label=auth.ext.config.messages['verify_password']) + "password": Field( + "password", validation=password_field._requires, label=auth.ext.config.messages["new_password"] + ), + "password2": Field("password", label=auth.ext.config.messages["verify_password"]), } return rv def password_change_fields(auth): - password_validation = auth.ext.config.models['user'].password._requires + password_validation = auth.ext.config.models["user"].password._requires rv = { - 'old_password': Field( - 'password', validation=password_validation, - label=auth.ext.config.messages['old_password']), - 'new_password': Field( - 'password', validation=password_validation, - label=auth.ext.config.messages['new_password']), - 'new_password2': Field( - 'password', label=auth.ext.config.messages['verify_password']) + "old_password": Field( + "password", validation=password_validation, label=auth.ext.config.messages["old_password"] + ), + "new_password": Field( + "password", validation=password_validation, label=auth.ext.config.messages["new_password"] + ), + "new_password2": Field("password", label=auth.ext.config.messages["verify_password"]), } return rv -@AuthForms.register_for('login', fields=login_fields) +@AuthForms.register_for("login", fields=login_fields) def login_form(auth, fields, **kwargs): - opts = { - 'submit': auth.ext.config.messages['login_button'], 'keepvalues': True} + opts = {"submit": auth.ext.config.messages["login_button"], "keepvalues": True} opts.update(**kwargs) - return Form( - fields, - **opts - ) + return Form(fields, **opts) -@AuthForms.register_for('registration', fields=registration_fields) +@AuthForms.register_for("registration", fields=registration_fields) def registration_form(auth, fields, **kwargs): - opts = { - 'submit': auth.ext.config.messages['registration_button'], - 'keepvalues': True} + opts = {"submit": auth.ext.config.messages["registration_button"], "keepvalues": True} opts.update(**kwargs) - return Form( - fields, - **opts - ) + return Form(fields, **opts) -@AuthForms.register_for('profile', fields=profile_fields) +@AuthForms.register_for("profile", fields=profile_fields) def profile_form(auth, fields, **kwargs): - opts = { - 'submit': auth.ext.config.messages['profile_button'], - 'keepvalues': True} + opts = {"submit": auth.ext.config.messages["profile_button"], "keepvalues": True} opts.update(**kwargs) return ModelForm( - auth.models['user'], - record_id=auth.user.id, - fields=fields, - upload=auth.ext.exposer.url('download'), - **opts + auth.models["user"], record_id=auth.user.id, fields=fields, upload=auth.ext.exposer.url("download"), **opts ) -@AuthForms.register_for('password_retrieval', fields=password_retrieval_fields) +@AuthForms.register_for("password_retrieval", fields=password_retrieval_fields) def password_retrieval_form(auth, fields, **kwargs): - opts = {'submit': auth.ext.config.messages['password_retrieval_button']} + opts = {"submit": auth.ext.config.messages["password_retrieval_button"]} opts.update(**kwargs) - return Form( - fields, - **opts - ) + return Form(fields, **opts) -@AuthForms.register_for('password_reset', fields=password_reset_fields) +@AuthForms.register_for("password_reset", fields=password_reset_fields) def password_reset_form(auth, fields, **kwargs): - opts = { - 'submit': auth.ext.config.messages['password_reset_button'], - 'keepvalues': True} + opts = {"submit": auth.ext.config.messages["password_reset_button"], "keepvalues": True} opts.update(**kwargs) - return Form( - fields, - **opts - ) + return Form(fields, **opts) -@AuthForms.register_for('password_change', fields=password_change_fields) +@AuthForms.register_for("password_change", fields=password_change_fields) def password_change_form(auth, fields, **kwargs): - opts = { - 'submit': auth.ext.config.messages['password_change_button'], - 'keepvalues': True} + opts = {"submit": auth.ext.config.messages["password_change_button"], "keepvalues": True} opts.update(**kwargs) - return Form( - fields, - **opts - ) + return Form(fields, **opts) diff --git a/emmett/tools/auth/models.py b/emmett/tools/auth/models.py index 7e790ae0..868f9a83 100644 --- a/emmett/tools/auth/models.py +++ b/emmett/tools/auth/models.py @@ -1,18 +1,18 @@ # -*- coding: utf-8 -*- """ - emmett.tools.auth.models - ------------------------ +emmett.tools.auth.models +------------------------ - Provides models for the authorization system. +Provides models for the authorization system. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ from ..._shortcuts import uuid from ...ctx import current from ...locals import now, request -from ...orm import Model, Field, before_insert, rowmethod +from ...orm import Field, Model, before_insert, rowmethod class TimestampedModel(Model): @@ -21,8 +21,7 @@ class TimestampedModel(Model): class AuthModel(Model): - _additional_inheritable_dict_attrs_ = [ - 'form_registration_rw', 'form_profile_rw'] + _additional_inheritable_dict_attrs_ = ["form_registration_rw", "form_profile_rw"] auth = None @classmethod @@ -33,8 +32,7 @@ def _init_inheritable_dicts_(cls): else: attr_name, default = attr, {} if not isinstance(default, dict): - raise SyntaxError( - "{} is not a dictionary".format(attr_name)) + raise SyntaxError("{} is not a dictionary".format(attr_name)) setattr(cls, attr_name, default) @classmethod @@ -57,33 +55,32 @@ def _merge_inheritable_dicts_(cls, models): setattr(cls, attr_name, attrs) def __super_method(self, name): - return getattr(super(AuthModel, self), '_Model__' + name) + return getattr(super(AuthModel, self), "_Model__" + name) def _define_(self): self.__hide_all() - self.__super_method('define_indexes')() - self.__super_method('define_validation')() - self.__super_method('define_access')() - self.__super_method('define_defaults')() - self.__super_method('define_updates')() - self.__super_method('define_representation')() - self.__super_method('define_computations')() - self.__super_method('define_callbacks')() - self.__super_method('define_scopes')() - self.__super_method('define_query_helpers')() - self.__super_method('define_form_utils')() + self.__super_method("define_indexes")() + self.__super_method("define_validation")() + self.__super_method("define_access")() + self.__super_method("define_defaults")() + self.__super_method("define_updates")() + self.__super_method("define_representation")() + self.__super_method("define_computations")() + self.__super_method("define_callbacks")() + self.__super_method("define_scopes")() + self.__super_method("define_query_helpers")() + self.__super_method("define_form_utils")() self.__define_authform_utils() self.setup() - #def __define_extra_fields(self): + # def __define_extra_fields(self): # self.auth.settings.extra_fields['auth_user'] = self.fields def __hide_all(self): - alwaysvisible = ['first_name', 'last_name', 'password', 'email'] + alwaysvisible = ["first_name", "last_name", "password", "email"] for field in self.table.fields: if field not in alwaysvisible: - self.table[field].writable = self.table[field].readable = \ - False + self.table[field].writable = self.table[field].readable = False def __base_visibility(self): rv = {} @@ -93,8 +90,9 @@ def __base_visibility(self): def __define_authform_utils(self): settings = { - 'form_registration_rw': {'writable': [], 'readable': []}, - 'form_profile_rw': {'writable': [], 'readable': []}} + "form_registration_rw": {"writable": [], "readable": []}, + "form_profile_rw": {"writable": [], "readable": []}, + } for config_dict in settings.keys(): rw_data = self.__base_visibility() rw_data.update(**self.fields_rw) @@ -105,17 +103,18 @@ def __define_authform_utils(self): else: readable = writable = value if readable: - settings[config_dict]['readable'].append(key) + settings[config_dict]["readable"].append(key) if writable: - settings[config_dict]['writable'].append(key) - setattr(self, '_merged_form_rw_', { - 'registration': settings['form_registration_rw'], - 'profile': settings['form_profile_rw']}) + settings[config_dict]["writable"].append(key) + self._merged_form_rw_ = { + "registration": settings["form_registration_rw"], + "profile": settings["form_profile_rw"], + } class AuthUserBasic(AuthModel, TimestampedModel): tablename = "auth_users" - format = '%(email)s (%(id)s)' + format = "%(email)s (%(id)s)" #: injected by Auth # has_many( # {'auth_memberships': 'AuthMembership'}, @@ -126,55 +125,48 @@ class AuthUserBasic(AuthModel, TimestampedModel): email = Field(length=255, unique=True) password = Field.password(length=512) - registration_key = Field(length=512, rw=False, default='') - reset_password_key = Field(length=512, rw=False, default='') - registration_id = Field(length=512, rw=False, default='') + registration_key = Field(length=512, rw=False, default="") + reset_password_key = Field(length=512, rw=False, default="") + registration_id = Field(length=512, rw=False, default="") - form_labels = { - 'email': 'E-mail', - 'password': 'Password' - } + form_labels = {"email": "E-mail", "password": "Password"} - form_profile_rw = { - 'email': (True, False), - 'password': False - } + form_profile_rw = {"email": (True, False), "password": False} @before_insert def set_registration_key(self, fields): - if self.auth.config.registration_verification and not \ - fields.get('registration_key'): - fields['registration_key'] = uuid() + if self.auth.config.registration_verification and not fields.get("registration_key"): + fields["registration_key"] = uuid() elif self.auth.config.registration_approval: - fields['registration_key'] = 'pending' + fields["registration_key"] = "pending" - @rowmethod('disable') + @rowmethod("disable") def _set_disabled(self, row): - return row.update_record(registration_key='disabled') + return row.update_record(registration_key="disabled") - @rowmethod('block') + @rowmethod("block") def _set_blocked(self, row): - return row.update_record(registration_key='blocked') + return row.update_record(registration_key="blocked") - @rowmethod('allow') + @rowmethod("allow") def _set_allowed(self, row): - return row.update_record(registration_key='') + return row.update_record(registration_key="") class AuthUser(AuthUserBasic): - format = '%(first_name)s %(last_name)s (%(id)s)' + format = "%(first_name)s %(last_name)s (%(id)s)" first_name = Field(length=128, notnull=True) last_name = Field(length=128, notnull=True) form_labels = { - 'first_name': 'First name', - 'last_name': 'Last name', + "first_name": "First name", + "last_name": "Last name", } class AuthGroup(TimestampedModel): - format = '%(role)s (%(id)s)' + format = "%(role)s (%(id)s)" #: injected by Auth # has_many( # {'auth_memberships': 'AuthMembership'}, @@ -182,13 +174,10 @@ class AuthGroup(TimestampedModel): # {'users': {'via': 'memberships'}} # ) - role = Field(length=255, default='', unique=True) + role = Field(length=255, default="", unique=True) description = Field.text() - form_labels = { - 'role': 'Role', - 'description': 'Description' - } + form_labels = {"role": "Role", "description": "Description"} class AuthMembership(TimestampedModel): @@ -201,19 +190,13 @@ class AuthPermission(TimestampedModel): #: injected by Auth # belongs_to({'auth_group': 'AuthGroup'}) - name = Field(length=512, default='default', notnull=True) + name = Field(length=512, default="default", notnull=True) table_name = Field(length=512) record_id = Field.int(default=0) - validation = { - 'record_id': {'in': {'range': (0, 10**9)}} - } + validation = {"record_id": {"in": {"range": (0, 10**9)}}} - form_labels = { - 'name': 'Name', - 'table_name': 'Object or table name', - 'record_id': 'Record ID' - } + form_labels = {"name": "Name", "table_name": "Object or table name", "record_id": "Record ID"} class AuthEvent(TimestampedModel): @@ -225,15 +208,10 @@ class AuthEvent(TimestampedModel): description = Field.text(notnull=True) default_values = { - 'client_ip': lambda: - request.client if hasattr(current, 'request') else 'unavailable', - 'origin': 'auth', - 'description': '' + "client_ip": lambda: request.client if hasattr(current, "request") else "unavailable", + "origin": "auth", + "description": "", } #: labels injected by Auth - form_labels = { - 'client_ip': 'Client IP', - 'origin': 'Origin', - 'description': 'Description' - } + form_labels = {"client_ip": "Client IP", "origin": "Origin", "description": "Description"} diff --git a/emmett/tools/decorators.py b/emmett/tools/decorators.py index 241679ca..589b5e13 100644 --- a/emmett/tools/decorators.py +++ b/emmett/tools/decorators.py @@ -1,26 +1,27 @@ # -*- coding: utf-8 -*- """ - emmett.tools.decorators - ----------------------- +emmett.tools.decorators +----------------------- - Provides requires and service decorators. +Provides requires and service decorators. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ # from ..routing.router import Router -from emmett_core.pipeline.dyn import ServicePipeBuilder as _ServicePipeBuilder, requires as _requires, service as _service +from emmett_core.pipeline.dyn import ( + ServicePipeBuilder as _ServicePipeBuilder, + requires as _requires, + service as _service, +) from ..pipeline import RequirePipe from .service import XMLServicePipe class ServicePipeBuilder(_ServicePipeBuilder): - _pipe_cls = { - **_ServicePipeBuilder._pipe_cls, - **{"xml": XMLServicePipe} - } + _pipe_cls = {**_ServicePipeBuilder._pipe_cls, **{"xml": XMLServicePipe}} class requires(_requires): @@ -32,4 +33,4 @@ class service(_service): @staticmethod def xml(f): - return service('xml')(f) + return service("xml")(f) diff --git a/emmett/tools/mailer.py b/emmett/tools/mailer.py index b4cf602e..3c478fbc 100644 --- a/emmett/tools/mailer.py +++ b/emmett/tools/mailer.py @@ -1,29 +1,28 @@ # -*- coding: utf-8 -*- """ - emmett.tools.mailer - ------------------- +emmett.tools.mailer +------------------- - Provides mail facilities for Emmett. +Provides mail facilities for Emmett. - :copyright: 2014 Giovanni Barillari +:copyright: 2014 Giovanni Barillari - Based on the code of flask-mail - :copyright: (c) 2010 by Dan Jacob. +Based on the code of flask-mail +:copyright: (c) 2010 by Dan Jacob. - :license: BSD-3-Clause +:license: BSD-3-Clause """ import smtplib import time - from contextlib import contextmanager from email import charset as _charsetreg, policy from email.encoders import encode_base64 +from email.header import Header from email.mime.base import MIMEBase from email.mime.multipart import MIMEMultipart from email.mime.text import MIMEText -from email.header import Header -from email.utils import formatdate, formataddr, make_msgid, parseaddr +from email.utils import formataddr, formatdate, make_msgid, parseaddr from functools import wraps from emmett_core.utils import cachedprop @@ -31,25 +30,26 @@ from ..extensions import Extension from ..libs.contenttype import contenttype -_charsetreg.add_charset('utf-8', _charsetreg.SHORTEST, None, 'utf-8') + +_charsetreg.add_charset("utf-8", _charsetreg.SHORTEST, None, "utf-8") message_policy = policy.SMTP def _has_newline(line): - if line and ('\r' in line or '\n' in line): + if line and ("\r" in line or "\n" in line): return True return False -def sanitize_subject(subject, encoding='utf-8'): +def sanitize_subject(subject, encoding="utf-8"): try: - subject.encode('ascii') + subject.encode("ascii") except UnicodeEncodeError: subject = Header(subject, encoding).encode() return subject -def sanitize_address(address, encoding='utf-8'): +def sanitize_address(address, encoding="utf-8"): if isinstance(address, str): try: address = parseaddr(address, strict=False) @@ -58,20 +58,20 @@ def sanitize_address(address, encoding='utf-8'): name, address = address name = Header(name, encoding).encode() try: - address.encode('ascii') + address.encode("ascii") except UnicodeEncodeError: - if '@' in address: - localpart, domain = address.split('@', 1) + if "@" in address: + localpart, domain = address.split("@", 1) localpart = str(Header(localpart, encoding)) - domain = domain.encode('idna').decode('ascii') - address = '@'.join([localpart, domain]) + domain = domain.encode("idna").decode("ascii") + address = "@".join([localpart, domain]) else: address = Header(address, encoding).encode() return formataddr((name, address)) -def sanitize_addresses(addresses, encoding='utf-8'): - return map(lambda address: sanitize_address(address, encoding), addresses) +def sanitize_addresses(addresses, encoding="utf-8"): + return map(lambda address: sanitize_address(address, encoding), addresses) # noqa: C417 class MailServer(object): @@ -102,32 +102,43 @@ def send(self, message): self.host.sendmail( sanitize_address(message.sender), list(sanitize_addresses(message.all_recipients)), - str(message).encode('utf8'), + str(message).encode("utf8"), message.mail_options, - message.rcpt_options) + message.rcpt_options, + ) return True class Attachment(object): - def __init__( - self, filename=None, data=None, content_type=None, disposition=None, - headers=None - ): + def __init__(self, filename=None, data=None, content_type=None, disposition=None, headers=None): if not content_type and filename: content_type = contenttype(filename).split(";")[0] self.filename = filename self.content_type = content_type or contenttype(filename).split(";")[0] self.data = data - self.disposition = disposition or 'attachment' + self.disposition = disposition or "attachment" self.headers = headers or {} class Mail(object): def __init__( - self, ext, subject='', recipients=None, body=None, html=None, - alts=None, sender=None, cc=None, bcc=None, attachments=None, - reply_to=None, date=None, charset='utf-8', extra_headers=None, - mail_options=None, rcpt_options=None + self, + ext, + subject="", + recipients=None, + body=None, + html=None, + alts=None, + sender=None, + cc=None, + bcc=None, + attachments=None, + reply_to=None, + date=None, + charset="utf-8", + extra_headers=None, + mail_options=None, + rcpt_options=None, ): sender = sender or ext.config.sender if isinstance(sender, tuple): @@ -149,7 +160,7 @@ def __init__( self.mail_options = mail_options or [] self.rcpt_options = rcpt_options or [] self.attachments = attachments or [] - for attr in ['recipients', 'cc', 'bcc']: + for attr in ["recipients", "cc", "bcc"]: if not isinstance(getattr(self, attr), list): setattr(self, attr, [getattr(self, attr)]) @@ -159,14 +170,14 @@ def all_recipients(self): @property def html(self): - return self.alts.get('html') + return self.alts.get("html") @html.setter def html(self, value): if value is None: - self.alts.pop('html', None) + self.alts.pop("html", None) else: - self.alts['html'] = value + self.alts["html"] = value def has_bad_headers(self): headers = [self.sender, self.reply_to] + self.recipients @@ -175,10 +186,10 @@ def has_bad_headers(self): return True if self.subject: if _has_newline(self.subject): - for linenum, line in enumerate(self.subject.split('\r\n')): + for linenum, line in enumerate(self.subject.split("\r\n")): if not line: return True - if linenum > 0 and line[0] not in '\t ': + if linenum > 0 and line[0] not in "\t ": return True if _has_newline(line): return True @@ -186,7 +197,7 @@ def has_bad_headers(self): return True return False - def _mimetext(self, text, subtype='plain'): + def _mimetext(self, text, subtype="plain"): return MIMEText(text, _subtype=subtype, _charset=self.charset) @cachedprop @@ -202,38 +213,34 @@ def message(self): else: # Anything else msg = MIMEMultipart() - alternative = MIMEMultipart('alternative') - alternative.attach(self._mimetext(self.body, 'plain')) + alternative = MIMEMultipart("alternative") + alternative.attach(self._mimetext(self.body, "plain")) for mimetype, content in self.alts.items(): alternative.attach(self._mimetext(content, mimetype)) msg.attach(alternative) if self.subject: - msg['Subject'] = sanitize_subject(self.subject, self.charset) - msg['From'] = sanitize_address(self.sender, self.charset) - msg['To'] = ', '.join( - list(set(sanitize_addresses(self.recipients, self.charset)))) - msg['Date'] = formatdate(self.date, localtime=True) - msg['Message-ID'] = self.msgId + msg["Subject"] = sanitize_subject(self.subject, self.charset) + msg["From"] = sanitize_address(self.sender, self.charset) + msg["To"] = ", ".join(list(set(sanitize_addresses(self.recipients, self.charset)))) + msg["Date"] = formatdate(self.date, localtime=True) + msg["Message-ID"] = self.msgId if self.cc: - msg['Cc'] = ', '.join( - list(set(sanitize_addresses(self.cc, self.charset)))) + msg["Cc"] = ", ".join(list(set(sanitize_addresses(self.cc, self.charset)))) if self.reply_to: - msg['Reply-To'] = sanitize_address(self.reply_to, self.charset) + msg["Reply-To"] = sanitize_address(self.reply_to, self.charset) if self.extra_headers: for k, v in self.extra_headers.items(): msg[k] = v for attachment in attachments: - f = MIMEBase(*attachment.content_type.split('/')) + f = MIMEBase(*attachment.content_type.split("/")) f.set_payload(attachment.data) encode_base64(f) filename = attachment.filename try: - filename and filename.encode('ascii') + filename and filename.encode("ascii") except UnicodeEncodeError: - filename = ('UTF8', '', filename) - f.add_header( - 'Content-Disposition', attachment.disposition, - filename=filename) + filename = ("UTF8", "", filename) + f.add_header("Content-Disposition", attachment.disposition, filename=filename) for key, value in attachment.headers.items(): f.add_header(key, value) msg.attach(f) @@ -250,27 +257,23 @@ def send(self): def add_recipient(self, recipient): self.recipients.append(recipient) - def attach( - self, filename=None, data=None, content_type=None, disposition=None, - headers=None - ): - self.attachments.append( - Attachment(filename, data, content_type, disposition, headers)) + def attach(self, filename=None, data=None, content_type=None, disposition=None, headers=None): + self.attachments.append(Attachment(filename, data, content_type, disposition, headers)) class MailExtension(Extension): - namespace = 'mailer' + namespace = "mailer" default_config = { - 'server': '127.0.0.1', - 'username': None, - 'password': None, - 'port': 25, - 'use_tls': False, - 'use_ssl': False, - 'sender': None, - 'debug': False, - 'suppress': False + "server": "127.0.0.1", + "username": None, + "password": None, + "port": 25, + "use_tls": False, + "use_ssl": False, + "sender": None, + "debug": False, + "suppress": False, } def on_load(self): @@ -349,4 +352,5 @@ def _wrap_dispatcher(dispatcher): @wraps(dispatcher) def wrapped(ext, *args, **kwargs): return dispatcher(*args, **kwargs) + return wrapped diff --git a/emmett/tools/service.py b/emmett/tools/service.py index b3ce91e9..16090efe 100644 --- a/emmett/tools/service.py +++ b/emmett/tools/service.py @@ -1,12 +1,12 @@ # -*- coding: utf-8 -*- """ - emmett.tools.service - -------------------- +emmett.tools.service +-------------------- - Provides the services handler. +Provides the services handler. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ from emmett_core.pipeline.extras import JSONPipe @@ -22,14 +22,14 @@ class JSONServicePipe(JSONPipe): class XMLServicePipe(Pipe): - __slots__ = ['encoder'] - output = 'str' + __slots__ = ["encoder"] + output = "str" def __init__(self): - self.encoder = Serializers.get_for('xml') + self.encoder = Serializers.get_for("xml") async def pipe_request(self, next_pipe, **kwargs): - current.response.headers._data['content-type'] = 'text/xml' + current.response.headers._data["content-type"] = "text/xml" return self.encoder(await next_pipe(**kwargs)) def on_send(self, data): @@ -37,13 +37,7 @@ def on_send(self, data): def ServicePipe(procedure: str) -> Pipe: - pipe_cls = { - 'json': JSONServicePipe, - 'xml': XMLServicePipe - }.get(procedure) + pipe_cls = {"json": JSONServicePipe, "xml": XMLServicePipe}.get(procedure) if not pipe_cls: - raise RuntimeError( - 'Emmett cannot handle the service you requested: %s' % - procedure - ) + raise RuntimeError("Emmett cannot handle the service you requested: %s" % procedure) return pipe_cls() diff --git a/emmett/utils.py b/emmett/utils.py index 4623cca0..4b9a716e 100644 --- a/emmett/utils.py +++ b/emmett/utils.py @@ -1,74 +1,65 @@ # -*- coding: utf-8 -*- """ - emmett.utils - ------------ +emmett.utils +------------ - Provides some utilities for Emmett. +Provides some utilities for Emmett. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ from __future__ import annotations import re import socket - -from datetime import datetime, date, time +from datetime import date, datetime, time import pendulum - from pendulum.parsing import _parse as _pendulum_parse from .datastructures import sdict -_pendulum_parsing_opts = { - 'day_first': False, - 'year_first': True, - 'strict': True, - 'exact': False, - 'now': None -} +_pendulum_parsing_opts = {"day_first": False, "year_first": True, "strict": True, "exact": False, "now": None} def _pendulum_normalize(obj): if isinstance(obj, time): now = datetime.utcnow() - obj = datetime( - now.year, now.month, now.day, - obj.hour, obj.minute, obj.second, obj.microsecond - ) + obj = datetime(now.year, now.month, now.day, obj.hour, obj.minute, obj.second, obj.microsecond) elif isinstance(obj, date) and not isinstance(obj, datetime): obj = datetime(obj.year, obj.month, obj.day) return obj def parse_datetime(text): - parsed = _pendulum_normalize( - _pendulum_parse(text, **_pendulum_parsing_opts)) + parsed = _pendulum_normalize(_pendulum_parse(text, **_pendulum_parsing_opts)) return pendulum.datetime( - parsed.year, parsed.month, parsed.day, - parsed.hour, parsed.minute, parsed.second, parsed.microsecond, - tz=parsed.tzinfo or pendulum.UTC + parsed.year, + parsed.month, + parsed.day, + parsed.hour, + parsed.minute, + parsed.second, + parsed.microsecond, + tz=parsed.tzinfo or pendulum.UTC, ) -_re_ipv4 = re.compile(r'(\d+)\.(\d+)\.(\d+)\.(\d+)') +_re_ipv4 = re.compile(r"(\d+)\.(\d+)\.(\d+)\.(\d+)") def is_valid_ip_address(address): # deal with special cases - if address.lower() in [ - '127.0.0.1', 'localhost', '::1', '::ffff:127.0.0.1' - ]: + if address.lower() in ["127.0.0.1", "localhost", "::1", "::ffff:127.0.0.1"]: return True - elif address.lower() in ('unknown', ''): + elif address.lower() in ("unknown", ""): return False - elif address.count('.') == 3: # assume IPv4 - if address.startswith('::ffff:'): + elif address.count(".") == 3: # assume IPv4 + if address.startswith("::ffff:"): address = address[7:] - if hasattr(socket, 'inet_aton'): # try validate using the OS + if hasattr(socket, "inet_aton"): # try validate using the OS try: socket.inet_aton(address) return True @@ -76,12 +67,10 @@ def is_valid_ip_address(address): return False else: # try validate using Regex match = _re_ipv4.match(address) - if match and all( - 0 <= int(match.group(i)) < 256 for i in (1, 2, 3, 4) - ): + if match and all(0 <= int(match.group(i)) < 256 for i in (1, 2, 3, 4)): return True return False - elif hasattr(socket, 'inet_pton'): # assume IPv6, try using the OS + elif hasattr(socket, "inet_pton"): # assume IPv6, try using the OS try: socket.inet_pton(socket.AF_INET6, address) return True @@ -91,7 +80,7 @@ def is_valid_ip_address(address): return True -def read_file(filename, mode='r'): +def read_file(filename, mode="r"): # returns content from filename, making sure to close the file on exit. f = open(filename, mode) try: @@ -100,7 +89,7 @@ def read_file(filename, mode='r'): f.close() -def write_file(filename, value, mode='w'): +def write_file(filename, value, mode="w"): # writes to filename, making sure to close the file on exit. f = open(filename, mode) try: diff --git a/emmett/validators/__init__.py b/emmett/validators/__init__.py index 2e6eca14..95574ae4 100644 --- a/emmett/validators/__init__.py +++ b/emmett/validators/__init__.py @@ -1,26 +1,15 @@ # -*- coding: utf-8 -*- """ - emmett.validators - ----------------- +emmett.validators +----------------- - Implements validators for pyDAL. +Implements validators for pyDAL. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ -from .basic import ( - Allow, - Any, - Equals, - hasLength, - isEmpty, - isEmptyOr, - isntEmpty, - Matches, - Not, - Validator -) +from .basic import Allow, Any, Equals, Matches, Not, Validator, hasLength, isEmpty, isEmptyOr, isntEmpty from .consist import ( isAlphanumeric, isDate, @@ -34,10 +23,10 @@ isJSON, isList, isTime, - isUrl + isUrl, ) -from .inside import inRange, inSet, inDB, notInDB -from .process import Cleanup, Crypt, Lower, Urlify, Upper +from .inside import inDB, inRange, inSet, notInDB +from .process import Cleanup, Crypt, Lower, Upper, Urlify class ValidateFromDict(object): @@ -56,15 +45,9 @@ def __init__(self): "ip": isIP, "json": isJSON, "time": isTime, - "url": isUrl - } - self.proc_validators = { - "clean": Cleanup, - "crypt": Crypt, - "lower": Lower, - "upper": Upper, - "urlify": Urlify + "url": isUrl, } + self.proc_validators = {"clean": Cleanup, "crypt": Crypt, "lower": Lower, "upper": Upper, "urlify": Urlify} def parse_num_comparisons(self, data, minv=None, maxv=None): inclusions = [True, False] @@ -96,10 +79,10 @@ def parse_is(self, data, message=None): def parse_is_list(self, data, message=None): #: map types with 'list' fields - key = '' + key = "" options = {} suboptions = {} - lopts = ['splitter'] + lopts = ["splitter"] if isinstance(data, str): #: map {'is': 'list:int'} key = data @@ -114,26 +97,22 @@ def parse_is_list(self, data, message=None): else: raise SyntaxError("'is' validator accepts only string or dict") try: - keyspecs = key.split(':') + keyspecs = key.split(":") subkey = keyspecs[1].strip() - assert keyspecs[0].strip() == 'list' + assert keyspecs[0].strip() == "list" except Exception: - subkey = '_missing_' + subkey = "_missing_" validator = self.is_validators.get(subkey) - return isList( - [validator(message=message, **suboptions)], - message=message, - **options - ) if validator else None + return isList([validator(message=message, **suboptions)], message=message, **options) if validator else None def parse_reference(self, field): ref_table, ref_field, multiple = None, None, None - if field.type.startswith('reference'): + if field.type.startswith("reference"): multiple = False - elif field.type.startswith('list:reference'): + elif field.type.startswith("list:reference"): multiple = True if multiple is not None: - ref_table = field.type.split(' ')[1] + ref_table = field.type.split(" ")[1] model = field.table._model_ #: can't support (yet?) multi pks if model._belongs_ref_[field.name].compound: @@ -146,40 +125,36 @@ def __call__(self, field, data): validators = [] message = data.pop("message", None) #: parse 'presence' and 'empty' - presence = data.get('presence') - empty = data.get('empty') + presence = data.get("presence") + empty = data.get("empty") if presence is None and empty is not None: presence = not empty #: parse 'is' - _is = data.get('is') + _is = data.get("is") if _is is not None: validator = self.parse_is(_is, message) or self.parse_is_list(_is, message) if validator is None: - raise SyntaxError( - "Unknown type %s for 'is' validator" % data - ) + raise SyntaxError("Unknown type %s for 'is' validator" % data) validators.append(validator) #: parse 'len' - _len = data.get('len') + _len = data.get("len") if _len is not None: if isinstance(_len, int): #: allows {'len': 2} - validators.append( - hasLength(_len + 1, _len, message='Enter {min} characters') - ) + validators.append(hasLength(_len + 1, _len, message="Enter {min} characters")) else: #: allows # {'len': {'gt': 1, 'gte': 2, 'lt': 5, 'lte' 6}} # {'len': {'range': (2, 6)}} - if _len.get('range') is not None: - minv, maxv = _len['range'] + if _len.get("range") is not None: + minv, maxv = _len["range"] inc = (True, False) else: minv, maxv, inc = self.parse_num_comparisons(_len, 0, 256) validators.append(hasLength(maxv, minv, inc, message=message)) #: parse 'in' _dbset = None - _in = data.get('in', []) + _in = data.get("in", []) if _in: if isinstance(_in, (list, tuple, set)): #: allows {'in': [1, 2]} @@ -187,22 +162,22 @@ def __call__(self, field, data): elif isinstance(_in, dict): options = {} #: allows {'in': {'range': (1, 5)}} - _range = _in.get('range') + _range = _in.get("range") if isinstance(_range, (tuple, list)): validators.append(inRange(_range[0], _range[1], message=message)) #: allows {'in': {'set': [1, 5]}} with options - _set = _in.get('set') + _set = _in.get("set") if isinstance(_set, (list, tuple, set)): - opt_keys = [key for key in list(_in) if key != 'set'] + opt_keys = [key for key in list(_in) if key != "set"] for key in opt_keys: options[key] = _in[key] validators.append(inSet(_set, message=message, **options)) #: allows {'in': {'dbset': lambda db: db.where(query)}} - _dbset = _in.get('dbset') + _dbset = _in.get("dbset") if callable(_dbset): ref_table, ref_field, multiple = self.parse_reference(field) if ref_table: - opt_keys = [key for key in list(_in) if key != 'dbset'] + opt_keys = [key for key in list(_in) if key != "dbset"] for key in opt_keys: options[key] = _in[key] validators.append( @@ -213,108 +188,86 @@ def __call__(self, field, data): dbset=_dbset, multiple=multiple, message=message, - **options + **options, ) ) else: - raise SyntaxError( - "'in:dbset' validator needs a reference field" - ) + raise SyntaxError("'in:dbset' validator needs a reference field") else: - raise SyntaxError( - "'in' validator accepts only a set or a dict" - ) + raise SyntaxError("'in' validator accepts only a set or a dict") #: parse 'gt', 'gte', 'lt', 'lte' minv, maxv, inc = self.parse_num_comparisons(data) if minv is not None or maxv is not None: validators.append(inRange(minv, maxv, inc, message=message)) #: parse 'equals' - if 'equals' in data: - validators.append(Equals(data['equals'], message=message)) + if "equals" in data: + validators.append(Equals(data["equals"], message=message)) #: parse 'match' - if 'match' in data: - if isinstance(data['match'], dict): - validators.append(Matches(**data['match'], message=message)) + if "match" in data: + if isinstance(data["match"], dict): + validators.append(Matches(**data["match"], message=message)) else: - validators.append(Matches(data['match'], message=message)) + validators.append(Matches(data["match"], message=message)) #: parse transforming validators for key, vclass in self.proc_validators.items(): if key in data: options = {} if isinstance(data[key], dict): options = data[key] - elif data[key] != True: - if key == 'crypt' and isinstance(data[key], str): - options = {'algorithm': data[key]} + elif data[key] is not True: + if key == "crypt" and isinstance(data[key], str): + options = {"algorithm": data[key]} else: - raise SyntaxError( - key + " validator accepts only dict or True" - ) + raise SyntaxError(key + " validator accepts only dict or True") validators.append(vclass(message=message, **options)) #: parse 'unique' - _unique = data.get('unique', False) + _unique = data.get("unique", False) if _unique: _udbset = None if isinstance(_unique, dict): - whr = _unique.get('where', None) + whr = _unique.get("where", None) if callable(whr): _dbset = whr - validators.append( - notInDB( - field.db, - field.table, - field.name, - dbset=_udbset, - message=message - ) - ) + validators.append(notInDB(field.db, field.table, field.name, dbset=_udbset, message=message)) table = field.db[field._tablename] table._unique_fields_validation_[field.name] = 1 #: apply 'format' option - if 'format' in data: + if "format" in data: for validator in validators: children = [validator] - if hasattr(validator, 'children'): + if hasattr(validator, "children"): children += validator.children for child in children: - if hasattr(child, 'format'): - child.format = data['format'] + if hasattr(child, "format"): + child.format = data["format"] break #: parse 'custom' - if 'custom' in data: - if isinstance(data['custom'], list): - for element in data['custom']: + if "custom" in data: + if isinstance(data["custom"], list): + for element in data["custom"]: validators.append(element) else: - validators.append(data['custom']) + validators.append(data["custom"]) #: parse 'any' - if 'any' in data: - validators.append(Any(self(field, data['any']), message=message)) + if "any" in data: + validators.append(Any(self(field, data["any"]), message=message)) #: parse 'not' - if 'not' in data: - validators.append(Not(self(field, data['not']), message=message)) + if "not" in data: + validators.append(Not(self(field, data["not"]), message=message)) #: insert presence/empty validation if needed if presence: ref_table, ref_field, multiple = self.parse_reference(field) if ref_table: if not _dbset: - validators.append( - inDB( - field.db, - ref_table, - ref_field, - multiple=multiple, - message=message - ) - ) + validators.append(inDB(field.db, ref_table, ref_field, multiple=multiple, message=message)) else: validators.insert(0, isntEmpty(message=message)) if empty: validators.insert(0, isEmpty(message=message)) #: parse 'allow' - if 'allow' in data: - if data['allow'] in ['empty', 'blank']: + if "allow" in data: + if data["allow"] in ["empty", "blank"]: validators = [isEmptyOr(validators, message=message)] else: - validators = [Allow(data['allow'], validators, message=message)] + validators = [Allow(data["allow"], validators, message=message)] return validators diff --git a/emmett/validators/basic.py b/emmett/validators/basic.py index 2ce5d0c0..2fbb58b9 100644 --- a/emmett/validators/basic.py +++ b/emmett/validators/basic.py @@ -1,26 +1,25 @@ # -*- coding: utf-8 -*- """ - emmett.validators.basic - ----------------------- +emmett.validators.basic +----------------------- - Provide basic validators. +Provide basic validators. - :copyright: 2014 Giovanni Barillari +:copyright: 2014 Giovanni Barillari - Based on the web2py's validators (http://www.web2py.com) - :copyright: (c) by Massimo Di Pierro +Based on the web2py's validators (http://www.web2py.com) +:copyright: (c) by Massimo Di Pierro - :license: LGPLv3 (http://www.gnu.org/licenses/lgpl.html) +:license: LGPLv3 (http://www.gnu.org/licenses/lgpl.html) """ import re - from functools import reduce from os import SEEK_END, SEEK_SET # TODO: check unicode conversions from .._shortcuts import to_unicode -from .helpers import translate, is_empty +from .helpers import is_empty, translate class Validator: @@ -44,11 +43,7 @@ def __init__(self, children, message=None): self.children = children def formatter(self, value): - return reduce( - lambda formatted_val, child: child.formatter(formatted_val), - self.children, - value - ) + return reduce(lambda formatted_val, child: child.formatter(formatted_val), self.children, value) def __call__(self, value): raise NotImplementedError @@ -59,12 +54,7 @@ class _is(Validator): rule = None def __call__(self, value): - if ( - self.rule is None or ( - self.rule is not None and - self.rule.match(to_unicode(value) or '') - ) - ): + if self.rule is None or (self.rule is not None and self.rule.match(to_unicode(value) or "")): return self.check(value) return value, translate(self.message) @@ -138,19 +128,19 @@ def __init__(self, children, empty_regex=None, message=None): super().__init__(children, message=message) self.empty_regex = re.compile(empty_regex) if empty_regex is not None else None for child in self.children: - if hasattr(child, 'multiple'): + if hasattr(child, "multiple"): self.multiple = child.multiple break for child in self.children: - if hasattr(child, 'options'): + if hasattr(child, "options"): self._options_ = child.options self.options = self._get_options_ break def _get_options_(self): options = self._options_() - if (not options or options[0][0] != '') and not self.multiple: - options.insert(0, ('', '')) + if (not options or options[0][0] != "") and not self.multiple: + options.insert(0, ("", "")) return options def __call__(self, value): @@ -181,21 +171,19 @@ def __call__(self, value): class Matches(Validator): message = "Invalid expression" - def __init__( - self, expression, strict=False, search=False, extract=False, message=None - ): + def __init__(self, expression, strict=False, search=False, extract=False, message=None): super().__init__(message=message) if strict or not search: - if not expression.startswith('^'): - expression = '^(%s)' % expression + if not expression.startswith("^"): + expression = "^(%s)" % expression if strict: - if not expression.endswith('$'): - expression = '(%s)$' % expression + if not expression.endswith("$"): + expression = "(%s)$" % expression self.regex = re.compile(expression) self.extract = extract def __call__(self, value): - match = self.regex.search(to_unicode(value) or '') + match = self.regex.search(to_unicode(value) or "") if match is not None: return self.extract and match.group() or value, None return value, translate(self.message) @@ -204,9 +192,7 @@ def __call__(self, value): class hasLength(Validator): message = "Enter from {min} to {max} characters" - def __init__( - self, maxsize=256, minsize=0, include=(True, False), message=None - ): + def __init__(self, maxsize=256, minsize=0, include=(True, False), message=None): super().__init__(message=message) self.maxsize = maxsize self.minsize = minsize @@ -228,16 +214,16 @@ def __call__(self, value): length = 0 if self._between(length): return value, None - elif getattr(value, '_emt_field_hashed_contents_', False): + elif getattr(value, "_emt_field_hashed_contents_", False): return value, None - elif hasattr(value, 'file'): + elif hasattr(value, "file"): if value.file: value.file.seek(0, SEEK_END) length = value.file.tell() value.file.seek(0, SEEK_SET) if self._between(length): return value, None - elif hasattr(value, 'value'): + elif hasattr(value, "value"): val = value.value if val: length = len(val) @@ -247,7 +233,7 @@ def __call__(self, value): return value, None elif isinstance(value, bytes): try: - lvalue = len(value.decode('utf8')) + lvalue = len(value.decode("utf8")) except Exception: lvalue = len(value) if self._between(lvalue): diff --git a/emmett/validators/consist.py b/emmett/validators/consist.py index 245a0d24..db13c8b2 100644 --- a/emmett/validators/consist.py +++ b/emmett/validators/consist.py @@ -1,23 +1,22 @@ # -*- coding: utf-8 -*- """ - emmett.validators.consist - ------------------------- +emmett.validators.consist +------------------------- - Validators that check the value is of a certain type. +Validators that check the value is of a certain type. - :copyright: 2014 Giovanni Barillari +:copyright: 2014 Giovanni Barillari - Based on the web2py's validators (http://www.web2py.com) - :copyright: (c) by Massimo Di Pierro +Based on the web2py's validators (http://www.web2py.com) +:copyright: (c) by Massimo Di Pierro - :license: LGPLv3 (http://www.gnu.org/licenses/lgpl.html) +:license: LGPLv3 (http://www.gnu.org/licenses/lgpl.html) """ import decimal import re import struct - -from datetime import date, time, datetime, timedelta +from datetime import date, datetime, time, timedelta from time import strptime from urllib.parse import unquote as url_unquote @@ -26,7 +25,7 @@ from ..parsers import Parsers from ..serializers import Serializers from ..utils import parse_datetime -from .basic import Validator, ParentValidator, _is, Matches +from .basic import Matches, ParentValidator, Validator, _is from .helpers import ( _DEFAULT, _UTC, @@ -34,9 +33,10 @@ official_url_schemes, translate, unofficial_url_schemes, - url_split_regex + url_split_regex, ) + try: import ipaddress except ImportError: @@ -59,7 +59,7 @@ def __init__(self, dot=".", message=None): def check(self, value): try: - v = float(str(value).replace(self.dot, '.')) + v = float(str(value).replace(self.dot, ".")) return v, None except (ValueError, TypeError): pass @@ -69,11 +69,11 @@ def formatter(self, value): if value is None: return None val = str(value) - if '.' not in val: - val += '.00' + if "." not in val: + val += ".00" else: - val += '0' * (2 - len(val.split('.')[1])) - return val.replace('.', self.dot) + val += "0" * (2 - len(val.split(".")[1])) + return val.replace(".", self.dot) class isDecimal(isFloat): @@ -82,17 +82,14 @@ def check(self, value): if isinstance(value, decimal.Decimal): v = value else: - v = decimal.Decimal(str(value).replace(self.dot, '.')) + v = decimal.Decimal(str(value).replace(self.dot, ".")) return v, None except (ValueError, TypeError, decimal.InvalidOperation): return value, translate(self.message) class isTime(_is): - rule = re.compile( - r"((?P[0-9]+))([^0-9 ]+(?P[0-9 ]+))?" - r"([^0-9ap ]+(?P[0-9]*))?((?P[ap]m))?" - ) + rule = re.compile(r"((?P[0-9]+))([^0-9 ]+(?P[0-9 ]+))?" r"([^0-9ap ]+(?P[0-9]*))?((?P[ap]m))?") def __call__(self, value): return super().__call__(value.lower() if value else value) @@ -102,20 +99,17 @@ def check(self, value): return value, None val = self.rule.match(value) try: - (h, m, s) = (int(val.group('h')), 0, 0) - if not val.group('m') is None: - m = int(val.group('m')) - if not val.group('s') is None: - s = int(val.group('s')) - if val.group('d') == 'pm' and 0 < h < 12: + (h, m, s) = (int(val.group("h")), 0, 0) + if val.group("m") is not None: + m = int(val.group("m")) + if val.group("s") is not None: + s = int(val.group("s")) + if val.group("d") == "pm" and 0 < h < 12: h = h + 12 - if val.group('d') == 'am' and h == 12: + if val.group("d") == "am" and h == 12: h = 0 - if not (h in range(24) and m in range(60) and s - in range(60)): - raise ValueError( - 'Hours or minutes or seconds are outside of allowed range' - ) + if not (h in range(24) and m in range(60) and s in range(60)): + raise ValueError("Hours or minutes or seconds are outside of allowed range") val = time(h, m, s) return val, None except AttributeError: @@ -126,7 +120,7 @@ def check(self, value): class isDate(_is): - def __init__(self, format='%Y-%m-%d', timezone=None, message=None): + def __init__(self, format="%Y-%m-%d", timezone=None, message=None): super().__init__(message=message) self.format = translate(format) self.timezone = timezone @@ -167,13 +161,21 @@ def formatter(self, value): @staticmethod def nice(format): codes = ( - ('%Y', '1963'), ('%y', '63'), ('%d', '28'), ('%m', '08'), - ('%b', 'Aug'), ('%B', 'August'), ('%H', '14'), ('%I', '02'), - ('%p', 'PM'), ('%M', '30'), ('%S', '59') + ("%Y", "1963"), + ("%y", "63"), + ("%d", "28"), + ("%m", "08"), + ("%b", "Aug"), + ("%B", "August"), + ("%H", "14"), + ("%I", "02"), + ("%p", "PM"), + ("%M", "30"), + ("%S", "59"), ) - for (a, b) in codes: + for a, b in codes: format = format.replace(a, b) - return dict(format=format) + return {"format": format} class isDatetime(isDate): @@ -183,7 +185,7 @@ def __init__(self, format=_DEFAULT, **kwargs): def _get_parser(self, format): if format is _DEFAULT: - return self._parse_pendulum, '%Y-%m-%dT%H:%M:%S' + return self._parse_pendulum, "%Y-%m-%dT%H:%M:%S" return self._parse_strptime, format def _parse_strptime(self, value): @@ -191,27 +193,20 @@ def _parse_strptime(self, value): return datetime(y, m, d, hh, mm, ss) def _parse_pendulum(self, value): - return parse_datetime(value).in_timezone('UTC') + return parse_datetime(value).in_timezone("UTC") def _check_instance(self, value): return isinstance(value, datetime) def _formatter_obj(self, value): - return datetime( - value.year, - value.month, - value.day, - value.hour, - value.minute, - value.second - ) + return datetime(value.year, value.month, value.day, value.hour, value.minute, value.second) class isEmail(_is): rule = re.compile( r"^(?!\.)([-a-z0-9!\#$%&'*+/=?^_`{|}~]|(?= 0 - extension = value.filename[extension + 1:].lower() - if extension == 'jpg': - extension = 'jpeg' + extension = value.filename[extension + 1 :].lower() + if extension == "jpg": + extension = "jpeg" assert extension in self.extensions - if extension == 'bmp': + if extension == "bmp": width, height = self.__bmp(value.file) - elif extension == 'gif': + elif extension == "gif": width, height = self.__gif(value.file) - elif extension == 'jpeg': + elif extension == "jpeg": width, height = self.__jpeg(value.file) - elif extension == 'png': + elif extension == "png": width, height = self.__png(value.file) else: width = -1 height = -1 - assert self.minsize[0] <= width <= self.maxsize[0] \ - and self.minsize[1] <= height <= self.maxsize[1] + assert self.minsize[0] <= width <= self.maxsize[0] and self.minsize[1] <= height <= self.maxsize[1] value.file.seek(0) return value, None - except: + except Exception: return value, translate(self.message) def __bmp(self, stream): - if stream.read(2) == 'BM': + if stream.read(2) == "BM": stream.read(16) return struct.unpack("= 0xC0 and code <= 0xC3: - return tuple(reversed( - struct.unpack("!xHH", stream.read(5)))) + return tuple(reversed(struct.unpack("!xHH", stream.read(5)))) else: stream.read(length - 2) return -1, -1 def __png(self, stream): - if stream.read(8) == '\211PNG\r\n\032\n': + if stream.read(8) == "\211PNG\r\n\032\n": stream.read(4) if stream.read(4) == "IHDR": return struct.unpack("!LL", stream.read(8)) @@ -381,6 +371,7 @@ class _isGenericUrl(Validator): Based on RFC 2396: http://www.faqs.org/rfcs/rfc2396.html @author: Jonathan Benn """ + message = "Invalid URL" all_url_schemes = [None] + official_url_schemes + unofficial_url_schemes @@ -390,14 +381,11 @@ def __init__(self, schemes=None, prepend_scheme=None, message=None): self.prepend_scheme = prepend_scheme if self.prepend_scheme not in self.allowed_schemes: raise SyntaxError( - "prepend_scheme='{}' is not in allowed_schemes={}".format( - self.prepend_scheme, self.allowed_schemes - ) + "prepend_scheme='{}' is not in allowed_schemes={}".format(self.prepend_scheme, self.allowed_schemes) ) GENERIC_URL = re.compile( - r"%[^0-9A-Fa-f]{2}|%[^0-9A-Fa-f][0-9A-Fa-f]|%[0-9A-Fa-f][^0-9A-Fa-f]|" - r"%$|%[0-9A-Fa-f]$|%[^0-9A-Fa-f]$" + r"%[^0-9A-Fa-f]{2}|%[^0-9A-Fa-f][0-9A-Fa-f]|%[0-9A-Fa-f][^0-9A-Fa-f]|" r"%$|%[0-9A-Fa-f]$|%[^0-9A-Fa-f]$" ) GENERIC_URL_VALID = re.compile(r"[A-Za-z0-9;/?:@&=+$,\-_\.!~*'\(\)%#]+$") @@ -422,11 +410,9 @@ def __call__(self, value): # ports, check to see if adding a valid scheme fixes # the problem (but only do this if it doesn't have # one already!) - if value.find('://') < 0 and \ - None in self.allowed_schemes: - schemeToUse = self.prepend_scheme or 'http' - prependTest = self.__call__( - schemeToUse + '://' + value) + if value.find("://") < 0 and None in self.allowed_schemes: + schemeToUse = self.prepend_scheme or "http" + prependTest = self.__call__(schemeToUse + "://" + value) # if the prepend test succeeded if prependTest[1] is None: # if prepending in the output is enabled @@ -457,42 +443,30 @@ class _isHTTPUrl(Validator): """ message = "Invalid URL" - http_schemes = [None, 'http', 'https'] - GENERIC_VALID_IP = re.compile( - r"([\w.!~*'|;:&=+$,-]+@)?\d+\.\d+\.\d+\.\d+(:\d*)*$" - ) + http_schemes = [None, "http", "https"] + GENERIC_VALID_IP = re.compile(r"([\w.!~*'|;:&=+$,-]+@)?\d+\.\d+\.\d+\.\d+(:\d*)*$") GENERIC_VALID_DOMAIN = re.compile( r"([\w.!~*'|;:&=+$,-]+@)?(([A-Za-z0-9]+[A-Za-z0-9\-]*[A-Za-z0-9]+\.)" r"*([A-Za-z0-9]+\.)*)*([A-Za-z]+[A-Za-z0-9\-]*[A-Za-z0-9]+)\.?(:\d*)*$" ) - def __init__(self, schemes=None, prepend_scheme='http', tlds=None, message=None): + def __init__(self, schemes=None, prepend_scheme="http", tlds=None, message=None): super().__init__(message=message) self.allowed_schemes = schemes or self.http_schemes self.allowed_tlds = tlds or official_top_level_domains self.prepend_scheme = prepend_scheme for i in self.allowed_schemes: if i not in self.http_schemes: - raise SyntaxError( - "allowed_scheme value '{}' is not in {}".format( - i, self.http_schemes - ) - ) + raise SyntaxError("allowed_scheme value '{}' is not in {}".format(i, self.http_schemes)) if self.prepend_scheme not in self.allowed_schemes: raise SyntaxError( - "prepend_scheme='{}' is not in allowed_schemes={}".format( - self.prepend_scheme, self.allowed_schemes - ) + "prepend_scheme='{}' is not in allowed_schemes={}".format(self.prepend_scheme, self.allowed_schemes) ) def __call__(self, value): try: # if the URL passes generic validation - x = _isGenericUrl( - schemes=self.allowed_schemes, - prepend_scheme=self.prepend_scheme, - message=self.message - ) + x = _isGenericUrl(schemes=self.allowed_schemes, prepend_scheme=self.prepend_scheme, message=self.message) if x(value)[1] is None: componentsMatch = url_split_regex.match(value) authority = componentsMatch.group(4) @@ -504,12 +478,10 @@ def __call__(self, value): return value, None else: # else if authority is a valid domain name - domainMatch = self.GENERIC_VALID_DOMAIN.match( - authority) + domainMatch = self.GENERIC_VALID_DOMAIN.match(authority) if domainMatch: # if the top-level domain really exists - if domainMatch.group(5).lower()\ - in self.allowed_tlds: + if domainMatch.group(5).lower() in self.allowed_tlds: # Then this HTTP URL is valid return value, None else: @@ -518,15 +490,15 @@ def __call__(self, value): path = componentsMatch.group(5) # relative case: if this is a valid path (if it starts with # a slash) - if path.startswith('/'): + if path.startswith("/"): # Then this HTTP URL is valid return value, None else: # abbreviated case: if we haven't already, prepend a # scheme and see if it fixes the problem - if value.find('://') < 0: - schemeToUse = self.prepend_scheme or 'http' - prependTest = self(schemeToUse + '://' + value) + if value.find("://") < 0: + schemeToUse = self.prepend_scheme or "http" + prependTest = self(schemeToUse + "://" + value) # if the prepend test succeeded if prependTest[1] is None: # if prepending in the output is enabled @@ -536,7 +508,7 @@ def __call__(self, value): # else return the original, non-prepended # value return value, None - except: + except Exception: pass # else the HTTP URL is not valid return value, translate(self.message) @@ -546,39 +518,33 @@ class isUrl(Validator): #: use `_isGenericUrl` and `_isHTTPUrl` depending on `mode` parameter message = "Invalid URL" - def __init__( - self, mode='http', schemes=None, prepend_scheme='http', tlds=None, message=None - ): + def __init__(self, mode="http", schemes=None, prepend_scheme="http", tlds=None, message=None): super().__init__(message=message) self.mode = mode.lower() - if self.mode not in ['generic', 'http']: + if self.mode not in ["generic", "http"]: raise SyntaxError("invalid mode '{}' in isUrl".format(self.mode)) self.allowed_tlds = tlds self.allowed_schemes = schemes if self.allowed_schemes: if prepend_scheme not in self.allowed_schemes: raise SyntaxError( - "prepend_scheme='{}' is not in allowed_schemes={}".format( - prepend_scheme, self.allowed_schemes - ) + "prepend_scheme='{}' is not in allowed_schemes={}".format(prepend_scheme, self.allowed_schemes) ) # if allowed_schemes is None, then we will defer testing # prepend_scheme's validity to a sub-method self.prepend_scheme = prepend_scheme def __call__(self, value): - if self.mode == 'generic': + if self.mode == "generic": subValidator = _isGenericUrl( - schemes=self.allowed_schemes, - prepend_scheme=self.prepend_scheme, - message=self.message + schemes=self.allowed_schemes, prepend_scheme=self.prepend_scheme, message=self.message ) - elif self.mode == 'http': + elif self.mode == "http": subValidator = _isHTTPUrl( schemes=self.allowed_schemes, prepend_scheme=self.prepend_scheme, tlds=self.allowed_tlds, - message=self.message + message=self.message, ) else: raise SyntaxError("invalid mode '{}' in isUrl".format(self.mode)) @@ -621,38 +587,27 @@ class isIPv4(Validator): """ message = "Invalid IPv4 address" - regex = re.compile( - r"^(([1-9]?\d|1\d\d|2[0-4]\d|25[0-5])\.){3}([1-9]?\d|1\d\d|2[0-4]\d|25[0-5])$" - ) + regex = re.compile(r"^(([1-9]?\d|1\d\d|2[0-4]\d|25[0-5])\.){3}([1-9]?\d|1\d\d|2[0-4]\d|25[0-5])$") numbers = (16777216, 65536, 256, 1) localhost = 2130706433 private = ((2886729728, 2886795263), (3232235520, 3232301055)) automatic = (2851995648, 2852061183) def __init__( - self, - min='0.0.0.0', - max='255.255.255.255', - invert=False, - localhost=None, - private=None, - auto=None, - message=None + self, min="0.0.0.0", max="255.255.255.255", invert=False, localhost=None, private=None, auto=None, message=None ): super().__init__(message=message) for n, value in enumerate((min, max)): temp = [] if isinstance(value, str): - temp.append(value.split('.')) + temp.append(value.split(".")) elif isinstance(value, (list, tuple)): - if len(value) == len( - list(filter(lambda item: isinstance(item, int), value)) - ) == 4: + if len(value) == len(list(filter(lambda item: isinstance(item, int), value))) == 4: temp.append(value) else: for item in value: if isinstance(item, str): - temp.append(item.split('.')) + temp.append(item.split(".")) elif isinstance(item, (list, tuple)): temp.append(item) numbers = [] @@ -673,31 +628,21 @@ def __init__( def __call__(self, value): if self.regex.match(value): number = 0 - for i, j in zip(self.numbers, value.split('.')): + for i, j in zip(self.numbers, value.split(".")): number += i * int(j) ok = False for bottom, top in zip(self.minip, self.maxip): if self.invert != (bottom <= number <= top): ok = True - if not ( - self.is_localhost is None or self.is_localhost == ( - number == self.localhost - ) - ): + if not (self.is_localhost is None or self.is_localhost == (number == self.localhost)): ok = False if not ( - self.is_private is None or self.is_private == ( - sum([ - el[0] <= number <= el[1] - for el in self.private - ]) > 0 - ) + self.is_private is None + or self.is_private == (sum([el[0] <= number <= el[1] for el in self.private]) > 0) ): ok = False if not ( - self.is_automatic is None or self.is_automatic == ( - self.automatic[0] <= number <= self.automatic[1] - ) + self.is_automatic is None or self.is_automatic == (self.automatic[0] <= number <= self.automatic[1]) ): ok = False if ok: @@ -736,7 +681,7 @@ def __init__( to4=None, teredo=None, subnets=None, - message=None + message=None, ): super().__init__(message=message) self.is_private = private @@ -749,9 +694,7 @@ def __init__( self.subnets = subnets if ipaddress is None: - raise RuntimeError( - "You need 'ipaddress' python module to use isIPv6 validator." - ) + raise RuntimeError("You need 'ipaddress' python module to use isIPv6 validator.") def __call__(self, value): try: @@ -768,9 +711,8 @@ def __call__(self, value): for network in self.subnets: try: ipnet = ipaddress.IPv6Network(network) - except (ipaddress.NetmaskValueError, - ipaddress.AddressValueError): - return value, translate('invalid subnet provided') + except (ipaddress.NetmaskValueError, ipaddress.AddressValueError): + return value, translate("invalid subnet provided") if ip in ipnet: ok = True @@ -780,22 +722,13 @@ def __call__(self, value): self.is_reserved = False self.is_multicast = False - if not ( - self.is_private is None or self.is_private == ip.is_private - ): + if not (self.is_private is None or self.is_private == ip.is_private): ok = False - if not ( - self.is_link_local is None or - self.is_link_local == ip.is_link_local - ): + if not (self.is_link_local is None or self.is_link_local == ip.is_link_local): ok = False - if not ( - self.is_reserved is None or self.is_reserved == ip.is_reserved - ): + if not (self.is_reserved is None or self.is_reserved == ip.is_reserved): ok = False - if not ( - self.is_multicast is None or self.is_multicast == ip.is_multicast - ): + if not (self.is_multicast is None or self.is_multicast == ip.is_multicast): ok = False if not (self.is_6to4 is None or self.is_6to4 == ip.is_6to4): ok = False @@ -821,8 +754,8 @@ class isIP(Validator): def __init__( self, - min='0.0.0.0', - max='255.255.255.255', + min="0.0.0.0", + max="255.255.255.255", invert=False, localhost=None, private=None, @@ -836,11 +769,11 @@ def __init__( teredo=None, subnets=None, ipv6=None, - message=None + message=None, ): super().__init__(message=message) - self.minip = min, - self.maxip = max, + self.minip = (min,) + self.maxip = (max,) self.invert = invert self.is_localhost = localhost self.is_private = private @@ -857,9 +790,7 @@ def __init__( self.is_ipv6 = ipv6 if ipaddress is None: - raise RuntimeError( - "You need 'ipaddress' python module to use isIP validator." - ) + raise RuntimeError("You need 'ipaddress' python module to use isIP validator.") def __call__(self, value): try: @@ -879,7 +810,7 @@ def __call__(self, value): localhost=self.is_localhost, private=self.is_private, auto=self.is_automatic, - message=self.message + message=self.message, )(value) elif self.is_ipv6 or isinstance(ip, ipaddress.IPv6Address): rv = isIPv6( @@ -891,7 +822,7 @@ def __call__(self, value): to4=self.is_6to4, teredo=self.is_teredo, subnets=self.subnets, - message=self.message + message=self.message, )(value) else: rv = (value, translate(self.message)) diff --git a/emmett/validators/helpers.py b/emmett/validators/helpers.py index 8e744077..8f7c8314 100644 --- a/emmett/validators/helpers.py +++ b/emmett/validators/helpers.py @@ -1,25 +1,25 @@ # -*- coding: utf-8 -*- """ - emmett.validators.helpers - ------------------------- +emmett.validators.helpers +------------------------- - Provides utilities for validators. +Provides utilities for validators. - Ported from the original validators of web2py (http://www.web2py.com) +Ported from the original validators of web2py (http://www.web2py.com) - :copyright: (c) by Massimo Di Pierro - :license: LGPLv3 (http://www.gnu.org/licenses/lgpl.html) +:copyright: (c) by Massimo Di Pierro +:license: LGPLv3 (http://www.gnu.org/licenses/lgpl.html) """ import re - -from datetime import tzinfo, timedelta +from datetime import timedelta, tzinfo +from functools import reduce from io import StringIO # TODO: check unicode conversions from .._shortcuts import to_unicode, uuid from ..ctx import current -from ..security import simple_hash, DIGEST_ALG_BY_SIZE +from ..security import DIGEST_ALG_BY_SIZE, simple_hash _DEFAULT = lambda: None @@ -43,7 +43,7 @@ def is_empty(value, empty_regex=None): if isinstance(value, (str, bytes)): vclean = value.strip() if empty_regex is not None and empty_regex.match(vclean): - vclean = '' + vclean = "" return vclean, len(vclean) == 0 if isinstance(value, (list, dict)): return value, len(value) == 0 @@ -95,21 +95,21 @@ def __str__(self): if self.crypted: return self.crypted if self.crypt.key: - if ':' in self.crypt.key: - digest_alg, key = self.crypt.key.split(':', 1) + if ":" in self.crypt.key: + digest_alg, key = self.crypt.key.split(":", 1) else: digest_alg, key = self.crypt.digest_alg, self.crypt.key else: - digest_alg, key = self.crypt.digest_alg, '' + digest_alg, key = self.crypt.digest_alg, "" if self.crypt.salt: - if self.crypt.salt == True: - salt = str(uuid()).replace('-', '')[-16:] + if self.crypt.salt is True: + salt = str(uuid()).replace("-", "")[-16:] else: salt = self.crypt.salt else: - salt = '' + salt = "" hashed = simple_hash(self.password, key, salt, digest_alg) - self.crypted = '%s$%s$%s' % (digest_alg, salt, hashed) + self.crypted = "%s$%s$%s" % (digest_alg, salt, hashed) return self.crypted def __eq__(self, stored_password): @@ -117,30 +117,30 @@ def __eq__(self, stored_password): compares the current lazy crypted password with a stored password """ if isinstance(stored_password, self.__class__): - return ((self is stored_password) or - ((self.crypt.key == stored_password.crypt.key) and - (self.password == stored_password.password))) + return (self is stored_password) or ( + (self.crypt.key == stored_password.crypt.key) and (self.password == stored_password.password) + ) if self.crypt.key: - if ':' in self.crypt.key: - key = self.crypt.key.split(':')[1] + if ":" in self.crypt.key: + key = self.crypt.key.split(":")[1] else: key = self.crypt.key else: - key = '' + key = "" if stored_password is None: return False - elif stored_password.count('$') == 2: - (digest_alg, salt, hash) = stored_password.split('$') + elif stored_password.count("$") == 2: + (digest_alg, salt, hash) = stored_password.split("$") h = simple_hash(self.password, key, salt, digest_alg) - temp_pass = '%s$%s$%s' % (digest_alg, salt, h) + temp_pass = "%s$%s$%s" % (digest_alg, salt, h) else: # no salting # guess digest_alg digest_alg = DIGEST_ALG_BY_SIZE.get(len(stored_password), None) if not digest_alg: return False else: - temp_pass = simple_hash(self.password, key, '', digest_alg) + temp_pass = simple_hash(self.password, key, "", digest_alg) return temp_pass == stored_password def __ne__(self, other): @@ -148,7 +148,7 @@ def __ne__(self, other): def _escape_unicode(string): - ''' + """ Converts a unicode string into US-ASCII, using a simple conversion scheme. Each unicode character that does not have a US-ASCII equivalent is converted into a URL escaped form based on its hexadecimal value. @@ -162,14 +162,14 @@ def _escape_unicode(string): string: the US-ASCII escaped form of the inputted string @author: Jonathan Benn - ''' + """ returnValue = StringIO() for character in string: code = ord(character) if code > 0x7F: hexCode = hex(code) - returnValue.write('%' + hexCode[2:4] + '%' + hexCode[4:6]) + returnValue.write("%" + hexCode[2:4] + "%" + hexCode[4:6]) else: returnValue.write(character) @@ -177,7 +177,7 @@ def _escape_unicode(string): def _unicode_to_ascii_authority(authority): - ''' + """ Follows the steps in RFC 3490, Section 4 to convert a unicode authority string into its ASCII equivalent. For example, u'www.Alliancefran\xe7aise.nu' will be converted into @@ -196,40 +196,41 @@ def _unicode_to_ascii_authority(authority): authority @author: Jonathan Benn - ''' - label_split_regex = re.compile(u'[\u002e\u3002\uff0e\uff61]') + """ + label_split_regex = re.compile("[\u002e\u3002\uff0e\uff61]") - #RFC 3490, Section 4, Step 1 - #The encodings.idna Python module assumes that AllowUnassigned == True + # RFC 3490, Section 4, Step 1 + # The encodings.idna Python module assumes that AllowUnassigned == True - #RFC 3490, Section 4, Step 2 + # RFC 3490, Section 4, Step 2 labels = label_split_regex.split(authority) - #RFC 3490, Section 4, Step 3 - #The encodings.idna Python module assumes that UseSTD3ASCIIRules == False + # RFC 3490, Section 4, Step 3 + # The encodings.idna Python module assumes that UseSTD3ASCIIRules == False - #RFC 3490, Section 4, Step 4 - #We use the ToASCII operation because we are about to put the authority - #into an IDN-unaware slot + # RFC 3490, Section 4, Step 4 + # We use the ToASCII operation because we are about to put the authority + # into an IDN-unaware slot asciiLabels = [] try: import encodings.idna + for label in labels: if label: asciiLabels.append(encodings.idna.ToASCII(label)) else: - #encodings.idna.ToASCII does not accept an empty string, but - #it is necessary for us to allow for empty labels so that we - #don't modify the URL - asciiLabels.append('') - except: + # encodings.idna.ToASCII does not accept an empty string, but + # it is necessary for us to allow for empty labels so that we + # don't modify the URL + asciiLabels.append("") + except Exception: asciiLabels = [str(label) for label in labels] - #RFC 3490, Section 4, Step 5 - return str(reduce(lambda x, y: x + unichr(0x002E) + y, asciiLabels)) + # RFC 3490, Section 4, Step 5 + return str(reduce(lambda x, y: x + chr(0x002E) + y, asciiLabels)) def unicode_to_ascii_url(url, prepend_scheme): - ''' + """ Converts the inputed unicode url into a US-ASCII equivalent. This function goes a little beyond RFC 3490, which is limited in scope to the domain name (authority) only. Here, the functionality is expanded to what was observed @@ -259,223 +260,1016 @@ def unicode_to_ascii_url(url, prepend_scheme): string: a US-ASCII equivalent of the inputed url @author: Jonathan Benn - ''' - #convert the authority component of the URL into an ASCII punycode string, - #but encode the rest using the regular URI character encoding + """ + # convert the authority component of the URL into an ASCII punycode string, + # but encode the rest using the regular URI character encoding groups = url_split_regex.match(url).groups() - #If no authority was found + # If no authority was found if not groups[3]: - #Try appending a scheme to see if that fixes the problem - scheme_to_prepend = prepend_scheme or 'http' - groups = url_split_regex.match( - to_unicode(scheme_to_prepend) + u'://' + url).groups() - #if we still can't find the authority + # Try appending a scheme to see if that fixes the problem + scheme_to_prepend = prepend_scheme or "http" + groups = url_split_regex.match(to_unicode(scheme_to_prepend) + "://" + url).groups() + # if we still can't find the authority if not groups[3]: - raise Exception('No authority component found, ' + - 'could not decode unicode to US-ASCII') + raise Exception("No authority component found, " + "could not decode unicode to US-ASCII") - #We're here if we found an authority, let's rebuild the URL + # We're here if we found an authority, let's rebuild the URL scheme = groups[1] authority = groups[3] - path = groups[4] or '' - query = groups[5] or '' - fragment = groups[7] or '' + path = groups[4] or "" + query = groups[5] or "" + fragment = groups[7] or "" if prepend_scheme: - scheme = str(scheme) + '://' + scheme = str(scheme) + "://" else: - scheme = '' - return scheme + _unicode_to_ascii_authority(authority) +\ - _escape_unicode(path) + _escape_unicode(query) + str(fragment) + scheme = "" + return ( + scheme + _unicode_to_ascii_authority(authority) + _escape_unicode(path) + _escape_unicode(query) + str(fragment) + ) -url_split_regex = \ - re.compile(r'^(([^:/?#]+):)?(//([^/?#]*))?([^?#]*)(\?([^#]*))?(#(.*))?') +url_split_regex = re.compile(r"^(([^:/?#]+):)?(//([^/?#]*))?([^?#]*)(\?([^#]*))?(#(.*))?") official_url_schemes = [ - 'aaa', 'aaas', 'acap', 'cap', 'cid', 'crid', 'data', 'dav', 'dict', - 'dns', 'fax', 'file', 'ftp', 'go', 'gopher', 'h323', 'http', 'https', - 'icap', 'im', 'imap', 'info', 'ipp', 'iris', 'iris.beep', 'iris.xpc', - 'iris.xpcs', 'iris.lws', 'ldap', 'mailto', 'mid', 'modem', 'msrp', - 'msrps', 'mtqp', 'mupdate', 'news', 'nfs', 'nntp', 'opaquelocktoken', - 'pop', 'pres', 'prospero', 'rtsp', 'service', 'shttp', 'sip', 'sips', - 'snmp', 'soap.beep', 'soap.beeps', 'tag', 'tel', 'telnet', 'tftp', - 'thismessage', 'tip', 'tv', 'urn', 'vemmi', 'wais', 'xmlrpc.beep', - 'xmlrpc.beep', 'xmpp', 'z39.50r', 'z39.50s' + "aaa", + "aaas", + "acap", + "cap", + "cid", + "crid", + "data", + "dav", + "dict", + "dns", + "fax", + "file", + "ftp", + "go", + "gopher", + "h323", + "http", + "https", + "icap", + "im", + "imap", + "info", + "ipp", + "iris", + "iris.beep", + "iris.xpc", + "iris.xpcs", + "iris.lws", + "ldap", + "mailto", + "mid", + "modem", + "msrp", + "msrps", + "mtqp", + "mupdate", + "news", + "nfs", + "nntp", + "opaquelocktoken", + "pop", + "pres", + "prospero", + "rtsp", + "service", + "shttp", + "sip", + "sips", + "snmp", + "soap.beep", + "soap.beeps", + "tag", + "tel", + "telnet", + "tftp", + "thismessage", + "tip", + "tv", + "urn", + "vemmi", + "wais", + "xmlrpc.beep", + "xmlrpc.beep", + "xmpp", + "z39.50r", + "z39.50s", ] unofficial_url_schemes = [ - 'about', 'adiumxtra', 'aim', 'afp', 'aw', 'callto', 'chrome', 'cvs', - 'ed2k', 'feed', 'fish', 'gg', 'gizmoproject', 'iax2', 'irc', 'ircs', - 'itms', 'jar', 'javascript', 'keyparc', 'lastfm', 'ldaps', 'magnet', - 'mms', 'msnim', 'mvn', 'notes', 'nsfw', 'psyc', 'paparazzi:http', - 'rmi', 'rsync', 'secondlife', 'sgn', 'skype', 'ssh', 'sftp', 'smb', - 'sms', 'soldat', 'steam', 'svn', 'teamspeak', 'unreal', 'ut2004', - 'ventrilo', 'view-source', 'webcal', 'wyciwyg', 'xfire', 'xri', 'ymsgr' + "about", + "adiumxtra", + "aim", + "afp", + "aw", + "callto", + "chrome", + "cvs", + "ed2k", + "feed", + "fish", + "gg", + "gizmoproject", + "iax2", + "irc", + "ircs", + "itms", + "jar", + "javascript", + "keyparc", + "lastfm", + "ldaps", + "magnet", + "mms", + "msnim", + "mvn", + "notes", + "nsfw", + "psyc", + "paparazzi:http", + "rmi", + "rsync", + "secondlife", + "sgn", + "skype", + "ssh", + "sftp", + "smb", + "sms", + "soldat", + "steam", + "svn", + "teamspeak", + "unreal", + "ut2004", + "ventrilo", + "view-source", + "webcal", + "wyciwyg", + "xfire", + "xri", + "ymsgr", ] official_top_level_domains = [ # a - 'abogado', 'ac', 'academy', 'accountants', 'active', 'actor', - 'ad', 'adult', 'ae', 'aero', 'af', 'ag', 'agency', 'ai', - 'airforce', 'al', 'allfinanz', 'alsace', 'am', 'amsterdam', 'an', - 'android', 'ao', 'apartments', 'aq', 'aquarelle', 'ar', 'archi', - 'army', 'arpa', 'as', 'asia', 'associates', 'at', 'attorney', - 'au', 'auction', 'audio', 'autos', 'aw', 'ax', 'axa', 'az', + "abogado", + "ac", + "academy", + "accountants", + "active", + "actor", + "ad", + "adult", + "ae", + "aero", + "af", + "ag", + "agency", + "ai", + "airforce", + "al", + "allfinanz", + "alsace", + "am", + "amsterdam", + "an", + "android", + "ao", + "apartments", + "aq", + "aquarelle", + "ar", + "archi", + "army", + "arpa", + "as", + "asia", + "associates", + "at", + "attorney", + "au", + "auction", + "audio", + "autos", + "aw", + "ax", + "axa", + "az", # b - 'ba', 'band', 'bank', 'bar', 'barclaycard', 'barclays', - 'bargains', 'bayern', 'bb', 'bd', 'be', 'beer', 'berlin', 'best', - 'bf', 'bg', 'bh', 'bi', 'bid', 'bike', 'bingo', 'bio', 'biz', - 'bj', 'black', 'blackfriday', 'bloomberg', 'blue', 'bm', 'bmw', - 'bn', 'bnpparibas', 'bo', 'boo', 'boutique', 'br', 'brussels', - 'bs', 'bt', 'budapest', 'build', 'builders', 'business', 'buzz', - 'bv', 'bw', 'by', 'bz', 'bzh', + "ba", + "band", + "bank", + "bar", + "barclaycard", + "barclays", + "bargains", + "bayern", + "bb", + "bd", + "be", + "beer", + "berlin", + "best", + "bf", + "bg", + "bh", + "bi", + "bid", + "bike", + "bingo", + "bio", + "biz", + "bj", + "black", + "blackfriday", + "bloomberg", + "blue", + "bm", + "bmw", + "bn", + "bnpparibas", + "bo", + "boo", + "boutique", + "br", + "brussels", + "bs", + "bt", + "budapest", + "build", + "builders", + "business", + "buzz", + "bv", + "bw", + "by", + "bz", + "bzh", # c - 'ca', 'cab', 'cal', 'camera', 'camp', 'cancerresearch', 'canon', - 'capetown', 'capital', 'caravan', 'cards', 'care', 'career', - 'careers', 'cartier', 'casa', 'cash', 'casino', 'cat', - 'catering', 'cbn', 'cc', 'cd', 'center', 'ceo', 'cern', 'cf', - 'cg', 'ch', 'channel', 'chat', 'cheap', 'christmas', 'chrome', - 'church', 'ci', 'citic', 'city', 'ck', 'cl', 'claims', - 'cleaning', 'click', 'clinic', 'clothing', 'club', 'cm', 'cn', - 'co', 'coach', 'codes', 'coffee', 'college', 'cologne', 'com', - 'community', 'company', 'computer', 'condos', 'construction', - 'consulting', 'contractors', 'cooking', 'cool', 'coop', - 'country', 'cr', 'credit', 'creditcard', 'cricket', 'crs', - 'cruises', 'cu', 'cuisinella', 'cv', 'cw', 'cx', 'cy', 'cymru', - 'cz', + "ca", + "cab", + "cal", + "camera", + "camp", + "cancerresearch", + "canon", + "capetown", + "capital", + "caravan", + "cards", + "care", + "career", + "careers", + "cartier", + "casa", + "cash", + "casino", + "cat", + "catering", + "cbn", + "cc", + "cd", + "center", + "ceo", + "cern", + "cf", + "cg", + "ch", + "channel", + "chat", + "cheap", + "christmas", + "chrome", + "church", + "ci", + "citic", + "city", + "ck", + "cl", + "claims", + "cleaning", + "click", + "clinic", + "clothing", + "club", + "cm", + "cn", + "co", + "coach", + "codes", + "coffee", + "college", + "cologne", + "com", + "community", + "company", + "computer", + "condos", + "construction", + "consulting", + "contractors", + "cooking", + "cool", + "coop", + "country", + "cr", + "credit", + "creditcard", + "cricket", + "crs", + "cruises", + "cu", + "cuisinella", + "cv", + "cw", + "cx", + "cy", + "cymru", + "cz", # d - 'dabur', 'dad', 'dance', 'dating', 'day', 'dclk', 'de', 'deals', - 'degree', 'delivery', 'democrat', 'dental', 'dentist', 'desi', - 'design', 'dev', 'diamonds', 'diet', 'digital', 'direct', - 'directory', 'discount', 'dj', 'dk', 'dm', 'dnp', 'do', 'docs', - 'domains', 'doosan', 'durban', 'dvag', 'dz', + "dabur", + "dad", + "dance", + "dating", + "day", + "dclk", + "de", + "deals", + "degree", + "delivery", + "democrat", + "dental", + "dentist", + "desi", + "design", + "dev", + "diamonds", + "diet", + "digital", + "direct", + "directory", + "discount", + "dj", + "dk", + "dm", + "dnp", + "do", + "docs", + "domains", + "doosan", + "durban", + "dvag", + "dz", # e - 'eat', 'ec', 'edu', 'education', 'ee', 'eg', 'email', 'emerck', - 'energy', 'engineer', 'engineering', 'enterprises', 'equipment', - 'er', 'es', 'esq', 'estate', 'et', 'eu', 'eurovision', 'eus', - 'events', 'everbank', 'exchange', 'expert', 'exposed', + "eat", + "ec", + "edu", + "education", + "ee", + "eg", + "email", + "emerck", + "energy", + "engineer", + "engineering", + "enterprises", + "equipment", + "er", + "es", + "esq", + "estate", + "et", + "eu", + "eurovision", + "eus", + "events", + "everbank", + "exchange", + "expert", + "exposed", # f - 'fail', 'fans', 'farm', 'fashion', 'feedback', 'fi', 'finance', - 'financial', 'firmdale', 'fish', 'fishing', 'fit', 'fitness', - 'fj', 'fk', 'flights', 'florist', 'flowers', 'flsmidth', 'fly', - 'fm', 'fo', 'foo', 'football', 'forsale', 'foundation', 'fr', - 'frl', 'frogans', 'fund', 'furniture', 'futbol', + "fail", + "fans", + "farm", + "fashion", + "feedback", + "fi", + "finance", + "financial", + "firmdale", + "fish", + "fishing", + "fit", + "fitness", + "fj", + "fk", + "flights", + "florist", + "flowers", + "flsmidth", + "fly", + "fm", + "fo", + "foo", + "football", + "forsale", + "foundation", + "fr", + "frl", + "frogans", + "fund", + "furniture", + "futbol", # g - 'ga', 'gal', 'gallery', 'garden', 'gb', 'gbiz', 'gd', 'gdn', - 'ge', 'gent', 'gf', 'gg', 'ggee', 'gh', 'gi', 'gift', 'gifts', - 'gives', 'gl', 'glass', 'gle', 'global', 'globo', 'gm', 'gmail', - 'gmo', 'gmx', 'gn', 'goldpoint', 'goog', 'google', 'gop', 'gov', - 'gp', 'gq', 'gr', 'graphics', 'gratis', 'green', 'gripe', 'gs', - 'gt', 'gu', 'guide', 'guitars', 'guru', 'gw', 'gy', + "ga", + "gal", + "gallery", + "garden", + "gb", + "gbiz", + "gd", + "gdn", + "ge", + "gent", + "gf", + "gg", + "ggee", + "gh", + "gi", + "gift", + "gifts", + "gives", + "gl", + "glass", + "gle", + "global", + "globo", + "gm", + "gmail", + "gmo", + "gmx", + "gn", + "goldpoint", + "goog", + "google", + "gop", + "gov", + "gp", + "gq", + "gr", + "graphics", + "gratis", + "green", + "gripe", + "gs", + "gt", + "gu", + "guide", + "guitars", + "guru", + "gw", + "gy", # h - 'hamburg', 'hangout', 'haus', 'healthcare', 'help', 'here', - 'hermes', 'hiphop', 'hiv', 'hk', 'hm', 'hn', 'holdings', - 'holiday', 'homes', 'horse', 'host', 'hosting', 'house', 'how', - 'hr', 'ht', 'hu', + "hamburg", + "hangout", + "haus", + "healthcare", + "help", + "here", + "hermes", + "hiphop", + "hiv", + "hk", + "hm", + "hn", + "holdings", + "holiday", + "homes", + "horse", + "host", + "hosting", + "house", + "how", + "hr", + "ht", + "hu", # i - 'ibm', 'id', 'ie', 'ifm', 'il', 'im', 'immo', 'immobilien', 'in', - 'industries', 'info', 'ing', 'ink', 'institute', 'insure', 'int', - 'international', 'investments', 'io', 'iq', 'ir', 'irish', 'is', - 'it', 'iwc', + "ibm", + "id", + "ie", + "ifm", + "il", + "im", + "immo", + "immobilien", + "in", + "industries", + "info", + "ing", + "ink", + "institute", + "insure", + "int", + "international", + "investments", + "io", + "iq", + "ir", + "irish", + "is", + "it", + "iwc", # j - 'jcb', 'je', 'jetzt', 'jm', 'jo', 'jobs', 'joburg', 'jp', - 'juegos', + "jcb", + "je", + "jetzt", + "jm", + "jo", + "jobs", + "joburg", + "jp", + "juegos", # k - 'kaufen', 'kddi', 'ke', 'kg', 'kh', 'ki', 'kim', 'kitchen', - 'kiwi', 'km', 'kn', 'koeln', 'kp', 'kr', 'krd', 'kred', 'kw', - 'ky', 'kyoto', 'kz', + "kaufen", + "kddi", + "ke", + "kg", + "kh", + "ki", + "kim", + "kitchen", + "kiwi", + "km", + "kn", + "koeln", + "kp", + "kr", + "krd", + "kred", + "kw", + "ky", + "kyoto", + "kz", # l - 'la', 'lacaixa', 'land', 'lat', 'latrobe', 'lawyer', 'lb', 'lc', - 'lds', 'lease', 'legal', 'lgbt', 'li', 'lidl', 'life', - 'lighting', 'limited', 'limo', 'link', 'lk', 'loans', - 'localhost', 'london', 'lotte', 'lotto', 'lr', 'ls', 'lt', - 'ltda', 'lu', 'luxe', 'luxury', 'lv', 'ly', + "la", + "lacaixa", + "land", + "lat", + "latrobe", + "lawyer", + "lb", + "lc", + "lds", + "lease", + "legal", + "lgbt", + "li", + "lidl", + "life", + "lighting", + "limited", + "limo", + "link", + "lk", + "loans", + "localhost", + "london", + "lotte", + "lotto", + "lr", + "ls", + "lt", + "ltda", + "lu", + "luxe", + "luxury", + "lv", + "ly", # m - 'ma', 'madrid', 'maison', 'management', 'mango', 'market', - 'marketing', 'marriott', 'mc', 'md', 'me', 'media', 'meet', - 'melbourne', 'meme', 'memorial', 'menu', 'mg', 'mh', 'miami', - 'mil', 'mini', 'mk', 'ml', 'mm', 'mn', 'mo', 'mobi', 'moda', - 'moe', 'monash', 'money', 'mormon', 'mortgage', 'moscow', - 'motorcycles', 'mov', 'mp', 'mq', 'mr', 'ms', 'mt', 'mu', - 'museum', 'mv', 'mw', 'mx', 'my', 'mz', + "ma", + "madrid", + "maison", + "management", + "mango", + "market", + "marketing", + "marriott", + "mc", + "md", + "me", + "media", + "meet", + "melbourne", + "meme", + "memorial", + "menu", + "mg", + "mh", + "miami", + "mil", + "mini", + "mk", + "ml", + "mm", + "mn", + "mo", + "mobi", + "moda", + "moe", + "monash", + "money", + "mormon", + "mortgage", + "moscow", + "motorcycles", + "mov", + "mp", + "mq", + "mr", + "ms", + "mt", + "mu", + "museum", + "mv", + "mw", + "mx", + "my", + "mz", # n - 'na', 'nagoya', 'name', 'navy', 'nc', 'ne', 'net', 'network', - 'neustar', 'new', 'nexus', 'nf', 'ng', 'ngo', 'nhk', 'ni', - 'nico', 'ninja', 'nl', 'no', 'np', 'nr', 'nra', 'nrw', 'ntt', - 'nu', 'nyc', 'nz', + "na", + "nagoya", + "name", + "navy", + "nc", + "ne", + "net", + "network", + "neustar", + "new", + "nexus", + "nf", + "ng", + "ngo", + "nhk", + "ni", + "nico", + "ninja", + "nl", + "no", + "np", + "nr", + "nra", + "nrw", + "ntt", + "nu", + "nyc", + "nz", # o - 'okinawa', 'om', 'one', 'ong', 'onl', 'ooo', 'org', 'organic', - 'osaka', 'otsuka', 'ovh', + "okinawa", + "om", + "one", + "ong", + "onl", + "ooo", + "org", + "organic", + "osaka", + "otsuka", + "ovh", # p - 'pa', 'paris', 'partners', 'parts', 'party', 'pe', 'pf', 'pg', - 'ph', 'pharmacy', 'photo', 'photography', 'photos', 'physio', - 'pics', 'pictures', 'pink', 'pizza', 'pk', 'pl', 'place', - 'plumbing', 'pm', 'pn', 'pohl', 'poker', 'porn', 'post', 'pr', - 'praxi', 'press', 'pro', 'prod', 'productions', 'prof', - 'properties', 'property', 'ps', 'pt', 'pub', 'pw', 'py', + "pa", + "paris", + "partners", + "parts", + "party", + "pe", + "pf", + "pg", + "ph", + "pharmacy", + "photo", + "photography", + "photos", + "physio", + "pics", + "pictures", + "pink", + "pizza", + "pk", + "pl", + "place", + "plumbing", + "pm", + "pn", + "pohl", + "poker", + "porn", + "post", + "pr", + "praxi", + "press", + "pro", + "prod", + "productions", + "prof", + "properties", + "property", + "ps", + "pt", + "pub", + "pw", + "py", # q - 'qa', 'qpon', 'quebec', + "qa", + "qpon", + "quebec", # r - 're', 'realtor', 'recipes', 'red', 'rehab', 'reise', 'reisen', - 'reit', 'ren', 'rentals', 'repair', 'report', 'republican', - 'rest', 'restaurant', 'reviews', 'rich', 'rio', 'rip', 'ro', - 'rocks', 'rodeo', 'rs', 'rsvp', 'ru', 'ruhr', 'rw', 'ryukyu', + "re", + "realtor", + "recipes", + "red", + "rehab", + "reise", + "reisen", + "reit", + "ren", + "rentals", + "repair", + "report", + "republican", + "rest", + "restaurant", + "reviews", + "rich", + "rio", + "rip", + "ro", + "rocks", + "rodeo", + "rs", + "rsvp", + "ru", + "ruhr", + "rw", + "ryukyu", # s - 'sa', 'saarland', 'sale', 'samsung', 'sarl', 'saxo', 'sb', 'sc', - 'sca', 'scb', 'schmidt', 'school', 'schule', 'schwarz', - 'science', 'scot', 'sd', 'se', 'services', 'sew', 'sexy', 'sg', - 'sh', 'shiksha', 'shoes', 'shriram', 'si', 'singles', 'sj', 'sk', - 'sky', 'sl', 'sm', 'sn', 'so', 'social', 'software', 'sohu', - 'solar', 'solutions', 'soy', 'space', 'spiegel', 'sr', 'st', - 'style', 'su', 'supplies', 'supply', 'support', 'surf', - 'surgery', 'suzuki', 'sv', 'sx', 'sy', 'sydney', 'systems', 'sz', + "sa", + "saarland", + "sale", + "samsung", + "sarl", + "saxo", + "sb", + "sc", + "sca", + "scb", + "schmidt", + "school", + "schule", + "schwarz", + "science", + "scot", + "sd", + "se", + "services", + "sew", + "sexy", + "sg", + "sh", + "shiksha", + "shoes", + "shriram", + "si", + "singles", + "sj", + "sk", + "sky", + "sl", + "sm", + "sn", + "so", + "social", + "software", + "sohu", + "solar", + "solutions", + "soy", + "space", + "spiegel", + "sr", + "st", + "style", + "su", + "supplies", + "supply", + "support", + "surf", + "surgery", + "suzuki", + "sv", + "sx", + "sy", + "sydney", + "systems", + "sz", # t - 'taipei', 'tatar', 'tattoo', 'tax', 'tc', 'td', 'technology', - 'tel', 'temasek', 'tennis', 'tf', 'tg', 'th', 'tienda', 'tips', - 'tires', 'tirol', 'tj', 'tk', 'tl', 'tm', 'tn', 'to', 'today', - 'tokyo', 'tools', 'top', 'toshiba', 'town', 'toys', 'tp', 'tr', - 'trade', 'training', 'travel', 'trust', 'tt', 'tui', 'tv', 'tw', - 'tz', + "taipei", + "tatar", + "tattoo", + "tax", + "tc", + "td", + "technology", + "tel", + "temasek", + "tennis", + "tf", + "tg", + "th", + "tienda", + "tips", + "tires", + "tirol", + "tj", + "tk", + "tl", + "tm", + "tn", + "to", + "today", + "tokyo", + "tools", + "top", + "toshiba", + "town", + "toys", + "tp", + "tr", + "trade", + "training", + "travel", + "trust", + "tt", + "tui", + "tv", + "tw", + "tz", # u - 'ua', 'ug', 'uk', 'university', 'uno', 'uol', 'us', 'uy', 'uz', + "ua", + "ug", + "uk", + "university", + "uno", + "uol", + "us", + "uy", + "uz", # v - 'va', 'vacations', 'vc', 've', 'vegas', 'ventures', - 'versicherung', 'vet', 'vg', 'vi', 'viajes', 'video', 'villas', - 'vision', 'vlaanderen', 'vn', 'vodka', 'vote', 'voting', 'voto', - 'voyage', 'vu', + "va", + "vacations", + "vc", + "ve", + "vegas", + "ventures", + "versicherung", + "vet", + "vg", + "vi", + "viajes", + "video", + "villas", + "vision", + "vlaanderen", + "vn", + "vodka", + "vote", + "voting", + "voto", + "voyage", + "vu", # w - 'wales', 'wang', 'watch', 'webcam', 'website', 'wed', 'wedding', - 'wf', 'whoswho', 'wien', 'wiki', 'williamhill', 'wme', 'work', - 'works', 'world', 'ws', 'wtc', 'wtf', + "wales", + "wang", + "watch", + "webcam", + "website", + "wed", + "wedding", + "wf", + "whoswho", + "wien", + "wiki", + "williamhill", + "wme", + "work", + "works", + "world", + "ws", + "wtc", + "wtf", # x - 'xn--1qqw23a', 'xn--3bst00m', 'xn--3ds443g', 'xn--3e0b707e', - 'xn--45brj9c', 'xn--45q11c', 'xn--4gbrim', 'xn--55qw42g', - 'xn--55qx5d', 'xn--6frz82g', 'xn--6qq986b3xl', 'xn--80adxhks', - 'xn--80ao21a', 'xn--80asehdb', 'xn--80aswg', 'xn--90a3ac', - 'xn--90ais', 'xn--b4w605ferd', 'xn--c1avg', 'xn--cg4bki', - 'xn--clchc0ea0b2g2a9gcd', 'xn--czr694b', 'xn--czrs0t', - 'xn--czru2d', 'xn--d1acj3b', 'xn--d1alf', 'xn--fiq228c5hs', - 'xn--fiq64b', 'xn--fiqs8s', 'xn--fiqz9s', 'xn--flw351e', - 'xn--fpcrj9c3d', 'xn--fzc2c9e2c', 'xn--gecrj9c', 'xn--h2brj9c', - 'xn--hxt814e', 'xn--i1b6b1a6a2e', 'xn--io0a7i', 'xn--j1amh', - 'xn--j6w193g', 'xn--kprw13d', 'xn--kpry57d', 'xn--kput3i', - 'xn--l1acc', 'xn--lgbbat1ad8j', 'xn--mgb9awbf', - 'xn--mgba3a4f16a', 'xn--mgbaam7a8h', 'xn--mgbab2bd', - 'xn--mgbayh7gpa', 'xn--mgbbh1a71e', 'xn--mgbc0a9azcg', - 'xn--mgberp4a5d4ar', 'xn--mgbx4cd0ab', 'xn--ngbc5azd', - 'xn--node', 'xn--nqv7f', 'xn--nqv7fs00ema', 'xn--o3cw4h', - 'xn--ogbpf8fl', 'xn--p1acf', 'xn--p1ai', 'xn--pgbs0dh', - 'xn--q9jyb4c', 'xn--qcka1pmc', 'xn--rhqv96g', 'xn--s9brj9c', - 'xn--ses554g', 'xn--unup4y', 'xn--vermgensberater-ctb', - 'xn--vermgensberatung-pwb', 'xn--vhquv', 'xn--wgbh1c', - 'xn--wgbl6a', 'xn--xhq521b', 'xn--xkc2al3hye2a', - 'xn--xkc2dl3a5ee0h', 'xn--yfro4i67o', 'xn--ygbi2ammx', - 'xn--zfr164b', 'xxx', 'xyz', + "xn--1qqw23a", + "xn--3bst00m", + "xn--3ds443g", + "xn--3e0b707e", + "xn--45brj9c", + "xn--45q11c", + "xn--4gbrim", + "xn--55qw42g", + "xn--55qx5d", + "xn--6frz82g", + "xn--6qq986b3xl", + "xn--80adxhks", + "xn--80ao21a", + "xn--80asehdb", + "xn--80aswg", + "xn--90a3ac", + "xn--90ais", + "xn--b4w605ferd", + "xn--c1avg", + "xn--cg4bki", + "xn--clchc0ea0b2g2a9gcd", + "xn--czr694b", + "xn--czrs0t", + "xn--czru2d", + "xn--d1acj3b", + "xn--d1alf", + "xn--fiq228c5hs", + "xn--fiq64b", + "xn--fiqs8s", + "xn--fiqz9s", + "xn--flw351e", + "xn--fpcrj9c3d", + "xn--fzc2c9e2c", + "xn--gecrj9c", + "xn--h2brj9c", + "xn--hxt814e", + "xn--i1b6b1a6a2e", + "xn--io0a7i", + "xn--j1amh", + "xn--j6w193g", + "xn--kprw13d", + "xn--kpry57d", + "xn--kput3i", + "xn--l1acc", + "xn--lgbbat1ad8j", + "xn--mgb9awbf", + "xn--mgba3a4f16a", + "xn--mgbaam7a8h", + "xn--mgbab2bd", + "xn--mgbayh7gpa", + "xn--mgbbh1a71e", + "xn--mgbc0a9azcg", + "xn--mgberp4a5d4ar", + "xn--mgbx4cd0ab", + "xn--ngbc5azd", + "xn--node", + "xn--nqv7f", + "xn--nqv7fs00ema", + "xn--o3cw4h", + "xn--ogbpf8fl", + "xn--p1acf", + "xn--p1ai", + "xn--pgbs0dh", + "xn--q9jyb4c", + "xn--qcka1pmc", + "xn--rhqv96g", + "xn--s9brj9c", + "xn--ses554g", + "xn--unup4y", + "xn--vermgensberater-ctb", + "xn--vermgensberatung-pwb", + "xn--vhquv", + "xn--wgbh1c", + "xn--wgbl6a", + "xn--xhq521b", + "xn--xkc2al3hye2a", + "xn--xkc2dl3a5ee0h", + "xn--yfro4i67o", + "xn--ygbi2ammx", + "xn--zfr164b", + "xxx", + "xyz", # y - 'yachts', 'yandex', 'ye', 'yodobashi', 'yoga', 'yokohama', - 'youtube', 'yt', + "yachts", + "yandex", + "ye", + "yodobashi", + "yoga", + "yokohama", + "youtube", + "yt", # z - 'za', 'zip', 'zm', 'zone', 'zuerich', 'zw' + "za", + "zip", + "zm", + "zone", + "zuerich", + "zw", ] diff --git a/emmett/validators/inside.py b/emmett/validators/inside.py index 6d3a36ff..c391909d 100644 --- a/emmett/validators/inside.py +++ b/emmett/validators/inside.py @@ -1,16 +1,16 @@ # -*- coding: utf-8 -*- """ - emmett.validators.inside - ------------------------ +emmett.validators.inside +------------------------ - Validators that check presence/absence of given value in a set. +Validators that check presence/absence of given value in a set. - :copyright: 2014 Giovanni Barillari +:copyright: 2014 Giovanni Barillari - Based on the web2py's validators (http://www.web2py.com) - :copyright: (c) by Massimo Di Pierro +Based on the web2py's validators (http://www.web2py.com) +:copyright: (c) by Massimo Di Pierro - :license: LGPLv3 (http://www.gnu.org/licenses/lgpl.html) +:license: LGPLv3 (http://www.gnu.org/licenses/lgpl.html) """ from emmett_core.utils import cachedprop @@ -42,14 +42,11 @@ def _lt(self, val1, val2, eq=False): def __call__(self, value): minimum = self.minimum() if callable(self.minimum) else self.minimum maximum = self.maximum() if callable(self.maximum) else self.maximum - if ( - (minimum is None or self._gt(value, minimum, self.inc[0])) and - (maximum is None or self._lt(value, maximum, self.inc[1])) + if (minimum is None or self._gt(value, minimum, self.inc[0])) and ( + maximum is None or self._lt(value, maximum, self.inc[1]) ): return value, None - return value, translate( - self._range_error(self.message, minimum, maximum) - ) + return value, translate(self._range_error(self.message, minimum, maximum)) def _range_error(self, message, minimum, maximum): if message is None: @@ -66,22 +63,14 @@ def _range_error(self, message, minimum, maximum): class inSet(Validator): - def __init__( - self, - theset, - labels=None, - multiple=False, - zero=None, - sort=False, - message=None - ): + def __init__(self, theset, labels=None, multiple=False, zero=None, sort=False, message=None): super().__init__(message=message) self.multiple = multiple if ( - theset and - isinstance(theset, (tuple, list)) and - isinstance(theset[0], (tuple, list)) and - len(theset[0]) == 2 + theset + and isinstance(theset, (tuple, list)) + and isinstance(theset[0], (tuple, list)) + and len(theset[0]) == 2 ): lset, llabels = [], [] for item, label in theset: @@ -89,10 +78,7 @@ def __init__( llabels.append(str(label)) self.theset = lset self.labels = llabels - elif ( - theset and - isinstance(theset, dict) - ): + elif theset and isinstance(theset, dict): lset, llabels = [], [] for item, label in theset.items(): lset.append(str(item)) @@ -113,7 +99,7 @@ def options(self, zero=True): if self.sort: items.sort(options_sorter) if zero and self.zero is not None and not self.multiple: - items.insert(0, ('', self.zero)) + items.insert(0, ("", self.zero)) return items def __call__(self, value): @@ -126,25 +112,20 @@ def __call__(self, value): values = [value] else: values = [value] - failures = [ - x for x in values - if (to_unicode(x) or '') not in self.theset] + failures = [x for x in values if (to_unicode(x) or "") not in self.theset] if failures and self.theset: - if self.multiple and (value is None or value == ''): + if self.multiple and (value is None or value == ""): return ([], None) return value, translate(self.message) if self.multiple: - if ( - isinstance(self.multiple, (tuple, list)) and - not self.multiple[0] <= len(values) < self.multiple[1] - ): + if isinstance(self.multiple, (tuple, list)) and not self.multiple[0] <= len(values) < self.multiple[1]: return values, translate(self.message) return values, None return value, None class DBValidator(Validator): - def __init__(self, db, tablename, fieldname='id', dbset=None, message=None): + def __init__(self, db, tablename, fieldname="id", dbset=None, message=None): super().__init__(message=message) self.db = db self.tablename = tablename @@ -171,23 +152,9 @@ def field(self): class inDB(DBValidator): def __init__( - self, - db, - tablename, - fieldname='id', - dbset=None, - label_field=None, - multiple=False, - orderby=None, - message=None + self, db, tablename, fieldname="id", dbset=None, label_field=None, multiple=False, orderby=None, message=None ): - super().__init__( - db, - tablename, - fieldname=fieldname, - dbset=dbset, - message=message - ) + super().__init__(db, tablename, fieldname=fieldname, dbset=dbset, message=message) self.label_field = label_field self.multiple = multiple self.orderby = orderby @@ -209,18 +176,16 @@ def options(self, zero=True): items = [(r.id, self.db[self.tablename]._format % r) for r in records] else: items = [(r.id, r.id) for r in records] - #if self.sort: + # if self.sort: # items.sort(options_sorter) - #if zero and self.zero is not None and not self.multiple: + # if zero and self.zero is not None and not self.multiple: # items.insert(0, ('', self.zero)) return items def __call__(self, value): if self.multiple: values = value if isinstance(value, list) else [value] - records = self.dbset.where( - self.field.belongs(values) - ).select(self.field, distinct=True).column(self.field) + records = self.dbset.where(self.field.belongs(values)).select(self.field, distinct=True).column(self.field) if set(values).issubset(set(records)): return values, None else: @@ -231,11 +196,9 @@ def __call__(self, value): class notInDB(DBValidator): def __call__(self, value): - row = self.dbset.where( - self.field == value - ).select(limitby=(0, 1)).first() + row = self.dbset.where(self.field == value).select(limitby=(0, 1)).first() if row: - record_id = getattr(current, '_dbvalidation_record_id_', None) + record_id = getattr(current, "_dbvalidation_record_id_", None) if row.id != record_id: return value, translate(self.message) return value, None diff --git a/emmett/validators/process.py b/emmett/validators/process.py index 73eb4341..8e3b852b 100644 --- a/emmett/validators/process.py +++ b/emmett/validators/process.py @@ -1,14 +1,14 @@ # -*- coding: utf-8 -*- """ - emmett.validators.process - ------------------------- +emmett.validators.process +------------------------- - Validators that transform values. +Validators that transform values. - Ported from the original validators of web2py (http://www.web2py.com) +Ported from the original validators of web2py (http://www.web2py.com) - :copyright: (c) by Massimo Di Pierro - :license: LGPLv3 (http://www.gnu.org/licenses/lgpl.html) +:copyright: (c) by Massimo Di Pierro +:license: LGPLv3 (http://www.gnu.org/licenses/lgpl.html) """ import re @@ -17,7 +17,7 @@ # TODO: check unicode conversions from .._shortcuts import to_unicode from .basic import Validator -from .helpers import translate, LazyCrypt +from .helpers import LazyCrypt, translate class Cleanup(Validator): @@ -28,7 +28,7 @@ def __init__(self, regex=None, message=None): self.regex = self.rule if regex is None else re.compile(regex) def __call__(self, value): - v = self.regex.sub('', (to_unicode(value) or '').strip()) + v = self.regex.sub("", (to_unicode(value) or "").strip()) return v, None @@ -91,7 +91,7 @@ def _urlify(self, s): # remove leading and trailing hyphens s = s.strip(r"-") # enforce maximum length - return s[:self.maxlen] + return s[: self.maxlen] class Crypt(Validator): @@ -126,16 +126,14 @@ class Crypt(Validator): an existing salted password """ - def __init__( - self, key=None, algorithm='pbkdf2(1000,20,sha512)', salt=True, message=None - ): + def __init__(self, key=None, algorithm="pbkdf2(1000,20,sha512)", salt=True, message=None): super().__init__(message=message) self.key = key self.digest_alg = algorithm self.salt = salt def __call__(self, value): - if getattr(value, '_emt_field_hashed_contents_', False): + if getattr(value, "_emt_field_hashed_contents_", False): return value, None crypt = LazyCrypt(self, value) if isinstance(value, LazyCrypt) and value == crypt: diff --git a/emmett/wrappers/request.py b/emmett/wrappers/request.py index a8b89796..707aa9a9 100644 --- a/emmett/wrappers/request.py +++ b/emmett/wrappers/request.py @@ -1,16 +1,15 @@ # -*- coding: utf-8 -*- """ - emmett.wrappers.request - ----------------------- +emmett.wrappers.request +----------------------- - Provides http request wrappers. +Provides http request wrappers. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ import pendulum - from emmett_core.http.wrappers.request import Request as _Request from emmett_core.utils import cachedprop diff --git a/emmett/wrappers/response.py b/emmett/wrappers/response.py index 7413c4c0..46fc1514 100644 --- a/emmett/wrappers/response.py +++ b/emmett/wrappers/response.py @@ -1,12 +1,12 @@ # -*- coding: utf-8 -*- """ - emmett.wrappers.response - ------------------------ +emmett.wrappers.response +------------------------ - Provides response wrappers. +Provides response wrappers. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ from typing import Any @@ -34,11 +34,7 @@ def alerts(self, **kwargs): return get_flashed_messages(**kwargs) def _meta_tmpl(self): - return [ - (key, htmlescape(val)) for key, val in self.meta.items() - ] + return [(key, htmlescape(val)) for key, val in self.meta.items()] def _meta_tmpl_prop(self): - return [ - (key, htmlescape(val)) for key, val in self.meta_prop.items() - ] + return [(key, htmlescape(val)) for key, val in self.meta_prop.items()] diff --git a/emmett/wrappers/websocket.py b/emmett/wrappers/websocket.py index 92e4d76c..f1fc10f1 100644 --- a/emmett/wrappers/websocket.py +++ b/emmett/wrappers/websocket.py @@ -1,12 +1,12 @@ # -*- coding: utf-8 -*- """ - emmett.wrappers.websocket - ------------------------- +emmett.wrappers.websocket +------------------------- - Provides http websocket wrappers. +Provides http websocket wrappers. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ from emmett_core.http.wrappers.websocket import Websocket as Websocket diff --git a/pyproject.toml b/pyproject.toml index cd827a40..c5053a58 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,12 +75,77 @@ include = [ [tool.hatch.metadata] allow-direct-references = true +[tool.ruff] +line-length = 120 + +[tool.ruff.format] +quote-style = 'double' + +[tool.ruff.lint] +extend-select = [ + # E and F are enabled by default + 'B', # flake8-bugbear + 'C4', # flake8-comprehensions + 'C90', # mccabe + 'I', # isort + 'N', # pep8-naming + 'Q', # flake8-quotes + 'RUF100', # ruff (unused noqa) + 'S', # flake8-bandit + 'W', # pycodestyle +] +extend-ignore = [ + 'B006', # mutable function args are fine + 'B008', # function calls in args defaults are fine + 'B009', # getattr with constants is fine + 'B034', # re.split won't confuse us + 'B904', # rising without from is fine + 'E731', # assigning lambdas is fine + 'F403', # import * is fine + 'N801', # leave to us class naming + 'N802', # leave to us method naming + 'N806', # leave to us var naming + 'N811', # leave to us var naming + 'N814', # leave to us var naming + 'N818', # leave to us exceptions naming + 'S101', # assert is fine + 'S104', # leave to us security + 'S105', # leave to us security + 'S106', # leave to us security + 'S107', # leave to us security + 'S110', # pass on exceptions is fine + 'S301', # leave to us security + 'S324', # leave to us security +] +mccabe = { max-complexity = 44 } + +[tool.ruff.lint.isort] +combine-as-imports = true +lines-after-imports = 2 +known-first-party = ['emmett', 'tests'] + +[tool.ruff.lint.per-file-ignores] +'emmett/__init__.py' = ['F401'] +'emmett/http.py' = ['F401'] +'emmett/orm/__init__.py' = ['F401'] +'emmett/orm/engines/__init__.py' = ['F401'] +'emmett/orm/migrations/__init__.py' = ['F401'] +'emmett/orm/migrations/revisions.py' = ['B018'] +'emmett/tools/__init__.py' = ['F401'] +'emmett/tools/auth/__init__.py' = ['F401'] +'emmett/validators/__init__.py' = ['F401'] +'tests/**' = ['B017', 'B018', 'E711', 'E712', 'E741', 'F841', 'S110', 'S501'] + +[tool.pytest.ini_options] +asyncio_mode = 'auto' + [tool.uv] dev-dependencies = [ "ipaddress>=1.0", "pytest>=7.1", "pytest-asyncio>=0.15", "psycopg2-binary~=2.9; python_version != '3.13'", + "ruff~=0.5.0", ] [tool.uv.sources] diff --git a/tests/helpers.py b/tests/helpers.py index 3e29bc07..cedc5ccf 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -1,21 +1,23 @@ # -*- coding: utf-8 -*- """ - tests.helpers - ------------- +tests.helpers +------------- - Tests helpers +Tests helpers """ from contextlib import contextmanager from emmett_core.protocols.rsgi.test_client.scope import ScopeBuilder -from emmett.rsgi.wrappers import Request, Websocket + from emmett.ctx import RequestContext, WSContext, current from emmett.datastructures import sdict +from emmett.rsgi.wrappers import Request, Websocket from emmett.serializers import Serializers from emmett.wrappers.response import Response -json_dump = Serializers.get_for('json') + +json_dump = Serializers.get_for("json") class FakeRequestContext(RequestContext): @@ -31,7 +33,7 @@ def __init__(self): self._send_storage = [] async def receive(self): - return json_dump({'foo': 'bar'}) + return json_dump({"foo": "bar"}) async def send_str(self, data): self._send_storage.append(data) @@ -48,9 +50,7 @@ async def init(self): self.transport = FakeWSTransport() async def receive(self): - return sdict( - data=await self.transport.receive() - ) + return sdict(data=await self.transport.receive()) def close(self): pass @@ -60,11 +60,7 @@ class FakeWSContext(WSContext): def __init__(self, app, scope): self.app = app self._proto = FakeWsProto() - self.websocket = Websocket( - scope, - scope.path, - self._proto - ) + self.websocket = Websocket(scope, scope.path, self._proto) self._receive_storage = [] @property @@ -84,7 +80,7 @@ def current_ctx(path, app=None): def ws_ctx(path, app=None): builder = ScopeBuilder(path) scope_data = builder.get_data()[0] - scope_data.proto = 'ws' + scope_data.proto = "ws" token = current._init_(FakeWSContext(app, scope_data)) yield current current._close_(token) diff --git a/tests/test_auth.py b/tests/test_auth.py index b4debf9b..a37e9197 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -1,63 +1,65 @@ # -*- coding: utf-8 -*- """ - tests.auth - ---------- +tests.auth +---------- - Test Emmett Auth module +Test Emmett Auth module """ import os -import pytest import shutil + +import pytest + from emmett import App -from emmett.orm import Database, Field, Model, has_many, belongs_to +from emmett.orm import Database, Field, Model, belongs_to, has_many from emmett.sessions import SessionManager from emmett.tools import Auth, Mailer from emmett.tools.auth.models import AuthUser class User(AuthUser): - has_many('things') + has_many("things") gender = Field() class Thing(Model): - belongs_to('user') + belongs_to("user") -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def app(): rv = App(__name__) - rv.config.mailer.sender = 'nina@massivedynamics.com' + rv.config.mailer.sender = "nina@massivedynamics.com" rv.config.mailer.suppress = True rv.config.auth.single_template = True rv.config.auth.hmac_key = "foobar" - rv.pipeline = [SessionManager.cookies('foobar')] + rv.pipeline = [SessionManager.cookies("foobar")] return rv -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def client(app): return app.test_client() -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def mailer(app): return Mailer(app) -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def _db(app): try: - shutil.rmtree(os.path.join(app.root_path, 'databases')) - except: + shutil.rmtree(os.path.join(app.root_path, "databases")) + except Exception: pass db = Database(app, auto_migrate=True) app.pipeline.append(db.pipe) return db -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def auth(app, _db, mailer): auth = Auth(app, _db, user_model=User) app.pipeline.append(auth.pipe) @@ -65,7 +67,7 @@ def auth(app, _db, mailer): return auth -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def db(_db, auth): with _db.connection(): _db.define_models(Thing) @@ -75,24 +77,12 @@ def db(_db, auth): def test_models(db): with db.connection(): user = User.create( - email="walter@massivedynamics.com", - password="pocketuniverse", - first_name="Walter", - last_name="Bishop" - ) - group = db.auth_groups.insert( - role="admin" - ) - group2 = db.auth_groups.insert( - role="moderator" - ) - db.auth_memberships.insert( - user=user.id, - auth_group=group - ) - db.auth_permissions.insert( - auth_group=group + email="walter@massivedynamics.com", password="pocketuniverse", first_name="Walter", last_name="Bishop" ) + group = db.auth_groups.insert(role="admin") + group2 = db.auth_groups.insert(role="moderator") + db.auth_memberships.insert(user=user.id, auth_group=group) + db.auth_permissions.insert(auth_group=group) user = db.users[1] assert len(user.auth_memberships()) == 1 assert user.auth_memberships()[0].user == 1 @@ -112,119 +102,128 @@ def test_models(db): def test_registration(mailer, db, client): - page = client.get('/auth/registration').data - assert 'E-mail' in page - assert 'First name' in page - assert 'Last name' in page - assert 'Password' in page - assert 'Confirm password' in page - assert 'Register' in page + page = client.get("/auth/registration").data + assert "E-mail" in page + assert "First name" in page + assert "Last name" in page + assert "Password" in page + assert "Confirm password" in page + assert "Register" in page with mailer.store_mails() as mailbox: - with client.get('/auth/registration').context as ctx: - req = client.post('/auth/registration', data={ - 'email': 'william@massivedynamics.com', - 'first_name': 'William', - 'last_name': 'Bell', - 'password': 'imtheceo', - 'password2': 'imtheceo', - '_csrf_token': list(ctx.session._csrf)[-1] - }, follow_redirects=True) + with client.get("/auth/registration").context as ctx: + req = client.post( + "/auth/registration", + data={ + "email": "william@massivedynamics.com", + "first_name": "William", + "last_name": "Bell", + "password": "imtheceo", + "password2": "imtheceo", + "_csrf_token": list(ctx.session._csrf)[-1], + }, + follow_redirects=True, + ) assert "We sent you an email, check your inbox" in req.data assert len(mailbox) == 1 mail = mailbox[0] assert mail.recipients == ["william@massivedynamics.com"] - assert mail.subject == 'Email verification' + assert mail.subject == "Email verification" mail_as_str = str(mail) - assert 'Hello william@massivedynamics.com!' in mail_as_str - assert 'verify your email' in mail_as_str - verification_code = mail_as_str.split( - "http://localhost/auth/email_verification/")[1].split(" ")[0] - req = client.get( - '/auth/email_verification/{}'.format(verification_code), - follow_redirects=True) + assert "Hello william@massivedynamics.com!" in mail_as_str + assert "verify your email" in mail_as_str + verification_code = mail_as_str.split("http://localhost/auth/email_verification/")[1].split(" ")[0] + req = client.get("/auth/email_verification/{}".format(verification_code), follow_redirects=True) assert "Account verification completed" in req.data def test_login(db, client): - page = client.get('/auth/login').data - assert 'E-mail' in page - assert 'Password' in page - assert 'Sign in' in page - with client.get('/auth/login').context as ctx: - req = client.post('/auth/login', data={ - 'email': 'william@massivedynamics.com', - 'password': 'imtheceo', - '_csrf_token': list(ctx.session._csrf)[-1] - }, follow_redirects=True) - assert 'William' in req.data - assert 'Bell' in req.data - assert 'Save' in req.data + page = client.get("/auth/login").data + assert "E-mail" in page + assert "Password" in page + assert "Sign in" in page + with client.get("/auth/login").context as ctx: + req = client.post( + "/auth/login", + data={ + "email": "william@massivedynamics.com", + "password": "imtheceo", + "_csrf_token": list(ctx.session._csrf)[-1], + }, + follow_redirects=True, + ) + assert "William" in req.data + assert "Bell" in req.data + assert "Save" in req.data def test_password_change(db, client): - with client.get('/auth/login').context as ctx: - with client.post('/auth/login', data={ - 'email': 'william@massivedynamics.com', - 'password': 'imtheceo', - '_csrf_token': list(ctx.session._csrf)[-1] - }, follow_redirects=True): - page = client.get('/auth/password_change').data - assert 'Current password' in page - assert 'New password' in page - assert 'Confirm password' in page - with client.get('/auth/password_change').context as ctx2: - with client.post('/auth/password_change', data={ - 'old_password': 'imtheceo', - 'new_password': 'imthebigceo', - 'new_password2': 'imthebigceo', - '_csrf_token': list(ctx2.session._csrf)[-1] - }, follow_redirects=True) as req: - assert 'Password changed successfully' in req.data - assert 'William' in req.data - assert 'Save' in req.data + with client.get("/auth/login").context as ctx: + with client.post( + "/auth/login", + data={ + "email": "william@massivedynamics.com", + "password": "imtheceo", + "_csrf_token": list(ctx.session._csrf)[-1], + }, + follow_redirects=True, + ): + page = client.get("/auth/password_change").data + assert "Current password" in page + assert "New password" in page + assert "Confirm password" in page + with client.get("/auth/password_change").context as ctx2: + with client.post( + "/auth/password_change", + data={ + "old_password": "imtheceo", + "new_password": "imthebigceo", + "new_password2": "imthebigceo", + "_csrf_token": list(ctx2.session._csrf)[-1], + }, + follow_redirects=True, + ) as req: + assert "Password changed successfully" in req.data + assert "William" in req.data + assert "Save" in req.data with db.connection(): assert ( - db.users(email='william@massivedynamics.com').password == - db.users.password.requires[-1]('imthebigceo')[0]) + db.users(email="william@massivedynamics.com").password == db.users.password.requires[-1]("imthebigceo")[0] + ) def test_password_retrieval(mailer, db, client): - page = client.get('/auth/password_retrieval').data - assert 'Email' in page - assert 'Retrieve password' in page + page = client.get("/auth/password_retrieval").data + assert "Email" in page + assert "Retrieve password" in page with mailer.store_mails() as mailbox: - with client.get('/auth/password_retrieval').context as ctx: - with client.post('/auth/password_retrieval', data={ - 'email': 'william@massivedynamics.com', - '_csrf_token': list(ctx.session._csrf)[-1] - }, follow_redirects=True) as req: - assert 'We sent you an email, check your inbox' in req.data + with client.get("/auth/password_retrieval").context as ctx: + with client.post( + "/auth/password_retrieval", + data={"email": "william@massivedynamics.com", "_csrf_token": list(ctx.session._csrf)[-1]}, + follow_redirects=True, + ) as req: + assert "We sent you an email, check your inbox" in req.data assert len(mailbox) == 1 mail = mailbox[0] assert mail.recipients == ["william@massivedynamics.com"] - assert mail.subject == 'Password reset requested' + assert mail.subject == "Password reset requested" mail_as_str = str(mail) - assert 'A password reset was requested for your account' in mail_as_str - reset_code = mail_as_str.split( - "http://localhost/auth/password_reset/")[1].split(" ")[0] - with client.get( - '/auth/password_reset/{}'.format(reset_code), - follow_redirects=True - ) as req: - assert 'New password' in req.data - assert 'Confirm password' in req.data - assert 'Reset password' in req.data + assert "A password reset was requested for your account" in mail_as_str + reset_code = mail_as_str.split("http://localhost/auth/password_reset/")[1].split(" ")[0] + with client.get("/auth/password_reset/{}".format(reset_code), follow_redirects=True) as req: + assert "New password" in req.data + assert "Confirm password" in req.data + assert "Reset password" in req.data with client.post( - '/auth/password_reset/{}'.format(reset_code), + "/auth/password_reset/{}".format(reset_code), data={ - 'password': 'imtheceo', - 'password2': 'imtheceo', - '_csrf_token': list(req.context.session._csrf)[-1]}, - follow_redirects=True + "password": "imtheceo", + "password2": "imtheceo", + "_csrf_token": list(req.context.session._csrf)[-1], + }, + follow_redirects=True, ) as req2: - assert 'Password changed successfully' in req2.data - assert 'Sign in' in req2.data + assert "Password changed successfully" in req2.data + assert "Sign in" in req2.data with db.connection(): - assert ( - db.users(email='william@massivedynamics.com').password == - db.users.password.requires[-1]('imtheceo')[0]) + assert db.users(email="william@massivedynamics.com").password == db.users.password.requires[-1]("imtheceo")[0] diff --git a/tests/test_cache.py b/tests/test_cache.py index 92b6f77d..15ffd259 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -1,9 +1,9 @@ # -*- coding: utf-8 -*- """ - tests.cache - ----------- +tests.cache +----------- - Test Emmett cache module +Test Emmett cache module """ import pytest @@ -27,17 +27,17 @@ async def test_diskcache(): disk_cache = DiskCache() assert disk_cache._threshold == 500 - assert disk_cache('test', lambda: 2) == 2 - assert disk_cache('test', lambda: 3, 300) == 2 + assert disk_cache("test", lambda: 2) == 2 + assert disk_cache("test", lambda: 3, 300) == 2 - assert await disk_cache('test_loop', _await_2) == 2 - assert await disk_cache('test_loop', _await_3, 300) == 2 + assert await disk_cache("test_loop", _await_2) == 2 + assert await disk_cache("test_loop", _await_3, 300) == 2 - disk_cache.set('test', 3) - assert disk_cache.get('test') == 3 + disk_cache.set("test", 3) + assert disk_cache.get("test") == 3 - disk_cache.set('test', 4, 300) - assert disk_cache.get('test') == 4 + disk_cache.set("test", 4, 300) + assert disk_cache.get("test") == 4 disk_cache.clear() - assert disk_cache.get('test') is None + assert disk_cache.get("test") is None diff --git a/tests/test_logger.py b/tests/test_logger.py index 74128cf5..49638381 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -1,14 +1,15 @@ # -*- coding: utf-8 -*- """ - tests.logger - ------------ +tests.logger +------------ - Test Emmett logging module +Test Emmett logging module """ import logging from emmett_core import log as logger + from emmett import App, sdict @@ -18,18 +19,14 @@ def _call_create_logger(app): def test_user_assign_valid_level(): app = App(__name__) - app.config.logging.pytest = sdict( - level='info' - ) + app.config.logging.pytest = sdict(level="info") result = _call_create_logger(app) assert result.handlers[-1].level == logging.INFO def test_user_assign_invaild_level(): app = App(__name__) - app.config.logging.pytest = sdict( - level='invalid' - ) + app.config.logging.pytest = sdict(level="invalid") result = _call_create_logger(app) assert result.handlers[-1].level == logging.WARNING diff --git a/tests/test_mailer.py b/tests/test_mailer.py index e978028e..2c598e77 100644 --- a/tests/test_mailer.py +++ b/tests/test_mailer.py @@ -1,30 +1,31 @@ # -*- coding: utf-8 -*- """ - tests.mailer - ------------ +tests.mailer +------------ - Test Emmett mailer +Test Emmett mailer """ import base64 import email -import pytest import re import time - from email.header import Header + +import pytest + from emmett import App from emmett.tools.mailer import Mailer, sanitize_address -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def app(): rv = App(__name__) - rv.config.mailer.sender = 'support@example.com' + rv.config.mailer.sender = "support@example.com" return rv -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def mailer(app): rv = Mailer(app) return rv @@ -32,7 +33,7 @@ def mailer(app): def test_message_init(mailer): msg = mailer.mail(subject="subject", recipients="to@example.com") - assert msg.sender == 'support@example.com' + assert msg.sender == "support@example.com" assert msg.recipients == ["to@example.com"] @@ -45,32 +46,23 @@ def test_recipients(mailer): def test_all_recipients(mailer): msg = mailer.mail( - subject="subject", recipients=["somebody@here.com"], - cc=["cc@example.com"], bcc=["bcc@example.com"]) + subject="subject", recipients=["somebody@here.com"], cc=["cc@example.com"], bcc=["bcc@example.com"] + ) assert len(msg.all_recipients) == 3 msg.add_recipient("cc@example.com") assert len(msg.all_recipients) == 3 def test_sender(mailer): - msg = mailer.mail( - subject="subject", - sender=u'ÄÜÖ → ✓ >', - recipients=["to@example.com"]) - assert 'From: =?utf-8?b?w4TDnMOWIOKGkiDinJM=?= ' in \ - str(msg) + msg = mailer.mail(subject="subject", sender="ÄÜÖ → ✓ >", recipients=["to@example.com"]) + assert "From: =?utf-8?b?w4TDnMOWIOKGkiDinJM=?= " in str(msg) def test_sender_as_tuple(mailer): - msg = mailer.mail( - subject="testing", sender=("tester", "tester@example.com")) + msg = mailer.mail(subject="testing", sender=("tester", "tester@example.com")) assert msg.sender == "tester " - msg = mailer.mail( - subject="subject", - sender=(u"ÄÜÖ → ✓", 'from@example.com>'), - recipients=["to@example.com"]) - assert 'From: =?utf-8?b?w4TDnMOWIOKGkiDinJM=?= ' in \ - str(msg) + msg = mailer.mail(subject="subject", sender=("ÄÜÖ → ✓", "from@example.com>"), recipients=["to@example.com"]) + assert "From: =?utf-8?b?w4TDnMOWIOKGkiDinJM=?= " in str(msg) def test_reply_to(mailer): @@ -79,15 +71,14 @@ def test_reply_to(mailer): recipients=["to@example.com"], sender="spammer ", reply_to="somebody ", - body="testing") - h = Header( - "Reply-To: %s" % sanitize_address('somebody ')) + body="testing", + ) + h = Header("Reply-To: %s" % sanitize_address("somebody ")) assert h.encode() in str(msg) def test_missing_sender(mailer): - msg = mailer.mail( - subject="testing", recipients=["to@example.com"], body="testing") + msg = mailer.mail(subject="testing", recipients=["to@example.com"], body="testing") msg.sender = None with pytest.raises(AssertionError): msg.send() @@ -105,7 +96,8 @@ def test_bcc(mailer): subject="testing", recipients=["to@example.com"], body="testing", - bcc=["tosomeoneelse@example.com"]) + bcc=["tosomeoneelse@example.com"], + ) assert "tosomeoneelse@example.com" not in str(msg) @@ -115,43 +107,36 @@ def test_cc(mailer): subject="testing", recipients=["to@example.com"], body="testing", - cc=["tosomeoneelse@example.com"]) + cc=["tosomeoneelse@example.com"], + ) assert "Cc: tosomeoneelse@example.com" in str(msg) def test_attach(mailer): - msg = mailer.mail( - subject="testing", recipients=["to@example.com"], body="testing") + msg = mailer.mail(subject="testing", recipients=["to@example.com"], body="testing") msg.attach(data=b"this is a test", content_type="text/plain") a = msg.attachments[0] assert a.filename is None - assert a.disposition == 'attachment' - assert a.content_type == 'text/plain' + assert a.disposition == "attachment" + assert a.content_type == "text/plain" assert a.data == b"this is a test" def test_bad_subject(mailer): - msg = mailer.mail( - subject="testing\r\n", - sender="from@example.com", - body="testing", - recipients=["to@example.com"]) + msg = mailer.mail(subject="testing\r\n", sender="from@example.com", body="testing", recipients=["to@example.com"]) with pytest.raises(RuntimeError): msg.send() def test_subject(mailer): - msg = mailer.mail( - subject=u"sübject", - sender='from@example.com', - recipients=["to@example.com"]) - assert '=?utf-8?q?s=C3=BCbject?=' in str(msg) + msg = mailer.mail(subject="sübject", sender="from@example.com", recipients=["to@example.com"]) + assert "=?utf-8?q?s=C3=BCbject?=" in str(msg) def test_empty_subject(mailer): msg = mailer.mail(sender="from@example.com", recipients=["foo@bar.com"]) msg.body = "normal ascii text" - assert 'Subject:' not in str(msg) + assert "Subject:" not in str(msg) def test_multiline_subject(mailer): @@ -159,40 +144,31 @@ def test_multiline_subject(mailer): subject="testing\r\n testing\r\n testing \r\n \ttesting", sender="from@example.com", body="testing", - recipients=["to@example.com"]) + recipients=["to@example.com"], + ) msg_as_string = str(msg) assert "From: from@example.com" in msg_as_string assert "testing\r\n testing\r\n testing \r\n \ttesting" in msg_as_string msg = mailer.mail( - subject="testing\r\n testing\r\n ", - sender="from@example.com", - body="testing", - recipients=["to@example.com"]) + subject="testing\r\n testing\r\n ", sender="from@example.com", body="testing", recipients=["to@example.com"] + ) with pytest.raises(RuntimeError): msg.send() msg = mailer.mail( - subject="testing\r\n testing\r\n\t", - sender="from@example.com", - body="testing", - recipients=["to@example.com"]) + subject="testing\r\n testing\r\n\t", sender="from@example.com", body="testing", recipients=["to@example.com"] + ) with pytest.raises(RuntimeError): msg.send() msg = mailer.mail( - subject="testing\r\n testing\r\n\n", - sender="from@example.com", - body="testing", - recipients=["to@example.com"]) + subject="testing\r\n testing\r\n\n", sender="from@example.com", body="testing", recipients=["to@example.com"] + ) with pytest.raises(RuntimeError): msg.send() def test_bad_sender(mailer): - msg = mailer.mail( - subject="testing", - sender="from@example.com\r\n", - recipients=["to@example.com"], - body="testing") - assert 'From: from@example.com' in str(msg) + msg = mailer.mail(subject="testing", sender="from@example.com\r\n", recipients=["to@example.com"], body="testing") + assert "From: from@example.com" in str(msg) def test_bad_reply_to(mailer): @@ -201,11 +177,12 @@ def test_bad_reply_to(mailer): sender="from@example.com", reply_to="evil@example.com\r", recipients=["to@example.com"], - body="testing") + body="testing", + ) msg_as_string = str(msg) - assert 'From: from@example.com' in msg_as_string - assert 'To: to@example.com' in msg_as_string - assert 'Reply-To: evil@example.com' in msg_as_string + assert "From: from@example.com" in msg_as_string + assert "To: to@example.com" in msg_as_string + assert "Reply-To: evil@example.com" in msg_as_string def test_bad_recipient(mailer): @@ -213,8 +190,9 @@ def test_bad_recipient(mailer): subject="testing", sender="from@example.com", recipients=["to@example.com", "to\r\n@example.com"], - body="testing") - assert 'To: to@example.com' in str(msg) + body="testing", + ) + assert "To: to@example.com" in str(msg) def test_address_sanitize(mailer): @@ -222,86 +200,62 @@ def test_address_sanitize(mailer): subject="testing", sender="sender\r\n@example.com", reply_to="reply_to\r\n@example.com", - recipients=["recipient\r\n@example.com"]) + recipients=["recipient\r\n@example.com"], + ) msg_as_string = str(msg) - assert 'sender@example.com' in msg_as_string - assert 'reply_to@example.com' in msg_as_string - assert 'recipient@example.com' in msg_as_string + assert "sender@example.com" in msg_as_string + assert "reply_to@example.com" in msg_as_string + assert "recipient@example.com" in msg_as_string def test_plain_message(mailer): plain_text = "Hello Joe,\nHow are you?" - msg = mailer.mail( - sender="from@example.com", - subject="subject", - recipients=["to@example.com"], - body=plain_text) + msg = mailer.mail(sender="from@example.com", subject="subject", recipients=["to@example.com"], body=plain_text) assert plain_text == msg.body - assert 'Content-Type: text/plain' in str(msg) + assert "Content-Type: text/plain" in str(msg) def test_plain_message_with_attachments(mailer): - msg = mailer.mail( - sender="from@example.com", - subject="subject", - recipients=["to@example.com"], - body="hello") + msg = mailer.mail(sender="from@example.com", subject="subject", recipients=["to@example.com"], body="hello") msg.attach(data=b"this is a test", content_type="text/plain") - assert 'Content-Type: multipart/mixed' in str(msg) + assert "Content-Type: multipart/mixed" in str(msg) def test_plain_message_with_unicode_attachments(mailer): - msg = mailer.mail( - subject="subject", recipients=["to@example.com"], body="hello") - msg.attach( - data=b"this is a test", - content_type="text/plain", - filename=u'ünicöde ←→ ✓.txt') + msg = mailer.mail(subject="subject", recipients=["to@example.com"], body="hello") + msg.attach(data=b"this is a test", content_type="text/plain", filename="ünicöde ←→ ✓.txt") parsed = email.message_from_string(str(msg)) - assert ( - re.sub(r'\s+', ' ', parsed.get_payload()[1].get('Content-Disposition')) - in [ - 'attachment; filename*="UTF8\'\'' - '%C3%BCnic%C3%B6de%20%E2%86%90%E2%86%92%20%E2%9C%93.txt"', - 'attachment; filename*=UTF8\'\'' - '%C3%BCnic%C3%B6de%20%E2%86%90%E2%86%92%20%E2%9C%93.txt' - ]) + assert re.sub(r"\s+", " ", parsed.get_payload()[1].get("Content-Disposition")) in [ + "attachment; filename*=\"UTF8''" '%C3%BCnic%C3%B6de%20%E2%86%90%E2%86%92%20%E2%9C%93.txt"', + "attachment; filename*=UTF8''" "%C3%BCnic%C3%B6de%20%E2%86%90%E2%86%92%20%E2%9C%93.txt", + ] def test_html_message(mailer): html_text = "

Hello World

" - msg = mailer.mail( - sender="from@example.com", - subject="subject", - recipients=["to@example.com"], - html=html_text) + msg = mailer.mail(sender="from@example.com", subject="subject", recipients=["to@example.com"], html=html_text) assert html_text == msg.html - assert 'Content-Type: multipart/alternative' in str(msg) + assert "Content-Type: multipart/alternative" in str(msg) def test_json_message(mailer): json_text = '{"msg": "Hello World!}' msg = mailer.mail( - sender="from@example.com", - subject="subject", - recipients=["to@example.com"], - alts={'json': json_text}) - assert json_text == msg.alts['json'] - assert 'Content-Type: multipart/alternative' in str(msg) + sender="from@example.com", subject="subject", recipients=["to@example.com"], alts={"json": json_text} + ) + assert json_text == msg.alts["json"] + assert "Content-Type: multipart/alternative" in str(msg) def test_html_message_with_attachments(mailer): html_text = "

Hello World

" - plain_text = 'Hello World' + plain_text = "Hello World" msg = mailer.mail( - sender="from@example.com", - subject="subject", - recipients=["to@example.com"], - body=plain_text, - html=html_text) + sender="from@example.com", subject="subject", recipients=["to@example.com"], body=plain_text, html=html_text + ) msg.attach(data=b"this is a test", content_type="text/plain") assert html_text == msg.html - assert 'Content-Type: multipart/alternative' in str(msg) + assert "Content-Type: multipart/alternative" in str(msg) parsed = email.message_from_string(str(msg)) assert len(parsed.get_payload()) == 2 body, attachment = parsed.get_payload() @@ -309,47 +263,41 @@ def test_html_message_with_attachments(mailer): plain, html = body.get_payload() assert plain.get_payload() == plain_text assert html.get_payload() == html_text - assert base64.b64decode(attachment.get_payload()) == b'this is a test' + assert base64.b64decode(attachment.get_payload()) == b"this is a test" def test_date(mailer): before = time.time() msg = mailer.mail( - sender="from@example.com", - subject="subject", - recipients=["to@example.com"], - body="hello", - date=time.time()) + sender="from@example.com", subject="subject", recipients=["to@example.com"], body="hello", date=time.time() + ) after = time.time() assert before <= msg.date <= after fmt_date = email.utils.formatdate(msg.date, localtime=True) - assert 'Date: {}'.format(fmt_date) in str(msg) + assert "Date: {}".format(fmt_date) in str(msg) def test_msgid(mailer): - msg = mailer.mail( - sender="from@example.com", - subject="subject", - recipients=["to@example.com"], - body="hello") + msg = mailer.mail(sender="from@example.com", subject="subject", recipients=["to@example.com"], body="hello") r = re.compile(r"<\S+@\S+>").match(msg.msgId) assert r is not None - assert 'Message-ID: {}'.format(msg.msgId) in str(msg) + assert "Message-ID: {}".format(msg.msgId) in str(msg) def test_unicode_addresses(mailer): msg = mailer.mail( subject="subject", - sender=u'ÄÜÖ → ✓ ', - recipients=[u"Ä ", u"Ü "], - cc=[u"Ö "]) + sender="ÄÜÖ → ✓ ", + recipients=["Ä ", "Ü "], + cc=["Ö "], + ) msg_as_string = str(msg) - a1 = sanitize_address(u"Ä ") - a2 = sanitize_address(u"Ü ") + a1 = sanitize_address("Ä ") + a2 = sanitize_address("Ü ") h1_a = Header("To: %s, %s" % (a1, a2)) h1_b = Header("To: %s, %s" % (a2, a1)) - h2 = Header("From: %s" % sanitize_address(u"ÄÜÖ → ✓ ")) - h3 = Header("Cc: %s" % sanitize_address(u"Ö ")) + h2 = Header("From: %s" % sanitize_address("ÄÜÖ → ✓ ")) + h3 = Header("Cc: %s" % sanitize_address("Ö ")) try: assert h1_a.encode() in msg_as_string except AssertionError: @@ -364,28 +312,27 @@ def test_extra_headers(mailer): subject="subject", recipients=["to@example.com"], body="hello", - extra_headers={'X-Extra-Header': 'Yes'}) - assert 'X-Extra-Header: Yes' in str(msg) + extra_headers={"X-Extra-Header": "Yes"}, + ) + assert "X-Extra-Header: Yes" in str(msg) def test_send(mailer): with mailer.store_mails() as outbox: - msg = mailer.mail( - subject="testing", recipients=["tester@example.com"], body="test") + msg = mailer.mail(subject="testing", recipients=["tester@example.com"], body="test") msg.send() assert msg.date is not None assert len(outbox) == 1 sent_msg = outbox[0] - assert sent_msg.sender == 'support@example.com' + assert sent_msg.sender == "support@example.com" def test_send_mail(mailer): with mailer.store_mails() as outbox: - mailer.send_mail( - subject="testing", recipients=["tester@example.com"], body="test") + mailer.send_mail(subject="testing", recipients=["tester@example.com"], body="test") assert len(outbox) == 1 sent_msg = outbox[0] - assert sent_msg.subject == 'testing' + assert sent_msg.subject == "testing" assert sent_msg.recipients == ["tester@example.com"] - assert sent_msg.body == 'test' - assert sent_msg.sender == 'support@example.com' + assert sent_msg.body == "test" + assert sent_msg.sender == "support@example.com" diff --git a/tests/test_migrations.py b/tests/test_migrations.py index 9b774177..940f5130 100644 --- a/tests/test_migrations.py +++ b/tests/test_migrations.py @@ -1,9 +1,9 @@ # -*- coding: utf-8 -*- """ - tests.migrations - ---------------- +tests.migrations +---------------- - Test Emmett migrations engine +Test Emmett migrations engine """ import uuid @@ -11,9 +11,9 @@ import pytest from emmett import App -from emmett.orm import Database, Model, Field, belongs_to, refers_to -from emmett.orm.migrations.engine import MetaEngine, Engine -from emmett.orm.migrations.generation import MetaData, Comparator +from emmett.orm import Database, Field, Model, belongs_to, refers_to +from emmett.orm.migrations.engine import Engine, MetaEngine +from emmett.orm.migrations.generation import Comparator, MetaData class FakeEngine(Engine): @@ -23,10 +23,10 @@ def _log_and_exec(self, sql): self.sql_history.append(sql) -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def app(): rv = App(__name__) - rv.config.db.uri = 'sqlite:memory' + rv.config.db.uri = "sqlite:memory" return rv @@ -58,6 +58,7 @@ class StepOneThing(Model): name = Field() value = Field.float() + _step_one_sql = """CREATE TABLE "step_one_things"( "id" INTEGER PRIMARY KEY AUTOINCREMENT, "name" CHAR(512), @@ -105,6 +106,7 @@ class StepTwoThing(Model): value = Field.float(default=8.8) available = Field.bool(default=True) + _step_two_sql = """CREATE TABLE "step_two_things"( "id" INTEGER PRIMARY KEY AUTOINCREMENT, "name" CHAR(512) NOT NULL, @@ -136,6 +138,7 @@ class StepThreeThingThree(Model): tablename = "step_three_thing_ones" b = Field() + _step_three_sql = 'ALTER TABLE "step_three_thing_ones" ADD "b" CHAR(512);' _step_three_sql_drop = 'ALTER TABLE "step_three_thing_ones" DROP COLUMN "a";' @@ -178,6 +181,7 @@ class StepFourThingEdit(Model): available = Field.bool(default=True) asd = Field.int() + _step_four_sql = """ALTER TABLE "step_four_things" ALTER COLUMN "name" DROP NOT NULL; ALTER TABLE "step_four_things" ALTER COLUMN "value" DROP DEFAULT; ALTER TABLE "step_four_things" ALTER COLUMN "asd" TYPE INTEGER;""" @@ -202,32 +206,24 @@ class StepFiveThing(Model): value = Field.int() created_at = Field.datetime() - indexes = { - 'name': True, - ('name', 'value'): True - } + indexes = {"name": True, ("name", "value"): True} class StepFiveThingEdit(StepFiveThing): tablename = "step_five_things" - indexes = { - 'name': False, - 'name_created': { - 'fields': 'name', - 'expressions': lambda m: m.created_at.coalesce(None)} - } + indexes = {"name": False, "name_created": {"fields": "name", "expressions": lambda m: m.created_at.coalesce(None)}} _step_five_sql_before = [ 'CREATE UNIQUE INDEX "step_five_things_widx__code_unique" ON "step_five_things" ("code");', 'CREATE INDEX "step_five_things_widx__name" ON "step_five_things" ("name");', - 'CREATE INDEX "step_five_things_widx__name_value" ON "step_five_things" ("name","value");' + 'CREATE INDEX "step_five_things_widx__name_value" ON "step_five_things" ("name","value");', ] _step_five_sql_after = [ 'DROP INDEX "step_five_things_widx__name";', - 'CREATE INDEX "step_five_things_widx__name_created" ON "step_five_things" ("name",COALESCE("created_at",NULL));' + 'CREATE INDEX "step_five_things_widx__name_created" ON "step_five_things" ("name",COALESCE("created_at",NULL));', ] @@ -255,7 +251,7 @@ class StepSixThing(Model): class StepSixRelate(Model): - belongs_to('step_six_thing') + belongs_to("step_six_thing") name = Field() @@ -268,11 +264,13 @@ class StepSixRelate(Model): "name" CHAR(512), "step_six_thing" CHAR(512) );""" -_step_six_sql_fk = "".join([ - 'ALTER TABLE "step_six_relates" ADD CONSTRAINT ', - '"step_six_relates_ecnt__fk__stepsixthings_stepsixthing" FOREIGN KEY ', - '("step_six_thing") REFERENCES "step_six_things"("id") ON DELETE CASCADE;' -]) +_step_six_sql_fk = "".join( + [ + 'ALTER TABLE "step_six_relates" ADD CONSTRAINT ', + '"step_six_relates_ecnt__fk__stepsixthings_stepsixthing" FOREIGN KEY ', + '("step_six_thing") REFERENCES "step_six_things"("id") ON DELETE CASCADE;', + ] +) def test_step_six_id_types(app): @@ -286,22 +284,17 @@ def test_step_six_id_types(app): class StepSevenThing(Model): - primary_keys = ['foo', 'bar'] + primary_keys = ["foo", "bar"] foo = Field() bar = Field() class StepSevenRelate(Model): - refers_to({'foo': 'StepSevenThing.foo'}, {'bar': 'StepSevenThing.bar'}) + refers_to({"foo": "StepSevenThing.foo"}, {"bar": "StepSevenThing.bar"}) name = Field() - foreign_keys = { - "test": { - "fields": ["foo", "bar"], - "foreign_fields": ["foo", "bar"] - } - } + foreign_keys = {"test": {"fields": ["foo", "bar"], "foreign_fields": ["foo", "bar"]}} _step_seven_sql_t1 = """CREATE TABLE "step_seven_things"( @@ -315,11 +308,13 @@ class StepSevenRelate(Model): "foo" CHAR(512), "bar" CHAR(512) );""" -_step_seven_sql_fk = "".join([ - 'ALTER TABLE "step_seven_relates" ADD CONSTRAINT ', - '"step_seven_relates_ecnt__fk__stepseventhings_foo_bar" FOREIGN KEY ("foo","bar") ', - 'REFERENCES "step_seven_things"("foo","bar") ON DELETE CASCADE;', -]) +_step_seven_sql_fk = "".join( + [ + 'ALTER TABLE "step_seven_relates" ADD CONSTRAINT ', + '"step_seven_relates_ecnt__fk__stepseventhings_foo_bar" FOREIGN KEY ("foo","bar") ', + 'REFERENCES "step_seven_things"("foo","bar") ON DELETE CASCADE;', + ] +) def test_step_seven_composed_pks(app): diff --git a/tests/test_orm.py b/tests/test_orm.py index b99dc303..9dd8e11e 100644 --- a/tests/test_orm.py +++ b/tests/test_orm.py @@ -1,58 +1,63 @@ # -*- coding: utf-8 -*- """ - tests.orm - --------- +tests.orm +--------- - Test pyDAL implementation over Emmett. +Test pyDAL implementation over Emmett. """ -import pytest - from datetime import datetime, timedelta from uuid import uuid4 -from pydal.objects import Table +import pytest from pydal import Field as _Field -from emmett import App, sdict, now +from pydal.objects import Table + +from emmett import App, now, sdict from emmett.orm import ( - Database, Field, Model, + Database, + Field, + Model, + after_commit, + after_delete, + after_destroy, + after_insert, + after_save, + after_update, + before_commit, + before_delete, + before_destroy, + before_insert, + before_save, + before_update, + belongs_to, compute, - before_insert, after_insert, - before_update, after_update, - before_delete, after_delete, - before_save, after_save, - before_destroy, after_destroy, - before_commit, after_commit, - rowattr, rowmethod, - has_one, has_many, belongs_to, refers_to, - scope + has_many, + has_one, + refers_to, + rowattr, + rowmethod, + scope, ) +from emmett.orm.errors import MissingFieldsForCompute from emmett.orm.migrations.utils import generate_runtime_migration from emmett.orm.objects import TransactionOps -from emmett.orm.errors import MissingFieldsForCompute -from emmett.validators import isntEmpty, hasLength +from emmett.validators import hasLength, isntEmpty CALLBACK_OPS = { "before_insert": [], "before_update": [], "before_delete": [], - "before_save":[], + "before_save": [], "before_destroy": [], "after_insert": [], "after_update": [], "after_delete": [], - "after_save":[], - "after_destroy": [] -} -COMMIT_CALLBACKS = { - "all": [], - "insert": [], - "update": [], - "delete": [], - "save": [], - "destroy": [] + "after_save": [], + "after_destroy": [], } +COMMIT_CALLBACKS = {"all": [], "insert": [], "update": [], "delete": [], "save": [], "destroy": []} def _represent_f(value): @@ -72,73 +77,57 @@ class Stuff(Model): total_watch = Field.float() invisible = Field() - validation = { - "a": {'presence': True}, - "total": {"allow": "empty"}, - "total_watch": {"allow": "empty"} - } + validation = {"a": {"presence": True}, "total": {"allow": "empty"}, "total_watch": {"allow": "empty"}} - fields_rw = { - "invisible": False - } + fields_rw = {"invisible": False} - form_labels = { - "a": "A label" - } + form_labels = {"a": "A label"} - form_info = { - "a": "A comment" - } + form_info = {"a": "A comment"} - update_values = { - "a": "a_update" - } + update_values = {"a": "a_update"} - repr_values = { - "a": _represent_f - } + repr_values = {"a": _represent_f} - form_widgets = { - "a": _widget_f - } + form_widgets = {"a": _widget_f} - @compute('total') + @compute("total") def eval_total(self, row): return row.price * row.quantity - @compute('total_watch', watch=['price', 'quantity']) + @compute("total_watch", watch=["price", "quantity"]) def eval_total_watch(self, row): return row.price * row.quantity @before_insert def bi(self, fields): - CALLBACK_OPS['before_insert'].append(fields) + CALLBACK_OPS["before_insert"].append(fields) @after_insert def ai(self, fields, id): - CALLBACK_OPS['after_insert'].append((fields, id)) + CALLBACK_OPS["after_insert"].append((fields, id)) @before_update def bu(self, set, fields): - CALLBACK_OPS['before_update'].append((set, fields)) + CALLBACK_OPS["before_update"].append((set, fields)) @after_update def au(self, set, fields): - CALLBACK_OPS['after_update'].append((set, fields)) + CALLBACK_OPS["after_update"].append((set, fields)) @before_delete def bd(self, set): - CALLBACK_OPS['before_delete'].append(set) + CALLBACK_OPS["before_delete"].append(set) @after_delete def ad(self, set): - CALLBACK_OPS['after_delete'].append(set) + CALLBACK_OPS["after_delete"].append(set) - @rowattr('totalv') + @rowattr("totalv") def eval_total_v(self, row): return row.price * row.quantity - @rowmethod('totalm') + @rowmethod("totalm") def eval_total_m(self, row): return row.price * row.quantity @@ -148,57 +137,55 @@ def method_test(cls, t): class Person(Model): - has_many( - 'things', {'features': {'via': 'things'}}, {'pets': 'Dog.owner'}, - 'subscriptions') + has_many("things", {"features": {"via": "things"}}, {"pets": "Dog.owner"}, "subscriptions") name = Field() age = Field.int() class Thing(Model): - belongs_to('person') - has_many('features') + belongs_to("person") + has_many("features") name = Field() color = Field() class Feature(Model): - belongs_to('thing') - has_one('price') + belongs_to("thing") + has_one("price") name = Field() class Price(Model): - belongs_to('feature') + belongs_to("feature") value = Field.int() class Doctor(Model): - has_many('appointments', {'patients': {'via': 'appointments'}}) + has_many("appointments", {"patients": {"via": "appointments"}}) name = Field() class Patient(Model): - has_many('appointments', {'doctors': {'via': 'appointments'}}) + has_many("appointments", {"doctors": {"via": "appointments"}}) name = Field() class Appointment(Model): - belongs_to('patient', 'doctor') + belongs_to("patient", "doctor") date = Field.datetime() class User(Model): name = Field() has_many( - 'memberships', {'organizations': {'via': 'memberships'}}, - {'cover_orgs': { - 'via': 'memberships.organization', - 'where': lambda m: m.is_cover == True}}) + "memberships", + {"organizations": {"via": "memberships"}}, + {"cover_orgs": {"via": "memberships.organization", "where": lambda m: m.is_cover == True}}, + ) class Organization(Model): @@ -210,22 +197,23 @@ def admin_memberships3(self): return Membership.admins() has_many( - 'memberships', {'users': {'via': 'memberships'}}, - {'admin_memberships': {'target': 'Membership', 'scope': 'admins'}}, - {'admins': {'via': 'admin_memberships.user'}}, - {'admin_memberships2': { - 'target': 'Membership', 'where': lambda m: m.role == 'admin'}}, - {'admins2': {'via': 'admin_memberships2.user'}}, - {'admins3': {'via': 'admin_memberships3.user'}}) + "memberships", + {"users": {"via": "memberships"}}, + {"admin_memberships": {"target": "Membership", "scope": "admins"}}, + {"admins": {"via": "admin_memberships.user"}}, + {"admin_memberships2": {"target": "Membership", "where": lambda m: m.role == "admin"}}, + {"admins2": {"via": "admin_memberships2.user"}}, + {"admins3": {"via": "admin_memberships3.user"}}, + ) class Membership(Model): - belongs_to('user', 'organization') + belongs_to("user", "organization") role = Field() - @scope('admins') + @scope("admins") def filter_admins(self): - return self.role == 'admin' + return self.role == "admin" class House(Model): @@ -234,7 +222,7 @@ class House(Model): class Mouse(Model): tablename = "mice" - has_many('elephants') + has_many("elephants") name = Field() @@ -243,19 +231,19 @@ class NeedSplit(Model): class Zoo(Model): - has_many('animals', 'elephants', {'mice': {'via': 'elephants.mouse'}}) + has_many("animals", "elephants", {"mice": {"via": "elephants.mouse"}}) name = Field() class Animal(Model): - belongs_to('zoo') + belongs_to("zoo") name = Field() - @rowattr('doublename') + @rowattr("doublename") def get_double_name(self, row): return row.name * 2 - @rowattr('pretty') + @rowattr("pretty") def get_pretty(self, row): return row.name @@ -269,10 +257,10 @@ def bi2(self, *args, **kwargs): class Elephant(Animal): - belongs_to('mouse') + belongs_to("mouse") color = Field() - @rowattr('pretty') + @rowattr("pretty") def get_pretty(self, row): return row.name + " " + row.color @@ -282,24 +270,24 @@ def bi2(self, *args, **kwargs): class Dog(Model): - belongs_to({'owner': 'Person'}) + belongs_to({"owner": "Person"}) name = Field() class Subscription(Model): - belongs_to('person') + belongs_to("person") name = Field() status = Field.int() expires_at = Field.datetime() - STATUS = {'active': 1, 'suspended': 2, 'other': 3} + STATUS = {"active": 1, "suspended": 2, "other": 3} - @scope('expired') + @scope("expired") def get_expired(self): return self.expires_at < datetime.now() - @scope('of_status') + @scope("of_status") def filter_status(self, *statuses): if len(statuses) == 1: return self.status == self.STATUS[statuses[0]] @@ -362,7 +350,7 @@ def _compute_price(self, row): class SelfRef(Model): - refers_to({'parent': 'self'}) + refers_to({"parent": "self"}) name = Field.string() @@ -412,29 +400,43 @@ def _commit_watch_after_destroy(self, ctx): COMMIT_CALLBACKS["destroy"].append(("after", ctx)) -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def _db(): app = App(__name__) - db = Database( - app, config=sdict( - uri=f'sqlite://{uuid4().hex}.db', - auto_connect=True - ) - ) + db = Database(app, config=sdict(uri=f"sqlite://{uuid4().hex}.db", auto_connect=True)) db.define_models( - Stuff, Person, Thing, Feature, Price, Dog, Subscription, - Doctor, Patient, Appointment, - User, Organization, Membership, - House, Mouse, NeedSplit, Zoo, Animal, Elephant, - Product, Cart, CartElement, + Stuff, + Person, + Thing, + Feature, + Price, + Dog, + Subscription, + Doctor, + Patient, + Appointment, + User, + Organization, + Membership, + House, + Mouse, + NeedSplit, + Zoo, + Animal, + Elephant, + Product, + Cart, + CartElement, SelfRef, - CustomPKType, CustomPKName, CustomPKMulti, - CommitWatcher + CustomPKType, + CustomPKName, + CustomPKMulti, + CommitWatcher, ) return db -@pytest.fixture(scope='function') +@pytest.fixture(scope="function") def db(_db): migration = generate_runtime_migration(_db) migration.up() @@ -948,12 +950,13 @@ def test_relations(db): assert t[0].name == "apple" and t[0].color == "red" and t[0].person.id == p1.id f = p1.things()[0].features() assert len(f) == 1 - assert f[0].name == "tasty" and f[0].thing.id == t[0].id and \ - f[0].thing.person.id == p1.id + assert f[0].name == "tasty" and f[0].thing.id == t[0].id and f[0].thing.person.id == p1.id m = p1.things()[0].features()[0].price() assert ( - m.value == 5 and m.feature.id == f[0].id and - m.feature.thing.id == t[0].id and m.feature.thing.person.id == p1.id + m.value == 5 + and m.feature.id == f[0].id + and m.feature.thing.id == t[0].id + and m.feature.thing.person.id == p1.id ) p2.things.add(t2) assert p1.things.count() == 1 @@ -974,30 +977,30 @@ def test_relations(db): assert len(doctor.appointments()) == 1 assert len(patient.doctors()) == 1 assert len(patient.appointments()) == 1 - joe = db.User.insert(name='joe') - jim = db.User.insert(name='jim') - org = db.Organization.insert(name='') - org.users.add(joe, role='admin') - org.users.add(jim, role='manager') + joe = db.User.insert(name="joe") + jim = db.User.insert(name="jim") + org = db.Organization.insert(name="") + org.users.add(joe, role="admin") + org.users.add(jim, role="manager") assert len(org.users()) == 2 assert len(joe.organizations()) == 1 assert len(jim.organizations()) == 1 assert joe.organizations().first().id == org assert jim.organizations().first().id == org - assert joe.memberships().first().role == 'admin' - assert jim.memberships().first().role == 'manager' + assert joe.memberships().first().role == "admin" + assert jim.memberships().first().role == "manager" org.users.remove(joe) org.users.remove(jim) assert len(org.users(reload=True)) == 0 assert len(joe.organizations(reload=True)) == 0 assert len(jim.organizations(reload=True)) == 0 #: has_many with specified feld - db.Dog.insert(name='pongo', owner=p1) - assert len(p1.pets()) == 1 and p1.pets().first().name == 'pongo' + db.Dog.insert(name="pongo", owner=p1) + assert len(p1.pets()) == 1 and p1.pets().first().name == "pongo" #: has_many via with specified field - zoo = db.Zoo.insert(name='magic zoo') - mouse = db.Mouse.insert(name='jerry') - db.Elephant.insert(name='dumbo', color='pink', mouse=mouse, zoo=zoo) + zoo = db.Zoo.insert(name="magic zoo") + mouse = db.Mouse.insert(name="jerry") + db.Elephant.insert(name="dumbo", color="pink", mouse=mouse, zoo=zoo) assert len(zoo.mice()) == 1 @@ -1008,44 +1011,34 @@ def test_tablenames(db): def test_inheritance(db): - assert 'name' in db.Animal.fields - assert 'name' in db.Elephant.fields - assert 'zoo' in db.Animal.fields - assert 'zoo' in db.Elephant.fields - assert 'color' in db.Elephant.fields - assert 'color' not in db.Animal.fields - assert Elephant._all_virtuals_['get_double_name'] is \ - Animal._all_virtuals_['get_double_name'] - assert Elephant._all_virtuals_['get_pretty'] is not \ - Animal._all_virtuals_['get_pretty'] - assert Elephant._all_callbacks_['bi'] is \ - Animal._all_callbacks_['bi'] - assert Elephant._all_callbacks_['bi2'] is not \ - Animal._all_callbacks_['bi2'] + assert "name" in db.Animal.fields + assert "name" in db.Elephant.fields + assert "zoo" in db.Animal.fields + assert "zoo" in db.Elephant.fields + assert "color" in db.Elephant.fields + assert "color" not in db.Animal.fields + assert Elephant._all_virtuals_["get_double_name"] is Animal._all_virtuals_["get_double_name"] + assert Elephant._all_virtuals_["get_pretty"] is not Animal._all_virtuals_["get_pretty"] + assert Elephant._all_callbacks_["bi"] is Animal._all_callbacks_["bi"] + assert Elephant._all_callbacks_["bi2"] is not Animal._all_callbacks_["bi2"] def test_scopes(db): p = db.Person.insert(name="Walter", age=50) - s = db.Subscription.insert( - name="a", expires_at=datetime.now() - timedelta(hours=20), person=p, - status=1) - s2 = db.Subscription.insert( - name="b", expires_at=datetime.now() + timedelta(hours=20), person=p, - status=2) - db.Subscription.insert( - name="c", expires_at=datetime.now() + timedelta(hours=20), person=p, - status=3) + s = db.Subscription.insert(name="a", expires_at=datetime.now() - timedelta(hours=20), person=p, status=1) + s2 = db.Subscription.insert(name="b", expires_at=datetime.now() + timedelta(hours=20), person=p, status=2) + db.Subscription.insert(name="c", expires_at=datetime.now() + timedelta(hours=20), person=p, status=3) rows = db(db.Subscription).expired().select() assert len(rows) == 1 and rows[0].id == s rows = p.subscriptions.expired().select() assert len(rows) == 1 and rows[0].id == s rows = Subscription.expired().select() assert len(rows) == 1 and rows[0].id == s - rows = db(db.Subscription).of_status('active', 'suspended').select() + rows = db(db.Subscription).of_status("active", "suspended").select() assert len(rows) == 2 and rows[0].id == s and rows[1].id == s2 - rows = p.subscriptions.of_status('active', 'suspended').select() + rows = p.subscriptions.of_status("active", "suspended").select() assert len(rows) == 2 and rows[0].id == s and rows[1].id == s2 - rows = Subscription.of_status('active', 'suspended').select() + rows = Subscription.of_status("active", "suspended").select() assert len(rows) == 2 and rows[0].id == s and rows[1].id == s2 @@ -1054,7 +1047,7 @@ def test_relations_scopes(db): org = db.Organization.insert(name="Los pollos hermanos") org.users.add(gus, role="admin") frank = db.User.insert(name="Frank") - org.users.add(frank, role='manager') + org.users.add(frank, role="manager") assert org.admins.count() == 1 assert org.admins2.count() == 1 assert org.admins3.count() == 1 @@ -1100,24 +1093,13 @@ def test_relations_scopes(db): def test_model_where(db): - assert Subscription.where(lambda s: s.status == 1).query == \ - db(db.Subscription.status == 1).query + assert Subscription.where(lambda s: s.status == 1).query == db(db.Subscription.status == 1).query def test_model_first(db): p = db.Person.insert(name="Walter", age=50) - db.Subscription.insert( - name="a", - expires_at=datetime.now() + timedelta(hours=20), - person=p, - status=1 - ) - db.Subscription.insert( - name="b", - expires_at=datetime.now() + timedelta(hours=20), - person=p, - status=1 - ) + db.Subscription.insert(name="a", expires_at=datetime.now() + timedelta(hours=20), person=p, status=1) + db.Subscription.insert(name="b", expires_at=datetime.now() + timedelta(hours=20), person=p, status=1) db.CustomPKType.insert(id="a") db.CustomPKType.insert(id="b") db.CustomPKName.insert(name="a") @@ -1126,38 +1108,23 @@ def test_model_first(db): db.CustomPKMulti.insert(first_name="foo", last_name="baz") db.CustomPKMulti.insert(first_name="bar", last_name="baz") - assert Subscription.first().id == Subscription.all().select( - orderby=Subscription.id, - limitby=(0, 1) - ).first().id - assert CustomPKType.first().id == CustomPKType.all().select( - orderby=CustomPKType.id, - limitby=(0, 1) - ).first().id - assert CustomPKName.first().name == CustomPKName.all().select( - orderby=CustomPKName.name, - limitby=(0, 1) - ).first().name - assert CustomPKMulti.first() == CustomPKMulti.all().select( - orderby=CustomPKMulti.first_name|CustomPKMulti.last_name, - limitby=(0, 1) - ).first() + assert Subscription.first().id == Subscription.all().select(orderby=Subscription.id, limitby=(0, 1)).first().id + assert CustomPKType.first().id == CustomPKType.all().select(orderby=CustomPKType.id, limitby=(0, 1)).first().id + assert ( + CustomPKName.first().name == CustomPKName.all().select(orderby=CustomPKName.name, limitby=(0, 1)).first().name + ) + assert ( + CustomPKMulti.first() + == CustomPKMulti.all() + .select(orderby=CustomPKMulti.first_name | CustomPKMulti.last_name, limitby=(0, 1)) + .first() + ) def test_model_last(db): p = db.Person.insert(name="Walter", age=50) - db.Subscription.insert( - name="a", - expires_at=datetime.now() + timedelta(hours=20), - person=p, - status=1 - ) - db.Subscription.insert( - name="b", - expires_at=datetime.now() + timedelta(hours=20), - person=p, - status=1 - ) + db.Subscription.insert(name="a", expires_at=datetime.now() + timedelta(hours=20), person=p, status=1) + db.Subscription.insert(name="b", expires_at=datetime.now() + timedelta(hours=20), person=p, status=1) db.CustomPKType.insert(id="a") db.CustomPKType.insert(id="b") db.CustomPKName.insert(name="a") @@ -1166,19 +1133,14 @@ def test_model_last(db): db.CustomPKMulti.insert(first_name="foo", last_name="baz") db.CustomPKMulti.insert(first_name="bar", last_name="baz") - assert Subscription.last().id == Subscription.all().select( - orderby=~Subscription.id, - limitby=(0, 1) - ).first().id - assert CustomPKType.last().id == CustomPKType.all().select( - orderby=~CustomPKType.id, - limitby=(0, 1) - ).first().id - assert CustomPKName.last().name == CustomPKName.all().select( - orderby=~CustomPKName.name, - limitby=(0, 1) - ).first().name - assert CustomPKMulti.last() == CustomPKMulti.all().select( - orderby=~CustomPKMulti.first_name|~CustomPKMulti.last_name, - limitby=(0, 1) - ).first() + assert Subscription.last().id == Subscription.all().select(orderby=~Subscription.id, limitby=(0, 1)).first().id + assert CustomPKType.last().id == CustomPKType.all().select(orderby=~CustomPKType.id, limitby=(0, 1)).first().id + assert ( + CustomPKName.last().name == CustomPKName.all().select(orderby=~CustomPKName.name, limitby=(0, 1)).first().name + ) + assert ( + CustomPKMulti.last() + == CustomPKMulti.all() + .select(orderby=~CustomPKMulti.first_name | ~CustomPKMulti.last_name, limitby=(0, 1)) + .first() + ) diff --git a/tests/test_orm_connections.py b/tests/test_orm_connections.py index d4c7e048..df881763 100644 --- a/tests/test_orm_connections.py +++ b/tests/test_orm_connections.py @@ -1,9 +1,9 @@ # -*- coding: utf-8 -*- """ - tests.orm_connections - --------------------- +tests.orm_connections +--------------------- - Test pyDAL connection implementation over Emmett. +Test pyDAL connection implementation over Emmett. """ import pytest @@ -12,17 +12,10 @@ from emmett.orm import Database -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def db(): app = App(__name__) - db = Database( - app, - config=sdict( - uri='sqlite:memory', - auto_migrate=True, - auto_connect=False - ) - ) + db = Database(app, config=sdict(uri="sqlite:memory", auto_migrate=True, auto_connect=False)) return db diff --git a/tests/test_orm_gis.py b/tests/test_orm_gis.py index 0da40b86..d63542eb 100644 --- a/tests/test_orm_gis.py +++ b/tests/test_orm_gis.py @@ -1,21 +1,21 @@ # -*- coding: utf-8 -*- """ - tests.orm_gis - ------------- +tests.orm_gis +------------- - Test ORM GIS features +Test ORM GIS features """ import os + import pytest from emmett import App, sdict -from emmett.orm import Database, Model, Field, geo +from emmett.orm import Database, Field, Model, geo from emmett.orm.migrations.utils import generate_runtime_migration -require_postgres = pytest.mark.skipif( - not os.environ.get("POSTGRES_URI"), reason="No postgres database" -) + +require_postgres = pytest.mark.skipif(not os.environ.get("POSTGRES_URI"), reason="No postgres database") class Geography(Model): @@ -40,23 +40,15 @@ class Geometry(Model): multipolygon = Field.geometry("MULTIPOLYGON") -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def _db(): app = App(__name__) - db = Database( - app, - config=sdict( - uri=f"postgres://{os.environ.get('POSTGRES_URI')}" - ) - ) - db.define_models( - Geography, - Geometry - ) + db = Database(app, config=sdict(uri=f"postgres://{os.environ.get('POSTGRES_URI')}")) + db.define_models(Geography, Geometry) return db -@pytest.fixture(scope='function') +@pytest.fixture(scope="function") def db(_db): migration = generate_runtime_migration(_db) with _db.connection(): @@ -93,15 +85,9 @@ def test_gis_insert(db): multipoint=geo.MultiPoint((1, 1), (2, 2)), multiline=geo.MultiLine(((1, 1), (2, 2), (3, 3)), ((1, 1), (4, 4), (5, 5))), multipolygon=geo.MultiPolygon( - ( - ((0, 0), (20, 0), (20, 20), (0, 0)), - ((0, 0), (30, 0), (30, 30), (0, 0)) - ), - ( - ((1, 1), (21, 1), (21, 21), (1, 1)), - ((1, 1), (31, 1), (31, 31), (1, 1)) - ) - ) + (((0, 0), (20, 0), (20, 20), (0, 0)), ((0, 0), (30, 0), (30, 30), (0, 0))), + (((1, 1), (21, 1), (21, 21), (1, 1)), ((1, 1), (31, 1), (31, 31), (1, 1))), + ), ) row.save() @@ -111,35 +97,17 @@ def test_gis_insert(db): assert not row.point.groups assert row.line == "LINESTRING({})".format( - ",".join([ - " ".join(f"{v}.000000" for v in tup) - for tup in [ - (0, 0), - (20, 80), - (80, 80) - ] - ]) + ",".join([" ".join(f"{v}.000000" for v in tup) for tup in [(0, 0), (20, 80), (80, 80)]]) ) assert row.line.geometry == "LINESTRING" assert row.line.coordinates == ((0, 0), (20, 80), (80, 80)) assert not row.line.groups assert row.polygon == "POLYGON(({}))".format( - ",".join([ - " ".join(f"{v}.000000" for v in tup) - for tup in [ - (0, 0), - (150, 0), - (150, 10), - (0, 10), - (0, 0) - ] - ]) + ",".join([" ".join(f"{v}.000000" for v in tup) for tup in [(0, 0), (150, 0), (150, 10), (0, 10), (0, 0)]]) ) assert row.polygon.geometry == "POLYGON" - assert row.polygon.coordinates == ( - ((0, 0), (150, 0), (150, 10), (0, 10), (0, 0)), - ) + assert row.polygon.coordinates == (((0, 0), (150, 0), (150, 10), (0, 10), (0, 0)),) assert not row.polygon.groups assert row.multipoint == "MULTIPOINT((1.000000 1.000000),(2.000000 2.000000))" @@ -150,69 +118,48 @@ def test_gis_insert(db): assert row.multipoint.groups[1] == geo.Point(2, 2) assert row.multiline == "MULTILINESTRING({})".format( - ",".join([ - "({})".format( - ",".join([ - " ".join(f"{v}.000000" for v in tup) - for tup in group - ]) - ) for group in [ - ((1, 1), (2, 2), (3, 3)), - ((1, 1), (4, 4), (5, 5)) + ",".join( + [ + "({})".format(",".join([" ".join(f"{v}.000000" for v in tup) for tup in group])) + for group in [((1, 1), (2, 2), (3, 3)), ((1, 1), (4, 4), (5, 5))] ] - ]) + ) ) assert row.multiline.geometry == "MULTILINESTRING" - assert row.multiline.coordinates == ( - ((1, 1), (2, 2), (3, 3)), - ((1, 1), (4, 4), (5, 5)) - ) + assert row.multiline.coordinates == (((1, 1), (2, 2), (3, 3)), ((1, 1), (4, 4), (5, 5))) assert len(row.multiline.groups) == 2 assert row.multiline.groups[0] == geo.Line((1, 1), (2, 2), (3, 3)) assert row.multiline.groups[1] == geo.Line((1, 1), (4, 4), (5, 5)) assert row.multipolygon == "MULTIPOLYGON({})".format( - ",".join([ - "({})".format( - ",".join([ - "({})".format( - ",".join([ - " ".join(f"{v}.000000" for v in tup) - for tup in group - ]) - ) for group in polygon - ]) - ) for polygon in [ - ( - ((0, 0), (20, 0), (20, 20), (0, 0)), - ((0, 0), (30, 0), (30, 30), (0, 0)) - ), - ( - ((1, 1), (21, 1), (21, 21), (1, 1)), - ((1, 1), (31, 1), (31, 31), (1, 1)) + ",".join( + [ + "({})".format( + ",".join( + [ + "({})".format(",".join([" ".join(f"{v}.000000" for v in tup) for tup in group])) + for group in polygon + ] + ) ) + for polygon in [ + (((0, 0), (20, 0), (20, 20), (0, 0)), ((0, 0), (30, 0), (30, 30), (0, 0))), + (((1, 1), (21, 1), (21, 21), (1, 1)), ((1, 1), (31, 1), (31, 31), (1, 1))), + ] ] - ]) + ) ) assert row.multipolygon.geometry == "MULTIPOLYGON" assert row.multipolygon.coordinates == ( - ( - ((0, 0), (20, 0), (20, 20), (0, 0)), - ((0, 0), (30, 0), (30, 30), (0, 0)) - ), - ( - ((1, 1), (21, 1), (21, 21), (1, 1)), - ((1, 1), (31, 1), (31, 31), (1, 1)) - ) + (((0, 0), (20, 0), (20, 20), (0, 0)), ((0, 0), (30, 0), (30, 30), (0, 0))), + (((1, 1), (21, 1), (21, 21), (1, 1)), ((1, 1), (31, 1), (31, 31), (1, 1))), ) assert len(row.multipolygon.groups) == 2 assert row.multipolygon.groups[0] == geo.Polygon( - ((0, 0), (20, 0), (20, 20), (0, 0)), - ((0, 0), (30, 0), (30, 30), (0, 0)) + ((0, 0), (20, 0), (20, 20), (0, 0)), ((0, 0), (30, 0), (30, 30), (0, 0)) ) assert row.multipolygon.groups[1] == geo.Polygon( - ((1, 1), (21, 1), (21, 21), (1, 1)), - ((1, 1), (31, 1), (31, 31), (1, 1)) + ((1, 1), (21, 1), (21, 21), (1, 1)), ((1, 1), (31, 1), (31, 31), (1, 1)) ) @@ -226,15 +173,9 @@ def test_gis_select(db): multipoint=geo.MultiPoint((1, 1), (2, 2)), multiline=geo.MultiLine(((1, 1), (2, 2), (3, 3)), ((1, 1), (4, 4), (5, 5))), multipolygon=geo.MultiPolygon( - ( - ((0, 0), (20, 0), (20, 20), (0, 0)), - ((0, 0), (30, 0), (30, 30), (0, 0)) - ), - ( - ((1, 1), (21, 1), (21, 21), (1, 1)), - ((1, 1), (31, 1), (31, 31), (1, 1)) - ) - ) + (((0, 0), (20, 0), (20, 20), (0, 0)), ((0, 0), (30, 0), (30, 30), (0, 0))), + (((1, 1), (21, 1), (21, 21), (1, 1)), ((1, 1), (31, 1), (31, 31), (1, 1))), + ), ) row.save() row = model.get(row.id) @@ -248,9 +189,7 @@ def test_gis_select(db): assert not row.line.groups assert row.polygon.geometry == "POLYGON" - assert row.polygon.coordinates == ( - ((0, 0), (150, 0), (150, 10), (0, 10), (0, 0)), - ) + assert row.polygon.coordinates == (((0, 0), (150, 0), (150, 10), (0, 10), (0, 0)),) assert not row.polygon.groups assert row.multipoint.geometry == "MULTIPOINT" @@ -260,31 +199,20 @@ def test_gis_select(db): assert row.multipoint.groups[1] == geo.Point(2, 2) assert row.multiline.geometry == "MULTILINESTRING" - assert row.multiline.coordinates == ( - ((1, 1), (2, 2), (3, 3)), - ((1, 1), (4, 4), (5, 5)) - ) + assert row.multiline.coordinates == (((1, 1), (2, 2), (3, 3)), ((1, 1), (4, 4), (5, 5))) assert len(row.multiline.groups) == 2 assert row.multiline.groups[0] == geo.Line((1, 1), (2, 2), (3, 3)) assert row.multiline.groups[1] == geo.Line((1, 1), (4, 4), (5, 5)) assert row.multipolygon.geometry == "MULTIPOLYGON" assert row.multipolygon.coordinates == ( - ( - ((0, 0), (20, 0), (20, 20), (0, 0)), - ((0, 0), (30, 0), (30, 30), (0, 0)) - ), - ( - ((1, 1), (21, 1), (21, 21), (1, 1)), - ((1, 1), (31, 1), (31, 31), (1, 1)) - ) + (((0, 0), (20, 0), (20, 20), (0, 0)), ((0, 0), (30, 0), (30, 30), (0, 0))), + (((1, 1), (21, 1), (21, 21), (1, 1)), ((1, 1), (31, 1), (31, 31), (1, 1))), ) assert len(row.multipolygon.groups) == 2 assert row.multipolygon.groups[0] == geo.Polygon( - ((0, 0), (20, 0), (20, 20), (0, 0)), - ((0, 0), (30, 0), (30, 30), (0, 0)) + ((0, 0), (20, 0), (20, 20), (0, 0)), ((0, 0), (30, 0), (30, 30), (0, 0)) ) assert row.multipolygon.groups[1] == geo.Polygon( - ((1, 1), (21, 1), (21, 21), (1, 1)), - ((1, 1), (31, 1), (31, 31), (1, 1)) + ((1, 1), (21, 1), (21, 21), (1, 1)), ((1, 1), (31, 1), (31, 31), (1, 1)) ) diff --git a/tests/test_orm_pks.py b/tests/test_orm_pks.py index 54b12f8f..f877ff6f 100644 --- a/tests/test_orm_pks.py +++ b/tests/test_orm_pks.py @@ -1,25 +1,24 @@ # -*- coding: utf-8 -*- """ - tests.orm_pks - ------------- +tests.orm_pks +------------- - Test ORM primary keys hendling +Test ORM primary keys hendling """ import os -import pytest - from uuid import uuid4 +import pytest + from emmett import App, sdict -from emmett.orm import Database, Model, Field, belongs_to, has_many +from emmett.orm import Database, Field, Model, belongs_to, has_many from emmett.orm.errors import SaveException from emmett.orm.helpers import RowReferenceMixin from emmett.orm.migrations.utils import generate_runtime_migration -require_postgres = pytest.mark.skipif( - not os.environ.get("POSTGRES_URI"), reason="No postgres database" -) + +require_postgres = pytest.mark.skipif(not os.environ.get("POSTGRES_URI"), reason="No postgres database") class Standard(Model): @@ -101,7 +100,7 @@ class DoctorCustom(Model): has_many( {"appointments": "AppointmentCustom"}, {"patients": {"via": "appointments.patient_custom"}}, - {"symptoms_to_treat": {"via": "patients.symptoms"}} + {"symptoms_to_treat": {"via": "patients.symptoms"}}, ) id = Field.string(default=lambda: uuid4().hex) @@ -114,7 +113,7 @@ class DoctorMulti(Model): has_many( {"appointments": "AppointmentMulti"}, {"patients": {"via": "appointments.patient_multi"}}, - {"symptoms_to_treat": {"via": "patients.symptoms"}} + {"symptoms_to_treat": {"via": "patients.symptoms"}}, ) foo = Field.string(default=lambda: uuid4().hex) @@ -128,7 +127,7 @@ class PatientCustom(Model): has_many( {"appointments": "AppointmentCustom"}, {"symptoms": "SymptomCustom.patient"}, - {"doctors": {"via": "appointments.doctor_custom"}} + {"doctors": {"via": "appointments.doctor_custom"}}, ) code = Field.string(default=lambda: uuid4().hex) @@ -141,7 +140,7 @@ class PatientMulti(Model): has_many( {"appointments": "AppointmentMulti"}, {"symptoms": "SymptomMulti.patient"}, - {"doctors": {"via": "appointments.doctor_multi"}} + {"doctors": {"via": "appointments.doctor_multi"}}, ) foo = Field.string(default=lambda: uuid4().hex) @@ -185,35 +184,18 @@ class AppointmentMulti(Model): name = Field.string() -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def _db(): app = App(__name__) - db = Database( - app, - config=sdict( - uri=f'sqlite://{uuid4().hex}.db', - auto_connect=True - ) - ) - db.define_models( - Standard, - CustomType, - CustomName, - CustomMulti - ) + db = Database(app, config=sdict(uri=f"sqlite://{uuid4().hex}.db", auto_connect=True)) + db.define_models(Standard, CustomType, CustomName, CustomMulti) return db -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def _pgs(): app = App(__name__) - db = Database( - app, - config=sdict( - uri=f"postgres://{os.environ.get('POSTGRES_URI')}", - auto_connect=True - ) - ) + db = Database(app, config=sdict(uri=f"postgres://{os.environ.get('POSTGRES_URI')}", auto_connect=True)) db.define_models( SourceCustom, SourceMulti, @@ -228,12 +210,12 @@ def _pgs(): PatientMulti, AppointmentMulti, SymptomCustom, - SymptomMulti + SymptomMulti, ) return db -@pytest.fixture(scope='function') +@pytest.fixture(scope="function") def db(_db): migration = generate_runtime_migration(_db) migration.up() @@ -241,7 +223,7 @@ def db(_db): migration.down() -@pytest.fixture(scope='function') +@pytest.fixture(scope="function") def pgs(_pgs): migration = generate_runtime_migration(_pgs) migration.up() @@ -282,7 +264,7 @@ def test_save_insert(db): assert done assert row._concrete assert row.id - assert type(row.id) == int + assert type(row.id) is int row = CustomType.new(id="test1", foo="test2", bar="test3") done = row.save() @@ -770,9 +752,7 @@ def test_row(pgs): assert sm2.dest_multi_multis.count() == 0 dmm1.source_multi = sm2 - assert set(dmm1._changes.keys()).issubset( - {"source_multi", "source_multi_foo", "source_multi_bar"} - ) + assert set(dmm1._changes.keys()).issubset({"source_multi", "source_multi_foo", "source_multi_bar"}) dmm1.save() assert sm1.dest_multi_multis.count() == 0 assert sm2.dest_multi_multis.count() == 1 diff --git a/tests/test_orm_row.py b/tests/test_orm_row.py index 86367d5f..645f3235 100644 --- a/tests/test_orm_row.py +++ b/tests/test_orm_row.py @@ -1,20 +1,18 @@ # -*- coding: utf-8 -*- """ - tests.orm_row - ------------- +tests.orm_row +------------- - Test ORM row objects +Test ORM row objects """ import pickle -import pytest - from uuid import uuid4 -from emmett import App, sdict, now -from emmett.orm import ( - Database, Model, Field, belongs_to, has_many, refers_to, rowattr, rowmethod -) +import pytest + +from emmett import App, now, sdict +from emmett.orm import Database, Field, Model, belongs_to, has_many, refers_to, rowattr, rowmethod from emmett.orm.errors import ValidationError from emmett.orm.helpers import RowReferenceMixin from emmett.orm.migrations.utils import generate_runtime_migration @@ -70,20 +68,15 @@ class Crypted(Model): bar = Field.password() -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def _db(): app = App(__name__) - db = Database( - app, - config=sdict( - uri=f'sqlite://{uuid4().hex}.db', - auto_connect=True - ) - ) + db = Database(app, config=sdict(uri=f"sqlite://{uuid4().hex}.db", auto_connect=True)) db.define_models(One, Two, Three, Override, Crypted) return db -@pytest.fixture(scope='function') + +@pytest.fixture(scope="function") def db(_db): migration = generate_runtime_migration(_db) migration.up() @@ -96,80 +89,80 @@ def test_rowclass(db): db.Two.insert(one=ret, foo="test1", bar="test2") ret._allocate_() - assert type(ret._refrecord) == One._instance_()._rowclass_ + assert type(ret._refrecord) is One._instance_()._rowclass_ row = One.get(ret.id) - assert type(row) == One._instance_()._rowclass_ + assert type(row) is One._instance_()._rowclass_ row = One.first() - assert type(row) == One._instance_()._rowclass_ + assert type(row) is One._instance_()._rowclass_ row = db(db.One).select().first() - assert type(row) == One._instance_()._rowclass_ + assert type(row) is One._instance_()._rowclass_ row = db(db.One).select(db.One.ALL).first() - assert type(row) == One._instance_()._rowclass_ + assert type(row) is One._instance_()._rowclass_ row = One.all().select().first() - assert type(row) == One._instance_()._rowclass_ + assert type(row) is One._instance_()._rowclass_ row = One.where(lambda m: m.id != None).select().first() - assert type(row) == One._instance_()._rowclass_ + assert type(row) is One._instance_()._rowclass_ row = db(db.One).select().first() - assert type(row) == One._instance_()._rowclass_ + assert type(row) is One._instance_()._rowclass_ row = db(db.One).select(db.One.ALL).first() - assert type(row) == One._instance_()._rowclass_ + assert type(row) is One._instance_()._rowclass_ row = One.all().select(One.bar).first() - assert type(row) == Row + assert type(row) is Row row = db(db.One).select(One.bar).first() - assert type(row) == Row + assert type(row) is Row row = One.all().join("twos").select().first() - assert type(row) == One._instance_()._rowclass_ - assert type(row.twos().first()) == Two._instance_()._rowclass_ + assert type(row) is One._instance_()._rowclass_ + assert type(row.twos().first()) is Two._instance_()._rowclass_ row = One.all().join("twos").select(One.table.ALL, Two.table.ALL).first() - assert type(row) == One._instance_()._rowclass_ - assert type(row.twos().first()) == Two._instance_()._rowclass_ + assert type(row) is One._instance_()._rowclass_ + assert type(row.twos().first()) is Two._instance_()._rowclass_ row = One.all().join("twos").select(One.foo, Two.foo).first() - assert type(row) == Row - assert type(row.ones) == Row - assert type(row.twos) == Row + assert type(row) is Row + assert type(row.ones) is Row + assert type(row.twos) is Row row = db(Two.one == One.id).select().first() - assert type(row) == Row - assert type(row.ones) == One._instance_()._rowclass_ - assert type(row.twos) == Two._instance_()._rowclass_ + assert type(row) is Row + assert type(row.ones) is One._instance_()._rowclass_ + assert type(row.twos) is Two._instance_()._rowclass_ row = db(Two.one == One.id).select(One.table.ALL, Two.foo).first() - assert type(row) == Row - assert type(row.ones) == One._instance_()._rowclass_ - assert type(row.twos) == Row + assert type(row) is Row + assert type(row.ones) is One._instance_()._rowclass_ + assert type(row.twos) is Row row = db(Two.one == One.id).select(One.foo, Two.foo).first() - assert type(row) == Row - assert type(row.ones) == Row - assert type(row.twos) == Row + assert type(row) is Row + assert type(row.ones) is Row + assert type(row.twos) is Row for row in db(Two.one == One.id).iterselect(): - assert type(row) == Row - assert type(row.ones) == One._instance_()._rowclass_ - assert type(row.twos) == Two._instance_()._rowclass_ + assert type(row) is Row + assert type(row.ones) is One._instance_()._rowclass_ + assert type(row.twos) is Two._instance_()._rowclass_ for row in db(Two.one == One.id).iterselect(One.table.ALL, Two.foo): - assert type(row) == Row - assert type(row.ones) == One._instance_()._rowclass_ - assert type(row.twos) == Row + assert type(row) is Row + assert type(row.ones) is One._instance_()._rowclass_ + assert type(row.twos) is Row for row in db(Two.one == One.id).iterselect(One.foo, Two.foo): - assert type(row) == Row - assert type(row.ones) == Row - assert type(row.twos) == Row + assert type(row) is Row + assert type(row.ones) is Row + assert type(row.twos) is Row def test_concrete(db): diff --git a/tests/test_orm_transactions.py b/tests/test_orm_transactions.py index 782de130..194d3749 100644 --- a/tests/test_orm_transactions.py +++ b/tests/test_orm_transactions.py @@ -1,9 +1,9 @@ # -*- coding: utf-8 -*- """ - tests.orm_transactions - ---------------------- +tests.orm_transactions +---------------------- - Test pyDAL transactions implementation over Emmett. +Test pyDAL transactions implementation over Emmett. """ import pytest @@ -16,17 +16,15 @@ class Register(Model): value = Field.int() -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def db(): app = App(__name__) - db = Database( - app, config=sdict( - uri='sqlite:memory', auto_migrate=True, auto_connect=True)) + db = Database(app, config=sdict(uri="sqlite:memory", auto_migrate=True, auto_connect=True)) db.define_models(Register) return db -@pytest.fixture(scope='function') +@pytest.fixture(scope="function") def cleanup(request, db): def teardown(): Register.all().delete() @@ -41,7 +39,7 @@ def _save(*vals): def _values_in_register(*vals): - db_vals = Register.all().select(orderby=Register.value).column('value') + db_vals = Register.all().select(orderby=Register.value).column("value") return db_vals == list(vals) diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index ca3efc5c..986deac6 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -1,27 +1,27 @@ # -*- coding: utf-8 -*- """ - tests.pipeline - -------------- +tests.pipeline +-------------- - Test Emmett pipeline +Test Emmett pipeline """ import asyncio -import pytest - from contextlib import contextmanager +import pytest from helpers import current_ctx as _current_ctx, ws_ctx as _ws_ctx -from emmett_core.serializers import _json_type -from emmett import App, request, websocket, abort + +from emmett import App, abort, request, websocket from emmett.ctx import current from emmett.http import HTTP -from emmett.pipeline import Pipe, Injector from emmett.parsers import Parsers +from emmett.pipeline import Injector, Pipe from emmett.serializers import Serializers -json_load = Parsers.get_for('json') -json_dump = Serializers.get_for('json') + +json_load = Parsers.get_for("json") +json_dump = Serializers.get_for("json") class PipeException(Exception): @@ -45,43 +45,43 @@ def store_parallel(self, status): self.parallel_storage.append(self.__class__.__name__ + "." + status) async def on_pipe_success(self): - self.store_linear('success') + self.store_linear("success") async def on_pipe_failure(self): - self.store_linear('failure') + self.store_linear("failure") class FlowStorePipeCommon(FlowStorePipe): async def open(self): - self.store_parallel('open') + self.store_parallel("open") async def close(self): - self.store_parallel('close') + self.store_parallel("close") async def pipe(self, next_pipe, **kwargs): - self.store_linear('pipe') + self.store_linear("pipe") return await next_pipe(**kwargs) class FlowStorePipeSplit(FlowStorePipe): async def open_request(self): - self.store_parallel('open_request') + self.store_parallel("open_request") async def open_ws(self): - self.store_parallel('open_ws') + self.store_parallel("open_ws") async def close_request(self): - self.store_parallel('close_request') + self.store_parallel("close_request") async def close_ws(self): - self.store_parallel('close_ws') + self.store_parallel("close_ws") async def pipe_request(self, next_pipe, **kwargs): - self.store_linear('pipe_request') + self.store_linear("pipe_request") return await next_pipe(**kwargs) async def pipe_ws(self, next_pipe, **kwargs): - self.store_linear('pipe_ws') + self.store_linear("pipe_ws") return await next_pipe(**kwargs) @@ -91,13 +91,13 @@ class Pipe1(FlowStorePipeCommon): class Pipe2(FlowStorePipeSplit): async def pipe_request(self, next_pipe, **kwargs): - self.store_linear('pipe_request') + self.store_linear("pipe_request") if request.query_params.skip: return "block" return await next_pipe(**kwargs) async def pipe_ws(self, next_pipe, **kwargs): - self.store_linear('pipe_ws') + self.store_linear("pipe_ws") if websocket.query_params.skip: return await next_pipe(**kwargs) @@ -140,18 +140,18 @@ async def close(self): class PipeSR1(FlowStorePipeSplit): def on_receive(self, data): data = json_load(data) - return dict(pipe1r='receive_inject', **data) + return dict(pipe1r="receive_inject", **data) def on_send(self, data): - return json_dump(dict(pipe1s='send_inject', **data)) + return json_dump(dict(pipe1s="send_inject", **data)) class PipeSR2(FlowStorePipeSplit): def on_receive(self, data): - return dict(pipe2r='receive_inject', **data) + return dict(pipe2r="receive_inject", **data) def on_send(self, data): - return dict(pipe2s='send_inject', **data) + return dict(pipe2s="send_inject", **data) class CTXInjector(Injector): @@ -251,7 +251,7 @@ def parallel_flows_are_equal(flow, ctx): return set(flow) == set(ctx._pipeline_parallel_storage) -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def app(): app = App(__name__) app.pipeline = [Pipe1(), Pipe2(), Pipe3()] @@ -270,11 +270,11 @@ def error(): @app.route(pipeline=[ExcPipeOpen(), Pipe4()]) def open_error(): - return '' + return "" @app.route(pipeline=[ExcPipeClose(), Pipe4()]) def close_error(): - return '' + return "" @app.route(pipeline=[Pipe4()]) def pipe4(): @@ -282,7 +282,7 @@ def pipe4(): @app.websocket() async def ws_ok(): - await websocket.send('ok') + await websocket.send("ok") @app.websocket() def ws_error(): @@ -306,7 +306,7 @@ async def ws_inject(): current._receive_storage.append(data) await websocket.send(data) - mod = app.module(__name__, 'mod', url_prefix='mod') + mod = app.module(__name__, "mod", url_prefix="mod") mod.pipeline = [Pipe5()] @mod.route() @@ -325,15 +325,15 @@ def ws_pipe5(): def ws_pipe6(): return - inj = app.module(__name__, 'inj', url_prefix='inj') + inj = app.module(__name__, "inj", url_prefix="inj") inj.pipeline = [GlobalInjector(), ScopedInjector()] - @inj.route(template='test.html') + @inj.route(template="test.html") def injpipe(): - return {'posts': []} + return {"posts": []} - mg1 = app.module(__name__, 'mg1', url_prefix='mg1') - mg2 = app.module(__name__, 'mg2', url_prefix='mg2') + mg1 = app.module(__name__, "mg1", url_prefix="mg1") + mg2 = app.module(__name__, "mg2", url_prefix="mg2") mg1.pipeline = [Pipe5()] mg2.pipeline = [Pipe6()] mg = app.module_group(mg1, mg2) @@ -342,7 +342,7 @@ def injpipe(): async def pipe_mg(): return "mg" - mgc = mg.module(__name__, 'mgc', url_prefix='mgc') + mgc = mg.module(__name__, "mgc", url_prefix="mgc") mgc.pipeline = [Pipe7()] @mgc.route() @@ -354,24 +354,30 @@ async def pipe_mgc(): @pytest.mark.asyncio async def test_ok_flow(app): - with request_ctx(app, '/ok') as ctx: + with request_ctx(app, "/ok") as ctx: parallel_flow = [ - 'Pipe1.open', 'Pipe2.open_request', 'Pipe3.open', - 'Pipe3.close', 'Pipe2.close_request', 'Pipe1.close'] + "Pipe1.open", + "Pipe2.open_request", + "Pipe3.open", + "Pipe3.close", + "Pipe2.close_request", + "Pipe1.close", + ] linear_flow = [ - 'Pipe1.pipe', 'Pipe2.pipe_request', 'Pipe3.pipe', - 'Pipe3.success', 'Pipe2.success', 'Pipe1.success'] + "Pipe1.pipe", + "Pipe2.pipe_request", + "Pipe3.pipe", + "Pipe3.success", + "Pipe2.success", + "Pipe1.success", + ] await ctx.dispatch() assert linear_flows_are_equal(linear_flow, ctx.ctx) assert parallel_flows_are_equal(parallel_flow, ctx.ctx) - with ws_ctx(app, '/ws_ok') as ctx: - parallel_flow = [ - 'Pipe1.open', 'Pipe2.open_ws', 'Pipe3.open', - 'Pipe3.close', 'Pipe2.close_ws', 'Pipe1.close'] - linear_flow = [ - 'Pipe1.pipe', 'Pipe2.pipe_ws', 'Pipe3.pipe', - 'Pipe3.success', 'Pipe2.success', 'Pipe1.success'] + with ws_ctx(app, "/ws_ok") as ctx: + parallel_flow = ["Pipe1.open", "Pipe2.open_ws", "Pipe3.open", "Pipe3.close", "Pipe2.close_ws", "Pipe1.close"] + linear_flow = ["Pipe1.pipe", "Pipe2.pipe_ws", "Pipe3.pipe", "Pipe3.success", "Pipe2.success", "Pipe1.success"] await ctx.dispatch() assert linear_flows_are_equal(linear_flow, ctx.ctx) assert parallel_flows_are_equal(parallel_flow, ctx.ctx) @@ -379,13 +385,23 @@ async def test_ok_flow(app): @pytest.mark.asyncio async def test_httperror_flow(app): - with request_ctx(app, '/http_error') as ctx: + with request_ctx(app, "/http_error") as ctx: parallel_flow = [ - 'Pipe1.open', 'Pipe2.open_request', 'Pipe3.open', - 'Pipe3.close', 'Pipe2.close_request', 'Pipe1.close'] + "Pipe1.open", + "Pipe2.open_request", + "Pipe3.open", + "Pipe3.close", + "Pipe2.close_request", + "Pipe1.close", + ] linear_flow = [ - 'Pipe1.pipe', 'Pipe2.pipe_request', 'Pipe3.pipe', - 'Pipe3.success', 'Pipe2.success', 'Pipe1.success'] + "Pipe1.pipe", + "Pipe2.pipe_request", + "Pipe3.pipe", + "Pipe3.success", + "Pipe2.success", + "Pipe1.success", + ] try: await ctx.dispatch() except HTTP: @@ -396,13 +412,23 @@ async def test_httperror_flow(app): @pytest.mark.asyncio async def test_error_flow(app): - with request_ctx(app, '/error') as ctx: + with request_ctx(app, "/error") as ctx: parallel_flow = [ - 'Pipe1.open', 'Pipe2.open_request', 'Pipe3.open', - 'Pipe3.close', 'Pipe2.close_request', 'Pipe1.close'] + "Pipe1.open", + "Pipe2.open_request", + "Pipe3.open", + "Pipe3.close", + "Pipe2.close_request", + "Pipe1.close", + ] linear_flow = [ - 'Pipe1.pipe', 'Pipe2.pipe_request', 'Pipe3.pipe', - 'Pipe3.failure', 'Pipe2.failure', 'Pipe1.failure'] + "Pipe1.pipe", + "Pipe2.pipe_request", + "Pipe3.pipe", + "Pipe3.failure", + "Pipe2.failure", + "Pipe1.failure", + ] try: await ctx.dispatch() except Exception: @@ -410,13 +436,9 @@ async def test_error_flow(app): assert linear_flows_are_equal(linear_flow, ctx.ctx) assert parallel_flows_are_equal(parallel_flow, ctx.ctx) - with ws_ctx(app, '/ws_error') as ctx: - parallel_flow = [ - 'Pipe1.open', 'Pipe2.open_ws', 'Pipe3.open', - 'Pipe3.close', 'Pipe2.close_ws', 'Pipe1.close'] - linear_flow = [ - 'Pipe1.pipe', 'Pipe2.pipe_ws', 'Pipe3.pipe', - 'Pipe3.failure', 'Pipe2.failure', 'Pipe1.failure'] + with ws_ctx(app, "/ws_error") as ctx: + parallel_flow = ["Pipe1.open", "Pipe2.open_ws", "Pipe3.open", "Pipe3.close", "Pipe2.close_ws", "Pipe1.close"] + linear_flow = ["Pipe1.pipe", "Pipe2.pipe_ws", "Pipe3.pipe", "Pipe3.failure", "Pipe2.failure", "Pipe1.failure"] try: await ctx.dispatch() except Exception: @@ -427,9 +449,8 @@ async def test_error_flow(app): @pytest.mark.asyncio async def test_open_error(app): - with request_ctx(app, '/open_error') as ctx: - parallel_flow = [ - 'Pipe1.open', 'Pipe2.open_request', 'Pipe3.open', 'Pipe4.open'] + with request_ctx(app, "/open_error") as ctx: + parallel_flow = ["Pipe1.open", "Pipe2.open_request", "Pipe3.open", "Pipe4.open"] linear_flow = [] try: await ctx.dispatch() @@ -438,9 +459,8 @@ async def test_open_error(app): assert linear_flows_are_equal(linear_flow, ctx.ctx) assert parallel_flows_are_equal(parallel_flow, ctx.ctx) - with ws_ctx(app, '/ws_open_error') as ctx: - parallel_flow = [ - 'Pipe1.open', 'Pipe2.open_ws', 'Pipe3.open', 'Pipe4.open'] + with ws_ctx(app, "/ws_open_error") as ctx: + parallel_flow = ["Pipe1.open", "Pipe2.open_ws", "Pipe3.open", "Pipe4.open"] linear_flow = [] try: await ctx.dispatch() @@ -452,16 +472,30 @@ async def test_open_error(app): @pytest.mark.asyncio async def test_close_error(app): - with request_ctx(app, '/close_error') as ctx: + with request_ctx(app, "/close_error") as ctx: parallel_flow = [ - 'Pipe1.open', 'Pipe2.open_request', 'Pipe3.open', - 'ExcPipeClose.open', 'Pipe4.open', - 'Pipe4.close', 'Pipe3.close', 'Pipe2.close_request', 'Pipe1.close'] + "Pipe1.open", + "Pipe2.open_request", + "Pipe3.open", + "ExcPipeClose.open", + "Pipe4.open", + "Pipe4.close", + "Pipe3.close", + "Pipe2.close_request", + "Pipe1.close", + ] linear_flow = [ - 'Pipe1.pipe', 'Pipe2.pipe_request', 'Pipe3.pipe', - 'ExcPipeClose.pipe', 'Pipe4.pipe', - 'Pipe4.success', 'ExcPipeClose.success', 'Pipe3.success', - 'Pipe2.success', 'Pipe1.success'] + "Pipe1.pipe", + "Pipe2.pipe_request", + "Pipe3.pipe", + "ExcPipeClose.pipe", + "Pipe4.pipe", + "Pipe4.success", + "ExcPipeClose.success", + "Pipe3.success", + "Pipe2.success", + "Pipe1.success", + ] try: await ctx.dispatch() except PipeException as e: @@ -469,16 +503,30 @@ async def test_close_error(app): assert linear_flows_are_equal(linear_flow, ctx.ctx) assert parallel_flows_are_equal(parallel_flow, ctx.ctx) - with ws_ctx(app, '/ws_close_error') as ctx: + with ws_ctx(app, "/ws_close_error") as ctx: parallel_flow = [ - 'Pipe1.open', 'Pipe2.open_ws', 'Pipe3.open', - 'ExcPipeClose.open', 'Pipe4.open', - 'Pipe4.close', 'Pipe3.close', 'Pipe2.close_ws', 'Pipe1.close'] + "Pipe1.open", + "Pipe2.open_ws", + "Pipe3.open", + "ExcPipeClose.open", + "Pipe4.open", + "Pipe4.close", + "Pipe3.close", + "Pipe2.close_ws", + "Pipe1.close", + ] linear_flow = [ - 'Pipe1.pipe', 'Pipe2.pipe_ws', 'Pipe3.pipe', - 'ExcPipeClose.pipe', 'Pipe4.pipe', - 'Pipe4.success', 'ExcPipeClose.success', 'Pipe3.success', - 'Pipe2.success', 'Pipe1.success'] + "Pipe1.pipe", + "Pipe2.pipe_ws", + "Pipe3.pipe", + "ExcPipeClose.pipe", + "Pipe4.pipe", + "Pipe4.success", + "ExcPipeClose.success", + "Pipe3.success", + "Pipe2.success", + "Pipe1.success", + ] try: await ctx.dispatch() except PipeException as e: @@ -489,24 +537,23 @@ async def test_close_error(app): @pytest.mark.asyncio async def test_flow_interrupt(app): - with request_ctx(app, '/ok?skip=yes') as ctx: + with request_ctx(app, "/ok?skip=yes") as ctx: parallel_flow = [ - 'Pipe1.open', 'Pipe2.open_request', 'Pipe3.open', - 'Pipe3.close', 'Pipe2.close_request', 'Pipe1.close'] - linear_flow = [ - 'Pipe1.pipe', 'Pipe2.pipe_request', - 'Pipe2.success', 'Pipe1.success'] + "Pipe1.open", + "Pipe2.open_request", + "Pipe3.open", + "Pipe3.close", + "Pipe2.close_request", + "Pipe1.close", + ] + linear_flow = ["Pipe1.pipe", "Pipe2.pipe_request", "Pipe2.success", "Pipe1.success"] await ctx.dispatch() assert linear_flows_are_equal(linear_flow, ctx.ctx) assert parallel_flows_are_equal(parallel_flow, ctx.ctx) - with ws_ctx(app, '/ws_ok?skip=yes') as ctx: - parallel_flow = [ - 'Pipe1.open', 'Pipe2.open_ws', 'Pipe3.open', - 'Pipe3.close', 'Pipe2.close_ws', 'Pipe1.close'] - linear_flow = [ - 'Pipe1.pipe', 'Pipe2.pipe_ws', - 'Pipe2.success', 'Pipe1.success'] + with ws_ctx(app, "/ws_ok?skip=yes") as ctx: + parallel_flow = ["Pipe1.open", "Pipe2.open_ws", "Pipe3.open", "Pipe3.close", "Pipe2.close_ws", "Pipe1.close"] + linear_flow = ["Pipe1.pipe", "Pipe2.pipe_ws", "Pipe2.success", "Pipe1.success"] await ctx.dispatch() assert linear_flows_are_equal(linear_flow, ctx.ctx) assert parallel_flows_are_equal(parallel_flow, ctx.ctx) @@ -514,24 +561,52 @@ async def test_flow_interrupt(app): @pytest.mark.asyncio async def test_pipeline_composition(app): - with request_ctx(app, '/pipe4') as ctx: + with request_ctx(app, "/pipe4") as ctx: parallel_flow = [ - 'Pipe1.open', 'Pipe2.open_request', 'Pipe3.open', 'Pipe4.open', - 'Pipe4.close', 'Pipe3.close', 'Pipe2.close_request', 'Pipe1.close'] + "Pipe1.open", + "Pipe2.open_request", + "Pipe3.open", + "Pipe4.open", + "Pipe4.close", + "Pipe3.close", + "Pipe2.close_request", + "Pipe1.close", + ] linear_flow = [ - 'Pipe1.pipe', 'Pipe2.pipe_request', 'Pipe3.pipe', 'Pipe4.pipe', - 'Pipe4.success', 'Pipe3.success', 'Pipe2.success', 'Pipe1.success'] + "Pipe1.pipe", + "Pipe2.pipe_request", + "Pipe3.pipe", + "Pipe4.pipe", + "Pipe4.success", + "Pipe3.success", + "Pipe2.success", + "Pipe1.success", + ] await ctx.dispatch() assert linear_flows_are_equal(linear_flow, ctx.ctx) assert parallel_flows_are_equal(parallel_flow, ctx.ctx) - with ws_ctx(app, '/ws_pipe4') as ctx: + with ws_ctx(app, "/ws_pipe4") as ctx: parallel_flow = [ - 'Pipe1.open', 'Pipe2.open_ws', 'Pipe3.open', 'Pipe4.open', - 'Pipe4.close', 'Pipe3.close', 'Pipe2.close_ws', 'Pipe1.close'] + "Pipe1.open", + "Pipe2.open_ws", + "Pipe3.open", + "Pipe4.open", + "Pipe4.close", + "Pipe3.close", + "Pipe2.close_ws", + "Pipe1.close", + ] linear_flow = [ - 'Pipe1.pipe', 'Pipe2.pipe_ws', 'Pipe3.pipe', 'Pipe4.pipe', - 'Pipe4.success', 'Pipe3.success', 'Pipe2.success', 'Pipe1.success'] + "Pipe1.pipe", + "Pipe2.pipe_ws", + "Pipe3.pipe", + "Pipe4.pipe", + "Pipe4.success", + "Pipe3.success", + "Pipe2.success", + "Pipe1.success", + ] await ctx.dispatch() assert linear_flows_are_equal(linear_flow, ctx.ctx) assert parallel_flows_are_equal(parallel_flow, ctx.ctx) @@ -539,24 +614,52 @@ async def test_pipeline_composition(app): @pytest.mark.asyncio async def test_module_pipeline(app): - with request_ctx(app, '/mod/pipe5') as ctx: + with request_ctx(app, "/mod/pipe5") as ctx: parallel_flow = [ - 'Pipe1.open', 'Pipe2.open_request', 'Pipe3.open', 'Pipe5.open', - 'Pipe5.close', 'Pipe3.close', 'Pipe2.close_request', 'Pipe1.close'] + "Pipe1.open", + "Pipe2.open_request", + "Pipe3.open", + "Pipe5.open", + "Pipe5.close", + "Pipe3.close", + "Pipe2.close_request", + "Pipe1.close", + ] linear_flow = [ - 'Pipe1.pipe', 'Pipe2.pipe_request', 'Pipe3.pipe', 'Pipe5.pipe', - 'Pipe5.success', 'Pipe3.success', 'Pipe2.success', 'Pipe1.success'] + "Pipe1.pipe", + "Pipe2.pipe_request", + "Pipe3.pipe", + "Pipe5.pipe", + "Pipe5.success", + "Pipe3.success", + "Pipe2.success", + "Pipe1.success", + ] await ctx.dispatch() assert linear_flows_are_equal(linear_flow, ctx.ctx) assert parallel_flows_are_equal(parallel_flow, ctx.ctx) - with ws_ctx(app, '/mod/ws_pipe5') as ctx: + with ws_ctx(app, "/mod/ws_pipe5") as ctx: parallel_flow = [ - 'Pipe1.open', 'Pipe2.open_ws', 'Pipe3.open', 'Pipe5.open', - 'Pipe5.close', 'Pipe3.close', 'Pipe2.close_ws', 'Pipe1.close'] + "Pipe1.open", + "Pipe2.open_ws", + "Pipe3.open", + "Pipe5.open", + "Pipe5.close", + "Pipe3.close", + "Pipe2.close_ws", + "Pipe1.close", + ] linear_flow = [ - 'Pipe1.pipe', 'Pipe2.pipe_ws', 'Pipe3.pipe', 'Pipe5.pipe', - 'Pipe5.success', 'Pipe3.success', 'Pipe2.success', 'Pipe1.success'] + "Pipe1.pipe", + "Pipe2.pipe_ws", + "Pipe3.pipe", + "Pipe5.pipe", + "Pipe5.success", + "Pipe3.success", + "Pipe2.success", + "Pipe1.success", + ] await ctx.dispatch() assert linear_flows_are_equal(linear_flow, ctx.ctx) assert parallel_flows_are_equal(parallel_flow, ctx.ctx) @@ -564,32 +667,60 @@ async def test_module_pipeline(app): @pytest.mark.asyncio async def test_module_pipeline_composition(app): - with request_ctx(app, '/mod/pipe6') as ctx: + with request_ctx(app, "/mod/pipe6") as ctx: parallel_flow = [ - 'Pipe1.open', 'Pipe2.open_request', 'Pipe3.open', 'Pipe5.open', - 'Pipe6.open', - 'Pipe6.close', 'Pipe5.close', 'Pipe3.close', 'Pipe2.close_request', - 'Pipe1.close'] + "Pipe1.open", + "Pipe2.open_request", + "Pipe3.open", + "Pipe5.open", + "Pipe6.open", + "Pipe6.close", + "Pipe5.close", + "Pipe3.close", + "Pipe2.close_request", + "Pipe1.close", + ] linear_flow = [ - 'Pipe1.pipe', 'Pipe2.pipe_request', 'Pipe3.pipe', 'Pipe5.pipe', - 'Pipe6.pipe', - 'Pipe6.success', 'Pipe5.success', 'Pipe3.success', 'Pipe2.success', - 'Pipe1.success'] + "Pipe1.pipe", + "Pipe2.pipe_request", + "Pipe3.pipe", + "Pipe5.pipe", + "Pipe6.pipe", + "Pipe6.success", + "Pipe5.success", + "Pipe3.success", + "Pipe2.success", + "Pipe1.success", + ] await ctx.dispatch() assert linear_flows_are_equal(linear_flow, ctx.ctx) assert parallel_flows_are_equal(parallel_flow, ctx.ctx) - with ws_ctx(app, '/mod/ws_pipe6') as ctx: + with ws_ctx(app, "/mod/ws_pipe6") as ctx: parallel_flow = [ - 'Pipe1.open', 'Pipe2.open_ws', 'Pipe3.open', 'Pipe5.open', - 'Pipe6.open', - 'Pipe6.close', 'Pipe5.close', 'Pipe3.close', 'Pipe2.close_ws', - 'Pipe1.close'] + "Pipe1.open", + "Pipe2.open_ws", + "Pipe3.open", + "Pipe5.open", + "Pipe6.open", + "Pipe6.close", + "Pipe5.close", + "Pipe3.close", + "Pipe2.close_ws", + "Pipe1.close", + ] linear_flow = [ - 'Pipe1.pipe', 'Pipe2.pipe_ws', 'Pipe3.pipe', 'Pipe5.pipe', - 'Pipe6.pipe', - 'Pipe6.success', 'Pipe5.success', 'Pipe3.success', 'Pipe2.success', - 'Pipe1.success'] + "Pipe1.pipe", + "Pipe2.pipe_ws", + "Pipe3.pipe", + "Pipe5.pipe", + "Pipe6.pipe", + "Pipe6.success", + "Pipe5.success", + "Pipe3.success", + "Pipe2.success", + "Pipe1.success", + ] await ctx.dispatch() assert linear_flows_are_equal(linear_flow, ctx.ctx) assert parallel_flows_are_equal(parallel_flow, ctx.ctx) @@ -597,24 +728,52 @@ async def test_module_pipeline_composition(app): @pytest.mark.asyncio async def test_module_group_pipeline(app): - with request_ctx(app, '/mg1/pipe_mg') as ctx: + with request_ctx(app, "/mg1/pipe_mg") as ctx: parallel_flow = [ - 'Pipe1.open', 'Pipe2.open_request', 'Pipe3.open', 'Pipe5.open', - 'Pipe5.close', 'Pipe3.close', 'Pipe2.close_request', 'Pipe1.close'] + "Pipe1.open", + "Pipe2.open_request", + "Pipe3.open", + "Pipe5.open", + "Pipe5.close", + "Pipe3.close", + "Pipe2.close_request", + "Pipe1.close", + ] linear_flow = [ - 'Pipe1.pipe', 'Pipe2.pipe_request', 'Pipe3.pipe', 'Pipe5.pipe', - 'Pipe5.success', 'Pipe3.success', 'Pipe2.success', 'Pipe1.success'] + "Pipe1.pipe", + "Pipe2.pipe_request", + "Pipe3.pipe", + "Pipe5.pipe", + "Pipe5.success", + "Pipe3.success", + "Pipe2.success", + "Pipe1.success", + ] await ctx.dispatch() assert linear_flows_are_equal(linear_flow, ctx.ctx) assert parallel_flows_are_equal(parallel_flow, ctx.ctx) - with request_ctx(app, '/mg2/pipe_mg') as ctx: + with request_ctx(app, "/mg2/pipe_mg") as ctx: parallel_flow = [ - 'Pipe1.open', 'Pipe2.open_request', 'Pipe3.open', 'Pipe6.open', - 'Pipe6.close', 'Pipe3.close', 'Pipe2.close_request', 'Pipe1.close'] + "Pipe1.open", + "Pipe2.open_request", + "Pipe3.open", + "Pipe6.open", + "Pipe6.close", + "Pipe3.close", + "Pipe2.close_request", + "Pipe1.close", + ] linear_flow = [ - 'Pipe1.pipe', 'Pipe2.pipe_request', 'Pipe3.pipe', 'Pipe6.pipe', - 'Pipe6.success', 'Pipe3.success', 'Pipe2.success', 'Pipe1.success'] + "Pipe1.pipe", + "Pipe2.pipe_request", + "Pipe3.pipe", + "Pipe6.pipe", + "Pipe6.success", + "Pipe3.success", + "Pipe2.success", + "Pipe1.success", + ] await ctx.dispatch() assert linear_flows_are_equal(linear_flow, ctx.ctx) assert parallel_flows_are_equal(parallel_flow, ctx.ctx) @@ -622,32 +781,60 @@ async def test_module_group_pipeline(app): @pytest.mark.asyncio async def test_module_group_pipeline_composition(app): - with request_ctx(app, '/mg1/mgc/pipe_mgc') as ctx: + with request_ctx(app, "/mg1/mgc/pipe_mgc") as ctx: parallel_flow = [ - 'Pipe1.open', 'Pipe2.open_request', 'Pipe3.open', 'Pipe5.open', - 'Pipe7.open', - 'Pipe7.close', 'Pipe5.close', 'Pipe3.close', 'Pipe2.close_request', - 'Pipe1.close'] + "Pipe1.open", + "Pipe2.open_request", + "Pipe3.open", + "Pipe5.open", + "Pipe7.open", + "Pipe7.close", + "Pipe5.close", + "Pipe3.close", + "Pipe2.close_request", + "Pipe1.close", + ] linear_flow = [ - 'Pipe1.pipe', 'Pipe2.pipe_request', 'Pipe3.pipe', 'Pipe5.pipe', - 'Pipe7.pipe', - 'Pipe7.success', 'Pipe5.success', 'Pipe3.success', 'Pipe2.success', - 'Pipe1.success'] + "Pipe1.pipe", + "Pipe2.pipe_request", + "Pipe3.pipe", + "Pipe5.pipe", + "Pipe7.pipe", + "Pipe7.success", + "Pipe5.success", + "Pipe3.success", + "Pipe2.success", + "Pipe1.success", + ] await ctx.dispatch() assert linear_flows_are_equal(linear_flow, ctx.ctx) assert parallel_flows_are_equal(parallel_flow, ctx.ctx) - with request_ctx(app, '/mg2/mgc/pipe_mgc') as ctx: + with request_ctx(app, "/mg2/mgc/pipe_mgc") as ctx: parallel_flow = [ - 'Pipe1.open', 'Pipe2.open_request', 'Pipe3.open', 'Pipe6.open', - 'Pipe7.open', - 'Pipe7.close', 'Pipe6.close', 'Pipe3.close', 'Pipe2.close_request', - 'Pipe1.close'] + "Pipe1.open", + "Pipe2.open_request", + "Pipe3.open", + "Pipe6.open", + "Pipe7.open", + "Pipe7.close", + "Pipe6.close", + "Pipe3.close", + "Pipe2.close_request", + "Pipe1.close", + ] linear_flow = [ - 'Pipe1.pipe', 'Pipe2.pipe_request', 'Pipe3.pipe', 'Pipe6.pipe', - 'Pipe7.pipe', - 'Pipe7.success', 'Pipe6.success', 'Pipe3.success', 'Pipe2.success', - 'Pipe1.success'] + "Pipe1.pipe", + "Pipe2.pipe_request", + "Pipe3.pipe", + "Pipe6.pipe", + "Pipe7.pipe", + "Pipe7.success", + "Pipe6.success", + "Pipe3.success", + "Pipe2.success", + "Pipe1.success", + ] await ctx.dispatch() assert linear_flows_are_equal(linear_flow, ctx.ctx) assert parallel_flows_are_equal(parallel_flow, ctx.ctx) @@ -655,48 +842,61 @@ async def test_module_group_pipeline_composition(app): @pytest.mark.asyncio async def test_receive_send_flow(app): - with ws_ctx(app, '/ws_inject') as ctx: + with ws_ctx(app, "/ws_inject") as ctx: parallel_flow = [ - 'Pipe1.open', 'Pipe2.open_ws', 'Pipe3.open', - 'PipeSR1.open_ws', 'PipeSR2.open_ws', - 'PipeSR2.close_ws', 'PipeSR1.close_ws', - 'Pipe3.close', 'Pipe2.close_ws', 'Pipe1.close'] + "Pipe1.open", + "Pipe2.open_ws", + "Pipe3.open", + "PipeSR1.open_ws", + "PipeSR2.open_ws", + "PipeSR2.close_ws", + "PipeSR1.close_ws", + "Pipe3.close", + "Pipe2.close_ws", + "Pipe1.close", + ] linear_flow = [ - 'Pipe1.pipe', 'Pipe2.pipe_ws', 'Pipe3.pipe', - 'PipeSR1.pipe_ws', 'PipeSR2.pipe_ws', - 'PipeSR2.success', 'PipeSR1.success', - 'Pipe3.success', 'Pipe2.success', 'Pipe1.success'] + "Pipe1.pipe", + "Pipe2.pipe_ws", + "Pipe3.pipe", + "PipeSR1.pipe_ws", + "PipeSR2.pipe_ws", + "PipeSR2.success", + "PipeSR1.success", + "Pipe3.success", + "Pipe2.success", + "Pipe1.success", + ] await ctx.dispatch() assert linear_flows_are_equal(linear_flow, ctx.ctx) assert parallel_flows_are_equal(parallel_flow, ctx.ctx) - assert ctx.ctx._receive_storage[-1] == { - 'foo': 'bar', - 'pipe1r': 'receive_inject', 'pipe2r': 'receive_inject' - } + assert ctx.ctx._receive_storage[-1] == {"foo": "bar", "pipe1r": "receive_inject", "pipe2r": "receive_inject"} assert json_load(ctx.ctx._send_storage[-1]) == { - 'foo': 'bar', - 'pipe1r': 'receive_inject', 'pipe2r': 'receive_inject', - 'pipe1s': 'send_inject', 'pipe2s': 'send_inject' + "foo": "bar", + "pipe1r": "receive_inject", + "pipe2r": "receive_inject", + "pipe1s": "send_inject", + "pipe2s": "send_inject", } @pytest.mark.asyncio async def test_injectors(app): - with request_ctx(app, '/inj/injpipe') as ctx: + with request_ctx(app, "/inj/injpipe") as ctx: current.app = app await ctx.dispatch() env = ctx.ctx._pipeline_generic_storage[0] - assert env['posts'] == [] - assert env['foo'] == "bar" - assert env['bar'] == "baz" - assert env['staticm']("test") == "test" - assert env['boundm']("test") == ("bar", "test") - assert env['prop'] == "baz" + assert env["posts"] == [] + assert env["foo"] == "bar" + assert env["bar"] == "baz" + assert env["staticm"]("test") == "test" + assert env["boundm"]("test") == ("bar", "test") + assert env["prop"] == "baz" - env = env['scoped'] + env = env["scoped"] assert env.foo == "bar" assert env.bar == "baz" assert env.staticm("test") == "test" diff --git a/tests/test_routing.py b/tests/test_routing.py index 1eb2c9d5..5005cd2f 100644 --- a/tests/test_routing.py +++ b/tests/test_routing.py @@ -1,18 +1,18 @@ # -*- coding: utf-8 -*- """ - tests.routing - ------------- +tests.routing +------------- - Test Emmett routing module +Test Emmett routing module """ -import pendulum -import pytest - from contextlib import contextmanager -from helpers import FakeRequestContext +import pendulum +import pytest from emmett_core.protocols.rsgi.test_client.scope import ScopeBuilder +from helpers import FakeRequestContext + from emmett import App, abort, url from emmett.ctx import current from emmett.datastructures import sdict @@ -27,156 +27,146 @@ def current_ctx(app, path): current._close_(token) -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def app(): app = App(__name__) - app.languages = ['en', 'it'] - app.language_default = 'en' + app.languages = ["en", "it"] + app.language_default = "en" app.language_force_on_url = True @app.route() def test_route(): - return 'Test Router' + return "Test Router" @app.route() def test_404(): - abort(404, 'Not found, dude') + abort(404, "Not found, dude") - @app.route('/test2//') + @app.route("/test2//") def test_route2(a, b): - return 'Test Router' + return "Test Router" - @app.route('/test3//foo(/)?(.)?') + @app.route("/test3//foo(/)?(.)?") def test_route3(a, b, c): - return 'Test Router' + return "Test Router" - @app.route('/test_int/') + @app.route("/test_int/") def test_route_int(a): - return 'Test Router' + return "Test Router" - @app.route('/test_float/') + @app.route("/test_float/") def test_route_float(a): - return 'Test Router' + return "Test Router" - @app.route('/test_date/') + @app.route("/test_date/") def test_route_date(a): - return 'Test Router' + return "Test Router" - @app.route('/test_alpha/') + @app.route("/test_alpha/") def test_route_alpha(a): - return 'Test Router' + return "Test Router" - @app.route('/test_str/') + @app.route("/test_str/") def test_route_str(a): - return 'Test Router' + return "Test Router" - @app.route('/test_any/') + @app.route("/test_any/") def test_route_any(a): - return 'Test Router' - - @app.route( - '/test_complex' - '/' - '/' - '/' - '/' - '/' - '/' - ) + return "Test Router" + + @app.route("/test_complex" "/" "/" "/" "/" "/" "/") def test_route_complex(a, b, c, d, e, f): - current._reqargs = {"a": a, "b":b, "c": c, "d": d, "e": e, "f": f} - return 'Test Router' + current._reqargs = {"a": a, "b": b, "c": c, "d": d, "e": e, "f": f} + return "Test Router" return app def test_routing(app): - with current_ctx(app, '/test_int') as ctx: + with current_ctx(app, "/test_int") as ctx: route, args = app._router_http.match(ctx.request) assert not route - with current_ctx(app, '/test_int/a') as ctx: + with current_ctx(app, "/test_int/a") as ctx: route, args = app._router_http.match(ctx.request) assert not route - with current_ctx(app, '/test_int/1.1') as ctx: + with current_ctx(app, "/test_int/1.1") as ctx: route, args = app._router_http.match(ctx.request) assert not route - with current_ctx(app, '/test_int/2000-01-01') as ctx: + with current_ctx(app, "/test_int/2000-01-01") as ctx: route, args = app._router_http.match(ctx.request) assert not route - with current_ctx(app, '/test_int/1') as ctx: + with current_ctx(app, "/test_int/1") as ctx: route, args = app._router_http.match(ctx.request) - assert route.name == 'test_routing.test_route_int' + assert route.name == "test_routing.test_route_int" - with current_ctx(app, '/test_float') as ctx: + with current_ctx(app, "/test_float") as ctx: route, args = app._router_http.match(ctx.request) assert not route - with current_ctx(app, '/test_float/a.a') as ctx: + with current_ctx(app, "/test_float/a.a") as ctx: route, args = app._router_http.match(ctx.request) assert not route - with current_ctx(app, '/test_float/1') as ctx: + with current_ctx(app, "/test_float/1") as ctx: route, args = app._router_http.match(ctx.request) assert not route - with current_ctx(app, '/test_float/1.1') as ctx: + with current_ctx(app, "/test_float/1.1") as ctx: route, args = app._router_http.match(ctx.request) - assert route.name == 'test_routing.test_route_float' + assert route.name == "test_routing.test_route_float" - with current_ctx(app, '/test_date') as ctx: + with current_ctx(app, "/test_date") as ctx: route, args = app._router_http.match(ctx.request) assert not route - with current_ctx(app, '/test_date/2000-01-01') as ctx: + with current_ctx(app, "/test_date/2000-01-01") as ctx: route, args = app._router_http.match(ctx.request) - assert route.name == 'test_routing.test_route_date' + assert route.name == "test_routing.test_route_date" - with current_ctx(app, '/test_alpha') as ctx: + with current_ctx(app, "/test_alpha") as ctx: route, args = app._router_http.match(ctx.request) assert not route - with current_ctx(app, '/test_alpha/a1') as ctx: + with current_ctx(app, "/test_alpha/a1") as ctx: route, args = app._router_http.match(ctx.request) assert not route - with current_ctx(app, '/test_alpha/a-a') as ctx: + with current_ctx(app, "/test_alpha/a-a") as ctx: route, args = app._router_http.match(ctx.request) assert not route - with current_ctx(app, '/test_alpha/a') as ctx: + with current_ctx(app, "/test_alpha/a") as ctx: route, args = app._router_http.match(ctx.request) - assert route.name == 'test_routing.test_route_alpha' + assert route.name == "test_routing.test_route_alpha" - with current_ctx(app, '/test_str') as ctx: + with current_ctx(app, "/test_str") as ctx: route, args = app._router_http.match(ctx.request) assert not route - with current_ctx(app, '/test_str/a/b') as ctx: + with current_ctx(app, "/test_str/a/b") as ctx: route, args = app._router_http.match(ctx.request) assert not route - with current_ctx(app, '/test_str/a1-') as ctx: + with current_ctx(app, "/test_str/a1-") as ctx: route, args = app._router_http.match(ctx.request) - assert route.name == 'test_routing.test_route_str' + assert route.name == "test_routing.test_route_str" - with current_ctx(app, '/test_any') as ctx: + with current_ctx(app, "/test_any") as ctx: route, args = app._router_http.match(ctx.request) assert not route - with current_ctx(app, '/test_any/a/b') as ctx: + with current_ctx(app, "/test_any/a/b") as ctx: route, args = app._router_http.match(ctx.request) - assert route.name == 'test_routing.test_route_any' + assert route.name == "test_routing.test_route_any" @pytest.mark.asyncio async def test_route_args(app): - with current_ctx( - app, '/test_complex/1/1.2/2000-12-01/foo/foo1/bar/baz' - ) as ctx: + with current_ctx(app, "/test_complex/1/1.2/2000-12-01/foo/foo1/bar/baz") as ctx: # route, args = app._router_http.match(ctx.request) # assert route.name == 'test_routing.test_route_complex' # assert args['a'] == 1 @@ -186,109 +176,108 @@ async def test_route_args(app): # assert args['e'] == 'foo1' # assert args['f'] == 'bar/baz' await app._router_http.dispatch(ctx.request, sdict()) - assert ctx.request.name == 'test_routing.test_route_complex' + assert ctx.request.name == "test_routing.test_route_complex" args = current._reqargs - assert args['a'] == 1 - assert round(args['b'], 1) == 1.2 - assert args['c'] == pendulum.datetime(2000, 12, 1) - assert args['d'] == 'foo' - assert args['e'] == 'foo1' - assert args['f'] == 'bar/baz' + assert args["a"] == 1 + assert round(args["b"], 1) == 1.2 + assert args["c"] == pendulum.datetime(2000, 12, 1) + assert args["d"] == "foo" + assert args["e"] == "foo1" + assert args["f"] == "bar/baz" @pytest.mark.asyncio async def test_routing_valid_route(app): - with current_ctx(app, '/it/test_route') as ctx: + with current_ctx(app, "/it/test_route") as ctx: response = await app._router_http.dispatch(ctx.request, ctx.response) assert response.status_code == ctx.response.status == 200 - assert response.body == 'Test Router' - assert ctx.request.language == 'it' + assert response.body == "Test Router" + assert ctx.request.language == "it" @pytest.mark.asyncio async def test_routing_not_found_route(app): - with current_ctx(app, '/') as ctx: + with current_ctx(app, "/") as ctx: with pytest.raises(HTTPResponse) as excinfo: await app._router_http.dispatch(ctx.request, ctx.response) assert excinfo.value.status_code == 404 - assert excinfo.value.body == b'Resource not found' + assert excinfo.value.body == b"Resource not found" @pytest.mark.asyncio async def test_routing_exception_route(app): - with current_ctx(app, '/test_404') as ctx: + with current_ctx(app, "/test_404") as ctx: with pytest.raises(HTTPResponse) as excinfo: await app._router_http.dispatch(ctx.request, ctx.response) assert excinfo.value.status_code == 404 - assert excinfo.value.body == 'Not found, dude' + assert excinfo.value.body == "Not found, dude" def test_static_url(app): - link = url('static', 'file') - assert link == '/static/file' + link = url("static", "file") + assert link == "/static/file" app.config.static_version_urls = True - app.config.static_version = '1.0.0' - link = url('static', 'js/foo.js', language='it') - assert link == '/it/static/_1.0.0/js/foo.js' + app.config.static_version = "1.0.0" + link = url("static", "js/foo.js", language="it") + assert link == "/it/static/_1.0.0/js/foo.js" def test_module_url(app): - with current_ctx(app, '/') as ctx: - ctx.request.language = 'it' - link = url('test_route') - assert link == '/it/test_route' - link = url('test_route2') - assert link == '/it/test2' - link = url('test_route2', [2]) - assert link == '/it/test2/2' - link = url('test_route2', [2, 'foo']) - assert link == '/it/test2/2/foo' - link = url('test_route3') - assert link == '/it/test3' - link = url('test_route3', [2]) - assert link == '/it/test3/2/foo' - link = url('test_route3', [2, 'bar']) - assert link == '/it/test3/2/foo/bar' - link = url('test_route3', [2, 'bar', 'json']) - assert link == '/it/test3/2/foo/bar.json' - link = url( - 'test_route3', [2, 'bar', 'json'], {'foo': 'bar', 'bar': 'foo'}) - lsplit = link.split('?') - assert lsplit[0] == '/it/test3/2/foo/bar.json' - assert lsplit[1] in ['foo=bar&bar=foo', 'bar=foo&foo=bar'] + with current_ctx(app, "/") as ctx: + ctx.request.language = "it" + link = url("test_route") + assert link == "/it/test_route" + link = url("test_route2") + assert link == "/it/test2" + link = url("test_route2", [2]) + assert link == "/it/test2/2" + link = url("test_route2", [2, "foo"]) + assert link == "/it/test2/2/foo" + link = url("test_route3") + assert link == "/it/test3" + link = url("test_route3", [2]) + assert link == "/it/test3/2/foo" + link = url("test_route3", [2, "bar"]) + assert link == "/it/test3/2/foo/bar" + link = url("test_route3", [2, "bar", "json"]) + assert link == "/it/test3/2/foo/bar.json" + link = url("test_route3", [2, "bar", "json"], {"foo": "bar", "bar": "foo"}) + lsplit = link.split("?") + assert lsplit[0] == "/it/test3/2/foo/bar.json" + assert lsplit[1] in ["foo=bar&bar=foo", "bar=foo&foo=bar"] def test_global_url_prefix(app): - app._router_http._prefix_main = '/foo' + app._router_http._prefix_main = "/foo" app._router_http._prefix_main_len = 3 - with current_ctx(app, '/') as ctx: + with current_ctx(app, "/") as ctx: app.config.static_version_urls = False - ctx.request.language = 'en' + ctx.request.language = "en" - link = url('test_route') - assert link == '/foo/test_route' + link = url("test_route") + assert link == "/foo/test_route" - link = url('static', 'js/foo.js') - assert link == '/foo/static/js/foo.js' + link = url("static", "js/foo.js") + assert link == "/foo/static/js/foo.js" app.config.static_version_urls = True - app.config.static_version = '1.0.0' + app.config.static_version = "1.0.0" - link = url('static', 'js/foo.js') - assert link == '/foo/static/_1.0.0/js/foo.js' + link = url("static", "js/foo.js") + assert link == "/foo/static/_1.0.0/js/foo.js" app.config.static_version_urls = False - ctx.request.language = 'it' + ctx.request.language = "it" - link = url('test_route') - assert link == '/foo/it/test_route' + link = url("test_route") + assert link == "/foo/it/test_route" - link = url('static', 'js/foo.js') - assert link == '/foo/it/static/js/foo.js' + link = url("static", "js/foo.js") + assert link == "/foo/it/static/js/foo.js" app.config.static_version_urls = True - app.config.static_version = '1.0.0' + app.config.static_version = "1.0.0" - link = url('static', 'js/foo.js') - assert link == '/foo/it/static/_1.0.0/js/foo.js' + link = url("static", "js/foo.js") + assert link == "/foo/it/static/_1.0.0/js/foo.js" diff --git a/tests/test_templates.py b/tests/test_templates.py index b4acf839..cfeb6511 100644 --- a/tests/test_templates.py +++ b/tests/test_templates.py @@ -1,21 +1,21 @@ # -*- coding: utf-8 -*- """ - tests.templates - --------------- +tests.templates +--------------- - Test Emmett templating module +Test Emmett templating module """ import pytest - from helpers import current_ctx + from emmett import App -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def app(): app = App(__name__) - app.config.templates_escape = 'all' + app.config.templates_escape = "all" app.config.templates_adjust_indent = False app.config.templates_auto_reload = True return app @@ -24,31 +24,34 @@ def app(): def test_helpers(app): templater = app.templater r = templater._render(source="{{ include_helpers }}") - assert r == '' + \ - '' + assert ( + r + == '' + + '' + ) def test_meta(app): - with current_ctx('/', app) as ctx: + with current_ctx("/", app) as ctx: ctx.response.meta.foo = "bar" ctx.response.meta_prop.foo = "bar" templater = app.templater - r = templater._render( - source="{{ include_meta }}", - context={'current': ctx}) - assert r == '' + \ - '' + r = templater._render(source="{{ include_meta }}", context={"current": ctx}) + assert r == '' + '' def test_static(app): templater = app.templater s = "{{include_static 'foo.js'}}\n{{include_static 'bar.css'}}" r = templater._render(source=s) - assert r == '\n' + assert ( + r + == '\n' + ) rendered_value = """ @@ -78,20 +81,19 @@ def test_static(app): """.format( - helpers="".join([ - f'' - for name in ["jquery.min.js", "helpers.js"] - ]) + helpers="".join( + [ + f'' + for name in ["jquery.min.js", "helpers.js"] + ] + ) ) def test_render(app): - with current_ctx('/', app) as ctx: - ctx.language = 'it' - r = app.templater.render( - 'test.html', { - 'current': ctx, 'posts': [{'title': 'foo'}, {'title': 'bar'}] - } + with current_ctx("/", app) as ctx: + ctx.language = "it" + r = app.templater.render("test.html", {"current": ctx, "posts": [{"title": "foo"}, {"title": "bar"}]}) + assert "\n".join([l.strip() for l in r.splitlines() if l.strip()]) == "\n".join( + [l.strip() for l in rendered_value[1:].splitlines()] ) - assert "\n".join([l.strip() for l in r.splitlines() if l.strip()]) == \ - "\n".join([l.strip() for l in rendered_value[1:].splitlines()]) diff --git a/tests/test_translator.py b/tests/test_translator.py index 925cbbcf..ac868035 100644 --- a/tests/test_translator.py +++ b/tests/test_translator.py @@ -1,9 +1,9 @@ # -*- coding: utf-8 -*- """ - tests.translator - ---------------- +tests.translator +---------------- - Test Emmett translator module +Test Emmett translator module """ import pytest @@ -13,17 +13,17 @@ from emmett.locals import T -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def app(): return App(__name__) def _make_translation(language): - return str(T('partly cloudy', lang=language)) + return str(T("partly cloudy", lang=language)) def test_translation(app): - current.language = 'en' - assert _make_translation('it') == 'nuvolosità variabile' - assert _make_translation('de') == 'teilweise bewölkt' - assert _make_translation('ru') == 'переменная облачность' + current.language = "en" + assert _make_translation("it") == "nuvolosità variabile" + assert _make_translation("de") == "teilweise bewölkt" + assert _make_translation("ru") == "переменная облачность" diff --git a/tests/test_utils.py b/tests/test_utils.py index 4228c3d3..18a2f081 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,9 +1,9 @@ # -*- coding: utf-8 -*- """ - tests.utils - ----------- +tests.utils +----------- - Test Emmett utils engine +Test Emmett utils engine """ from emmett.datastructures import sdict @@ -11,29 +11,29 @@ def test_is_valid_ip_address(): - result_localhost = is_valid_ip_address('127.0.0.1') + result_localhost = is_valid_ip_address("127.0.0.1") assert result_localhost is True - result_unknown = is_valid_ip_address('unknown') + result_unknown = is_valid_ip_address("unknown") assert result_unknown is False - result_ipv4_valid = is_valid_ip_address('::ffff:192.168.0.1') + result_ipv4_valid = is_valid_ip_address("::ffff:192.168.0.1") assert result_ipv4_valid is True - result_ipv4_valid = is_valid_ip_address('192.168.256.1') + result_ipv4_valid = is_valid_ip_address("192.168.256.1") assert result_ipv4_valid is False - result_ipv6_valid = is_valid_ip_address('fd40:363d:ee85::') + result_ipv6_valid = is_valid_ip_address("fd40:363d:ee85::") assert result_ipv6_valid is True - result_ipv6_valid = is_valid_ip_address('fd40:363d:ee85::1::') + result_ipv6_valid = is_valid_ip_address("fd40:363d:ee85::1::") assert result_ipv6_valid is False def test_dict_to_sdict(): - result_sdict = dict_to_sdict({'test': 'dict'}) + result_sdict = dict_to_sdict({"test": "dict"}) assert isinstance(result_sdict, sdict) - assert result_sdict.test == 'dict' + assert result_sdict.test == "dict" result_number = dict_to_sdict(1) assert not isinstance(result_number, sdict) diff --git a/tests/test_validators.py b/tests/test_validators.py index 2deebd93..fb343529 100644 --- a/tests/test_validators.py +++ b/tests/test_validators.py @@ -1,21 +1,47 @@ # -*- coding: utf-8 -*- """ - tests.validators - ---------------- +tests.validators +---------------- - Test Emmett validators over pyDAL. +Test Emmett validators over pyDAL. """ +from datetime import datetime, timedelta + import pytest -from datetime import datetime, timedelta from emmett import App, sdict -from emmett.orm import Database, Model, Field, has_many, belongs_to +from emmett.orm import Database, Field, Model, belongs_to, has_many from emmett.validators import ( - isEmptyOr, hasLength, isInt, isFloat, isDate, isTime, isDatetime, isJSON, - isntEmpty, inSet, inDB, isEmail, isUrl, isIP, isImage, inRange, Equals, - Lower, Upper, Cleanup, Urlify, Crypt, notInDB, Allow, Not, Matches, Any, - isList) + Allow, + Any, + Cleanup, + Crypt, + Equals, + Lower, + Matches, + Not, + Upper, + Urlify, + hasLength, + inDB, + inRange, + inSet, + isDate, + isDatetime, + isEmail, + isEmptyOr, + isFloat, + isImage, + isInt, + isIP, + isJSON, + isList, + isntEmpty, + isTime, + isUrl, + notInDB, +) class A(Model): @@ -49,11 +75,9 @@ class B(Model): tablename = "b" a = Field() - b = Field(validation={'len': {'gte': 5}}) + b = Field(validation={"len": {"gte": 5}}) - validation = { - 'a': {'len': {'gte': 5}} - } + validation = {"a": {"len": {"gte": 5}}} class Consist(Model): @@ -65,12 +89,12 @@ class Consist(Model): emailsplit = Field.string_list() validation = { - 'email': {'is': 'email'}, - 'url': {'is': 'url'}, - 'ip': {'is': 'ip'}, - 'image': {'is': 'image'}, - 'emails': {'is': 'list:email'}, - 'emailsplit': {'is': {'list:email': {'splitter': ',;'}}} + "email": {"is": "email"}, + "url": {"is": "url"}, + "ip": {"is": "ip"}, + "image": {"is": "image"}, + "emails": {"is": "list:email"}, + "emailsplit": {"is": {"list:email": {"splitter": ",;"}}}, } @@ -81,10 +105,10 @@ class Len(Model): d = Field() validation = { - 'a': {'len': 5}, - 'b': {'len': {'gt': 4, 'lt': 13}}, - 'c': {'len': {'gte': 5, 'lte': 12}}, - 'd': {'len': {'range': (5, 13)}} + "a": {"len": 5}, + "b": {"len": {"gt": 4, "lt": 13}}, + "c": {"len": {"gte": 5, "lte": 12}}, + "d": {"len": {"range": (5, 13)}}, } @@ -92,10 +116,7 @@ class Inside(Model): a = Field() b = Field.int() - validation = { - 'a': {'in': ['a', 'b']}, - 'b': {'in': {'range': (1, 5)}} - } + validation = {"a": {"in": ["a", "b"]}, "b": {"in": {"range": (1, 5)}}} class Num(Model): @@ -103,11 +124,7 @@ class Num(Model): b = Field.int() c = Field.int() - validation = { - 'a': {'gt': 0}, - 'b': {'lt': 5}, - 'c': {'gt': 0, 'lte': 4} - } + validation = {"a": {"gt": 0}, "b": {"lt": 5}, "c": {"gt": 0, "lte": 4}} class Proc(Model): @@ -119,12 +136,12 @@ class Proc(Model): f = Field.password() validation = { - 'a': {'lower': True}, - 'b': {'upper': True}, - 'c': {'clean': True}, - 'd': {'urlify': True}, - 'e': {'len': {'range': (6, 25)}, 'crypt': True}, - 'f': {'len': {'gt': 5, 'lt': 25}, 'crypt': 'md5'} + "a": {"lower": True}, + "b": {"upper": True}, + "c": {"clean": True}, + "d": {"urlify": True}, + "e": {"len": {"range": (6, 25)}, "crypt": True}, + "f": {"len": {"gt": 5, "lt": 25}, "crypt": "md5"}, } @@ -133,60 +150,47 @@ class Eq(Model): b = Field.int() c = Field.float() - validation = { - 'a': {'equals': 'asd'}, - 'b': {'equals': 2}, - 'c': {'not': {'equals': 2.4}} - } + validation = {"a": {"equals": "asd"}, "b": {"equals": 2}, "c": {"not": {"equals": 2.4}}} class Match(Model): a = Field() b = Field() - validation = { - 'a': {'match': 'ab'}, - 'b': {'match': {'expression': 'ab', 'strict': True}} - } + validation = {"a": {"match": "ab"}, "b": {"match": {"expression": "ab", "strict": True}}} class Anyone(Model): a = Field() - validation = { - 'a': {'any': {'is': 'email', 'in': ['foo', 'bar']}} - } + validation = {"a": {"any": {"is": "email", "in": ["foo", "bar"]}}} class Person(Model): - has_many('things') + has_many("things") - name = Field(validation={'empty': False}) - surname = Field(validation={'presence': True}) + name = Field(validation={"empty": False}) + surname = Field(validation={"presence": True}) class Thing(Model): - belongs_to('person') + belongs_to("person") name = Field() color = Field() uid = Field(unique=True) - validation = { - 'name': {'presence': True}, - 'color': {'in': ['blue', 'red']}, - 'uid': {'empty': False} - } + validation = {"name": {"presence": True}, "color": {"in": ["blue", "red"]}, "uid": {"empty": False}} class Allowed(Model): - a = Field(validation={'in': ['a', 'b'], 'allow': None}) - b = Field(validation={'in': ['a', 'b'], 'allow': 'empty'}) - c = Field(validation={'in': ['a', 'b'], 'allow': 'blank'}) + a = Field(validation={"in": ["a", "b"], "allow": None}) + b = Field(validation={"in": ["a", "b"], "allow": "empty"}) + c = Field(validation={"in": ["a", "b"], "allow": "blank"}) class Mixed(Model): - belongs_to('person') + belongs_to("person") date = Field.date() type = Field() @@ -197,27 +201,21 @@ class Mixed(Model): psw = Field.password() validation = { - 'date': {'format': '%d/%m/%Y', 'gt': lambda: datetime.utcnow().date()}, - 'type': {'in': ['a', 'b'], 'allow': None}, - 'inside': {'in': ['asd', 'lol']}, - 'number': {'allow': 'blank'}, - 'dont': {'empty': True}, - 'yep': {'presence': True}, - 'psw': {'len': {'range': (6, 25)}, 'crypt': True} + "date": {"format": "%d/%m/%Y", "gt": lambda: datetime.utcnow().date()}, + "type": {"in": ["a", "b"], "allow": None}, + "inside": {"in": ["asd", "lol"]}, + "number": {"allow": "blank"}, + "dont": {"empty": True}, + "yep": {"presence": True}, + "psw": {"len": {"range": (6, 25)}, "crypt": True}, } -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def db(): app = App(__name__) - db = Database( - app, config=sdict( - uri='sqlite://validators.db', auto_connect=True, - auto_migrate=True)) - db.define_models([ - A, AA, AAA, B, Consist, Len, Inside, Num, Eq, Match, Anyone, Proc, - Person, Thing, Allowed, Mixed - ]) + db = Database(app, config=sdict(uri="sqlite://validators.db", auto_connect=True, auto_migrate=True)) + db.define_models([A, AA, AAA, B, Consist, Len, Inside, Num, Eq, Match, Anyone, Proc, Person, Thing, Allowed, Mixed]) return db @@ -349,243 +347,243 @@ def test_allow(db): def test_validation(db): #: 'is' is_data = { - 'name': 'foo', - 'val': 1, - 'fval': 1.5, - 'text': 'Lorem ipsum', - 'password': 'notverysecret', - 'd': '{:%Y-%m-%d}'.format(datetime.utcnow()), - 't': '15:23', - 'dt': '2015-12-23T15:23:00', - 'json': '{}' + "name": "foo", + "val": 1, + "fval": 1.5, + "text": "Lorem ipsum", + "password": "notverysecret", + "d": "{:%Y-%m-%d}".format(datetime.utcnow()), + "t": "15:23", + "dt": "2015-12-23T15:23:00", + "json": "{}", } errors = A.validate(is_data) assert not errors d = dict(is_data) - d['val'] = 'foo' + d["val"] = "foo" errors = A.validate(d) - assert 'val' in errors and len(errors) == 1 + assert "val" in errors and len(errors) == 1 d = dict(is_data) - d['fval'] = 'bar' + d["fval"] = "bar" errors = A.validate(d) - assert 'fval' in errors and len(errors) == 1 + assert "fval" in errors and len(errors) == 1 d = dict(is_data) - d['d'] = 'foo' + d["d"] = "foo" errors = A.validate(d) - assert 'd' in errors and len(errors) == 1 + assert "d" in errors and len(errors) == 1 d = dict(is_data) - d['t'] = 'bar' + d["t"] = "bar" errors = A.validate(d) - assert 't' in errors and len(errors) == 1 + assert "t" in errors and len(errors) == 1 d = dict(is_data) - d['dt'] = 'foo' + d["dt"] = "foo" errors = A.validate(d) - assert 'dt' in errors and len(errors) == 1 + assert "dt" in errors and len(errors) == 1 d = dict(is_data) - d['json'] = 'bar' + d["json"] = "bar" errors = A.validate(d) - assert 'json' in errors and len(errors) == 1 - errors = Consist.validate({'email': 'foo'}) - assert 'email' in errors - errors = Consist.validate({'url': 'notanurl'}) - assert 'url' in errors - errors = Consist.validate({'url': 'http://domain.com/'}) - assert 'url' not in errors - errors = Consist.validate({'url': 'http://domain.com'}) - assert 'url' not in errors - errors = Consist.validate({'url': 'domain.com'}) - assert 'url' not in errors - errors = Consist.validate({'ip': 'foo'}) - assert 'ip' in errors - errors = Consist.validate({'emails': 'foo'}) - assert 'emails' in errors - errors = Consist.validate({'emailsplit': 'foo'}) - assert 'emailsplit' in errors - errors = Consist.validate({'emailsplit': '1@asd.com, 2@asd.com'}) - assert 'emailsplit' not in errors - errors = Consist.validate({'emails': ['1@asd.com', '2@asd.com']}) - assert 'emails' not in errors + assert "json" in errors and len(errors) == 1 + errors = Consist.validate({"email": "foo"}) + assert "email" in errors + errors = Consist.validate({"url": "notanurl"}) + assert "url" in errors + errors = Consist.validate({"url": "http://domain.com/"}) + assert "url" not in errors + errors = Consist.validate({"url": "http://domain.com"}) + assert "url" not in errors + errors = Consist.validate({"url": "domain.com"}) + assert "url" not in errors + errors = Consist.validate({"ip": "foo"}) + assert "ip" in errors + errors = Consist.validate({"emails": "foo"}) + assert "emails" in errors + errors = Consist.validate({"emailsplit": "foo"}) + assert "emailsplit" in errors + errors = Consist.validate({"emailsplit": "1@asd.com, 2@asd.com"}) + assert "emailsplit" not in errors + errors = Consist.validate({"emails": ["1@asd.com", "2@asd.com"]}) + assert "emails" not in errors #: 'len' - len_data = {'a': '12345', 'b': '12345', 'c': '12345', 'd': '12345'} + len_data = {"a": "12345", "b": "12345", "c": "12345", "d": "12345"} errors = Len.validate(len_data) assert not errors d = dict(len_data) - d['a'] = 'ciao' + d["a"] = "ciao" errors = Len.validate(d) - assert 'a' in errors and len(errors) == 1 + assert "a" in errors and len(errors) == 1 d = dict(len_data) - d['b'] = 'ciao' + d["b"] = "ciao" errors = Len.validate(d) - assert 'b' in errors and len(errors) == 1 + assert "b" in errors and len(errors) == 1 d = dict(len_data) - d['b'] = '1234567890123' + d["b"] = "1234567890123" errors = Len.validate(d) - assert 'b' in errors and len(errors) == 1 + assert "b" in errors and len(errors) == 1 d = dict(len_data) - d['c'] = 'ciao' + d["c"] = "ciao" errors = Len.validate(d) - assert 'c' in errors and len(errors) == 1 + assert "c" in errors and len(errors) == 1 d = dict(len_data) - d['c'] = '1234567890123' + d["c"] = "1234567890123" errors = Len.validate(d) - assert 'c' in errors and len(errors) == 1 + assert "c" in errors and len(errors) == 1 d = dict(len_data) - d['d'] = 'ciao' + d["d"] = "ciao" errors = Len.validate(d) - assert 'd' in errors and len(errors) == 1 + assert "d" in errors and len(errors) == 1 d = dict(len_data) - d['d'] = '1234567890123' + d["d"] = "1234567890123" errors = Len.validate(d) - assert 'd' in errors and len(errors) == 1 + assert "d" in errors and len(errors) == 1 #: 'in' - in_data = {'a': 'a', 'b': 2} + in_data = {"a": "a", "b": 2} errors = Inside.validate(in_data) assert not errors d = dict(in_data) - d['a'] = 'c' + d["a"] = "c" errors = Inside.validate(d) - assert 'a' in errors and len(errors) == 1 + assert "a" in errors and len(errors) == 1 d = dict(in_data) - d['b'] = 0 + d["b"] = 0 errors = Inside.validate(d) - assert 'b' in errors and len(errors) == 1 + assert "b" in errors and len(errors) == 1 d = dict(in_data) - d['b'] = 7 + d["b"] = 7 errors = Inside.validate(d) - assert 'b' in errors and len(errors) == 1 + assert "b" in errors and len(errors) == 1 #: 'gt', 'lt', 'gte', 'lte' - num_data = {'a': 1, 'b': 4, 'c': 2} + num_data = {"a": 1, "b": 4, "c": 2} errors = Num.validate(num_data) assert not errors d = dict(num_data) - d['a'] = 0 + d["a"] = 0 errors = Num.validate(d) - assert 'a' in errors and len(errors) == 1 + assert "a" in errors and len(errors) == 1 d = dict(num_data) - d['b'] = 5 + d["b"] = 5 errors = Num.validate(d) - assert 'b' in errors and len(errors) == 1 + assert "b" in errors and len(errors) == 1 d = dict(num_data) - d['c'] = 0 + d["c"] = 0 errors = Num.validate(d) - assert 'c' in errors and len(errors) == 1 + assert "c" in errors and len(errors) == 1 d = dict(num_data) - d['c'] = 5 + d["c"] = 5 errors = Num.validate(d) - assert 'c' in errors and len(errors) == 1 + assert "c" in errors and len(errors) == 1 #: 'equals' - eq_data = {'a': 'asd', 'b': 2, 'c': 2.3} + eq_data = {"a": "asd", "b": 2, "c": 2.3} errors = Eq.validate(eq_data) assert not errors d = dict(eq_data) - d['a'] = 'lol' + d["a"] = "lol" errors = Eq.validate(d) - assert 'a' in errors and len(errors) == 1 + assert "a" in errors and len(errors) == 1 d = dict(eq_data) - d['b'] = 3 + d["b"] = 3 errors = Eq.validate(d) - assert 'b' in errors and len(errors) == 1 + assert "b" in errors and len(errors) == 1 #: 'not' d = dict(eq_data) - d['c'] = 2.4 + d["c"] = 2.4 errors = Eq.validate(d) - assert 'c' in errors and len(errors) == 1 + assert "c" in errors and len(errors) == 1 #: 'match' - match_data = {'a': 'abc', 'b': 'ab'} + match_data = {"a": "abc", "b": "ab"} errors = Match.validate(match_data) assert not errors d = dict(match_data) - d['a'] = 'lol' + d["a"] = "lol" errors = Match.validate(d) - assert 'a' in errors and len(errors) == 1 + assert "a" in errors and len(errors) == 1 d = dict(match_data) - d['b'] = 'abc' + d["b"] = "abc" errors = Match.validate(d) - assert 'b' in errors and len(errors) == 1 - d['b'] = 'lol' + assert "b" in errors and len(errors) == 1 + d["b"] = "lol" errors = Match.validate(d) - assert 'b' in errors and len(errors) == 1 + assert "b" in errors and len(errors) == 1 #: 'any' - errors = Anyone.validate({'a': 'foo'}) + errors = Anyone.validate({"a": "foo"}) assert not errors - errors = Anyone.validate({'a': 'walter@massivedynamics.com'}) + errors = Anyone.validate({"a": "walter@massivedynamics.com"}) assert not errors - errors = Anyone.validate({'a': 'lol'}) - assert 'a' in errors + errors = Anyone.validate({"a": "lol"}) + assert "a" in errors #: 'allow' - allow_data = {'a': 'a', 'b': 'a', 'c': 'a'} + allow_data = {"a": "a", "b": "a", "c": "a"} errors = Allowed.validate(allow_data) assert not errors d = dict(allow_data) - d['a'] = None + d["a"] = None errors = Allowed.validate(d) assert not errors - d['a'] = 'foo' + d["a"] = "foo" errors = Allowed.validate(d) - assert 'a' in errors and len(errors) == 1 + assert "a" in errors and len(errors) == 1 d = dict(allow_data) - d['b'] = '' + d["b"] = "" errors = Allowed.validate(d) assert not errors - d['b'] = None + d["b"] = None errors = Allowed.validate(d) assert not errors - d['b'] = 'foo' + d["b"] = "foo" errors = Allowed.validate(d) - assert 'b' in errors and len(errors) == 1 + assert "b" in errors and len(errors) == 1 d = dict(allow_data) - d['c'] = '' + d["c"] = "" errors = Allowed.validate(d) assert not errors - d['c'] = None + d["c"] = None errors = Allowed.validate(d) assert not errors - d['c'] = 'foo' + d["c"] = "foo" errors = Allowed.validate(d) - assert 'c' in errors and len(errors) == 1 + assert "c" in errors and len(errors) == 1 #: processing validators - assert Proc.a.validate('Asd')[0] == 'asd' - assert Proc.b.validate('Asd')[0] == 'ASD' - assert Proc.d.validate('Two Words')[0] == 'two-words' - psw = str(Proc.e.validate('somepassword')[0]) - assert psw[:23] == 'pbkdf2(1000,20,sha512)$' - assert psw[23:] != 'somepassword' - psw = str(Proc.f.validate('somepassword')[0]) - assert psw[:4] == 'md5$' - assert psw[4:] != 'somepassword' + assert Proc.a.validate("Asd")[0] == "asd" + assert Proc.b.validate("Asd")[0] == "ASD" + assert Proc.d.validate("Two Words")[0] == "two-words" + psw = str(Proc.e.validate("somepassword")[0]) + assert psw[:23] == "pbkdf2(1000,20,sha512)$" + assert psw[23:] != "somepassword" + psw = str(Proc.f.validate("somepassword")[0]) + assert psw[:4] == "md5$" + assert psw[4:] != "somepassword" #: 'presence' - mario = {'name': 'mario'} + mario = {"name": "mario"} errors = Person.validate(mario) - assert 'surname' in errors + assert "surname" in errors assert len(errors) == 1 #: 'presence' with reference, 'unique' - thing = {'name': 'a', 'person': 5, 'color': 'blue', 'uid': 'lol'} + thing = {"name": "a", "person": 5, "color": "blue", "uid": "lol"} errors = Thing.validate(thing) - assert 'person' in errors + assert "person" in errors assert len(errors) == 1 - mario = {'name': 'mario', 'surname': 'draghi'} + mario = {"name": "mario", "surname": "draghi"} mario = Person.create(mario) assert len(mario.errors.keys()) == 0 assert mario.id == 1 - thing = {'name': 'euro', 'person': mario.id, 'color': 'red', 'uid': 'lol'} + thing = {"name": "euro", "person": mario.id, "color": "red", "uid": "lol"} thing = Thing.create(thing) assert len(thing.errors.keys()) == 0 - thing = {'name': 'euro2', 'person': mario.id, 'color': 'red', 'uid': 'lol'} + thing = {"name": "euro2", "person": mario.id, "color": "red", "uid": "lol"} errors = Thing.validate(thing) assert len(errors) == 1 - assert 'uid' in errors + assert "uid" in errors def test_multi(db): p = db.Person(name="mario") base_data = { - 'date': '{0:%d/%m/%Y}'.format(datetime.utcnow()+timedelta(days=1)), - 'type': 'a', - 'inside': 'asd', - 'number': 1, - 'yep': 'asd', - 'psw': 'password', - 'person': p.id + "date": "{0:%d/%m/%Y}".format(datetime.utcnow() + timedelta(days=1)), + "type": "a", + "inside": "asd", + "number": 1, + "yep": "asd", + "psw": "password", + "person": p.id, } #: everything ok res = Mixed.create(base_data) @@ -593,67 +591,67 @@ def test_multi(db): assert len(res.errors.keys()) == 0 #: invalid belongs vals = dict(base_data) - del vals['person'] + del vals["person"] res = Mixed.create(vals) assert res.id is None assert len(res.errors.keys()) == 1 - assert 'person' in res.errors + assert "person" in res.errors #: invalid date range vals = dict(base_data) - vals['date'] = '{0:%d/%m/%Y}'.format(datetime.utcnow()-timedelta(days=2)) + vals["date"] = "{0:%d/%m/%Y}".format(datetime.utcnow() - timedelta(days=2)) res = Mixed.create(vals) assert res.id is None assert len(res.errors.keys()) == 1 - assert 'date' in res.errors + assert "date" in res.errors #: invalid date format - vals['date'] = '76-12-1249' + vals["date"] = "76-12-1249" res = Mixed.create(vals) assert res.id is None assert len(res.errors.keys()) == 1 - assert 'date' in res.errors + assert "date" in res.errors #: invalid in vals = dict(base_data) - vals['type'] = ' ' + vals["type"] = " " res = Mixed.create(vals) assert res.id is None assert len(res.errors.keys()) == 1 - assert 'type' in res.errors + assert "type" in res.errors #: empty number vals = dict(base_data) - vals['number'] = None + vals["number"] = None res = Mixed.create(vals) assert res.id == 2 assert len(res.errors.keys()) == 0 #: invalid number vals = dict(base_data) - vals['number'] = 'asd' + vals["number"] = "asd" res = Mixed.create(vals) assert res.id is None assert len(res.errors.keys()) == 1 - assert 'number' in res.errors + assert "number" in res.errors #: invalid empty vals = dict(base_data) - vals['dont'] = '2' + vals["dont"] = "2" res = Mixed.create(vals) assert res.id is None assert len(res.errors.keys()) == 1 - assert 'dont' in res.errors + assert "dont" in res.errors #: invalid presence vals = dict(base_data) - vals['yep'] = '' + vals["yep"] = "" res = Mixed.create(vals) assert res.id is None assert len(res.errors.keys()) == 1 - assert 'yep' in res.errors + assert "yep" in res.errors #: invalid password vals = dict(base_data) - vals['psw'] = '' + vals["psw"] = "" res = Mixed.create(vals) assert res.id is None assert len(res.errors.keys()) == 1 - assert 'psw' in res.errors - vals['psw'] = 'aksjfalsdkjflkasjdflkajsldkjfalslkdfjaslkdjf' + assert "psw" in res.errors + vals["psw"] = "aksjfalsdkjflkasjdflkajsldkjfalslkdfjaslkdjf" res = Mixed.create(vals) assert res.id is None assert len(res.errors.keys()) == 1 - assert 'psw' in res.errors + assert "psw" in res.errors diff --git a/tests/test_wrappers.py b/tests/test_wrappers.py index dd9b0468..a600d449 100644 --- a/tests/test_wrappers.py +++ b/tests/test_wrappers.py @@ -1,36 +1,37 @@ # -*- coding: utf-8 -*- """ - tests.wrappers - -------------- +tests.wrappers +-------------- - Test Emmett wrappers module +Test Emmett wrappers module """ -from helpers import current_ctx from emmett_core.protocols.rsgi.test_client.scope import ScopeBuilder +from helpers import current_ctx + from emmett.rsgi.wrappers import Request from emmett.wrappers.response import Response def test_request(): scope, _ = ScopeBuilder( - path='/?foo=bar', - method='GET', + path="/?foo=bar", + method="GET", ).get_data() request = Request(scope, None, None) - assert request.query_params == {'foo': 'bar'} - assert request.client == '127.0.0.1' + assert request.query_params == {"foo": "bar"} + assert request.client == "127.0.0.1" def test_response(): response = Response() assert response.status == 200 - assert response.headers['content-type'] == 'text/plain' + assert response.headers["content-type"] == "text/plain" def test_req_ctx(): - with current_ctx('/?foo=bar') as ctx: + with current_ctx("/?foo=bar") as ctx: assert isinstance(ctx.request, Request) assert isinstance(ctx.response, Response)