diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml index 54264462..6e0de274 100644 --- a/.github/FUNDING.yml +++ b/.github/FUNDING.yml @@ -1 +1,2 @@ github: [gi0baro] +polar: emmett-framework diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 00000000..b03c1233 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,7 @@ +version: 2 +updates: + - package-ecosystem: "github-actions" + target-branch: "master" + directory: "/" + schedule: + interval: "monthly" diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml new file mode 100644 index 00000000..70de72bc --- /dev/null +++ b/.github/workflows/lint.yml @@ -0,0 +1,30 @@ +name: lint + +on: + pull_request: + types: [opened, synchronize] + branches: + - master + +env: + PYTHON_VERSION: 3.12 + +jobs: + lint: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + - name: Set up Python ${{ env.PYTHON_VERSION }} + uses: actions/setup-python@v5 + with: + python-version: ${{ env.PYTHON_VERSION }} + - name: Install uv + uses: astral-sh/setup-uv@v3 + - name: Install dependencies + run: | + uv sync --dev + - name: Lint + run: | + source .venv/bin/activate + make lint diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index b72c4ff7..dbd96919 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -7,20 +7,24 @@ on: jobs: publish: runs-on: ubuntu-latest + environment: + name: pypi + url: https://pypi.org/p/emmett + permissions: + id-token: write steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - name: Set up Python - uses: actions/setup-python@v2 + uses: actions/setup-python@v5 with: - python-version: 3.9 - - name: Install and configure Poetry - uses: gi0baro/setup-poetry-bin@v1 - with: - virtualenvs-in-project: true - - name: Publish + python-version: 3.12 + - name: Install uv + uses: astral-sh/setup-uv@v3 + - name: Build distributions run: | - poetry build - poetry publish - env: - POETRY_PYPI_TOKEN_PYPI: ${{ secrets.PYPI_TOKEN }} + uv build + - name: Publish package to pypi + uses: pypa/gh-action-pypi-publish@release/v1 + with: + skip-existing: true diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 8f7ae027..0381e55d 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -3,17 +3,20 @@ name: Tests on: push: branches: - - "**" - tags-ignore: - - "**" + - master + - release pull_request: + types: [opened, synchronize] + branches: + - master jobs: Linux: runs-on: ubuntu-latest strategy: + fail-fast: false matrix: - python-version: [3.8, 3.9, '3.10', '3.11', '3.12'] + python-version: [3.8, 3.9, '3.10', '3.11', '3.12', '3.13'] services: postgres: @@ -25,68 +28,63 @@ jobs: - 5432:5432 steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - - name: Install and configure Poetry - uses: gi0baro/setup-poetry-bin@v1 - with: - virtualenvs-in-project: true + - name: Install uv + uses: astral-sh/setup-uv@v3 - name: Install dependencies run: | - poetry install -v + uv sync --dev - name: Test env: POSTGRES_URI: postgres:postgres@localhost:5432/test run: | - poetry run pytest -v tests + uv run pytest -v tests MacOS: - runs-on: macos-13 + runs-on: macos-latest strategy: + fail-fast: false matrix: - python-version: [3.8, 3.9, '3.10', '3.11', '3.12'] + python-version: [3.8, 3.9, '3.10', '3.11', '3.12', '3.13'] steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - - name: Install and configure Poetry - uses: gi0baro/setup-poetry-bin@v1 - with: - virtualenvs-in-project: true + - name: Install uv + uses: astral-sh/setup-uv@v3 - name: Install dependencies run: | - poetry install -v + uv sync --dev - name: Test run: | - poetry run pytest -v tests + uv run pytest -v tests Windows: runs-on: windows-latest strategy: + fail-fast: false matrix: - python-version: [3.8, 3.9, '3.10', '3.11', '3.12'] + python-version: [3.8, 3.9, '3.10', '3.11', '3.12', '3.13'] steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - - name: Install and configure Poetry - uses: gi0baro/setup-poetry-bin@v1 - with: - virtualenvs-in-project: true + - name: Install uv + uses: astral-sh/setup-uv@v3 - name: Install dependencies - shell: bash run: | - poetry install -v + uv sync --dev - name: Test shell: bash run: | - poetry run pytest -v tests + uv run pytest -v tests diff --git a/.gitignore b/.gitignore index 488a0efd..b322a6fb 100644 --- a/.gitignore +++ b/.gitignore @@ -13,6 +13,7 @@ build/* dist/* Emmett.egg-info/* poetry.lock +uv.lock examples/*/databases examples/*/logs diff --git a/CHANGES.md b/CHANGES.md index 5364d5b7..bff6a498 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,6 +1,21 @@ Emmett changelog ================ +Version 2.6 +----------- + +Released on October 14th 2024, codename Gutenberg + +- Added Python 3.13 support +- Re-implemented router matching algorithm in Rust +- Re-implemented multipart parsing in Rust +- Added `Response` wrap methods +- Refactored `Request.files` implementation +- Support iteration on `Request.body` +- Added `iter`, `aiter`, `http` and `snippet` to routes' outputs +- Testing client is now using RSGI protocol in place of ASGI +- Logger now uses stdout in place of files under default configuration + Version 2.5 ----------- diff --git a/LICENSE b/LICENSE index 67f51855..a048352d 100644 --- a/LICENSE +++ b/LICENSE @@ -26,21 +26,3 @@ PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - - -Some parts in Emmett's components are derived from web2py framework, -and contains the original copyright credits: - Copyrighted (c) by Massimo Di Pierro (2007-2013) - -Due to the original license limitations, some of these components are -distributed under LGPLv3 license, in particular: - - * emmett/security.py - * emmett/validators - -For futher details, check out the web2py's license: - https://github.com/web2py/web2py/blob/master/LICENSE - - -Emmett also contains third party software in the 'libs' directory: each -file/module in this directory is distributed under its original license. diff --git a/Makefile b/Makefile new file mode 100644 index 00000000..14fd8afd --- /dev/null +++ b/Makefile @@ -0,0 +1,19 @@ +.DEFAULT_GOAL := all +pysources = emmett tests + +.PHONY: format +format: + ruff check --fix $(pysources) + ruff format $(pysources) + +.PHONY: lint +lint: + ruff check $(pysources) + ruff format --check $(pysources) + +.PHONY: test +test: + pytest -v tests + +.PHONY: all +all: format lint test diff --git a/README.md b/README.md index c89bed1d..97f4a359 100644 --- a/README.md +++ b/README.md @@ -38,12 +38,9 @@ async def todo(): return {'tasks': tasks} ``` -[![pip version](https://img.shields.io/pypi/v/emmett.svg?style=flat)](https://pypi.python.org/pypi/emmett) -![Tests Status](https://github.com/emmett-framework/emmett/workflows/Tests/badge.svg) - ## Documentation -The documentation is available at [https://emmett.sh/docs](https://emmett.sh/docs). +The documentation is available at [https://emmett.sh/docs](https://emmett.sh/docs). The sources are available under the [docs folder](https://github.com/emmett-framework/emmett/tree/master/docs). ## Examples @@ -76,6 +73,4 @@ We would be very glad if you contributed to the project in one or all of these w Emmett is released under the BSD License. -However, due to original license limitations, some components are included -in Emmett under their original licenses. Please check the LICENSE file for -more details. +However, due to original license limitations, contents under [validators](https://github.com/emmett-framework/emmett/tree/master/emmett/validators) and [libs](https://github.com/emmett-framework/emmett/tree/master/emmett/libs) are included in Emmett under their original licenses. Please check the source code for more details. diff --git a/docs/forms.md b/docs/forms.md index aa027923..539482b7 100644 --- a/docs/forms.md +++ b/docs/forms.md @@ -91,6 +91,7 @@ Here is the complete list of parameters accepted by `Form` class: > **Note:** the `fields` and `exclude_fields` parameters should not be used together. If you need to hide just a few fields, you'd better using the `exclude_fields`, and you should use `fields` if you have to show only few table fields. The advantages of these parameters are lost if you use both. ### Uploads with forms + As we saw above, the `upload` parameter of forms needs an URL for download. Let's focus a bit on uploads and see an example to completely understand this requirement. Let's say you want to handle the upload of avatar images from your user. So, in your model you would have an upload field: @@ -99,14 +100,14 @@ Let's say you want to handle the upload of avatar images from your user. So, in avatar = Field.upload() ``` -and the forms produced by Emmett will handle uploads for you. How would you display this image in your template? You need a streaming function like this: +and the forms produced by Emmett will handle uploads for you. How would you display this image in your template? You need a route which will send back the uploaded files' contents: ```python -from emmett.helpers import stream_dbfile +from emmett import response @app.route("/download/") async def download(filename): - stream_dbfile(db, filename) + return response.wrap_dbfile(db, filename) ``` and then, in your template, you can create an `img` tag pointing to the `download` function you've just exposed: diff --git a/docs/html.md b/docs/html.md new file mode 100644 index 00000000..a87cedfb --- /dev/null +++ b/docs/html.md @@ -0,0 +1,94 @@ +HTML without templates +====================== + +As we saw in the [templates chapter](./templates), Emmett comes with a template engine out of the box, which you can use to render HTML. + +Under specific circumstances though, it might be convenient generating HTML directly in your route code, using the Python language. To support these scenarios, Emmett provides few helpers under the `html` module. Let's see them in details. + +The `tag` helper +---------------- + +The `tag` object is the main interface provided by Emmett to produce HTML contents from Python code. It dinamically produces HTML elements based on its attributes, so you can produce even custom elements: + +```python +from emmett.html import tag + +# an empty

+p = tag.p() +# a custom element +card = tag.card() +# a custom element +list_item = tag["list-item"]() +``` + +Every element produced by the `tag` helper accepts both nested contents and attributes, with the caveat HTML attributes needs to start with `_`: + +```python +#

Hello world

+p = tag.p("Hello world") +#

bar

+div = tag.div(tag.p("bar"), _class="foo") +``` + +> **Note:** the reasons behind the underscore notation for HTML attributes are mainly: +> - to avoid issues with Python reserved words (eg: `class`) +> - to keep the ability to set custom attributes on the HTML objects in Python code but prevent those attributes to be rendered + +Mind that the `tag` helper already takes care of *self-closing* elements and escaping contents, so you don't have to worry about those. + +> – That's cool dude, but what if I need to set several attributes with the same prefix? +> – *Like with HTMX? Sure, just use a dictionary* + +```python +# +btn = tag.button( + "Click me", + _hx={ + "post": url("clicked"), + "swap": "outerHTML" + } +) +``` + +The `cat` helper +---------------- + +Sometimes you may need to stack together HTML elements without a parent. For such cases, the `cat` helper can be used: + +```python +from emmett.html import cat, tag + +#

hello

world

+multi_p = cat(tag.p("hello"), tag.p("world")) +``` + +Building deep stacks +-------------------- + +All the elements produced with the `tag` helper supports `with` statements, so you can easily manage even complicated stacks. For instance the following code: + +```python +root = tag.div(_class="root") +with root: + with tag.div(_class="lv1"): + with tag.div(_class="lvl2"): + tag.p("foo") + tag.p("bar") + +str(root) +``` + +will produce the following HTML: + +```html +
+
+
+

foo

+

bar

+
+
+
+``` + +> **Note:** when compared to templates, HTML generation from Python will be noticeably slower. For cases in which you want to render long and almost static HTML contents, using templates is preferable. diff --git a/docs/quickstart.md b/docs/quickstart.md index ce8b3240..0de8565f 100644 --- a/docs/quickstart.md +++ b/docs/quickstart.md @@ -244,6 +244,8 @@ The output will be a JSON object with the converted content of your Python dicti The `service` module has other helpers, like *XML* format: go further in the [Services chapter](./services) of the documentation. +Also, Emmett allows you to respond with streams and files: check the [Responses chapter](./response#wrapping-methods) of the documentation for further details. + Dealing with requests --------------------- diff --git a/docs/request.md b/docs/request.md index 1b7ec111..f32a33f2 100644 --- a/docs/request.md +++ b/docs/request.md @@ -41,15 +41,18 @@ Now, let's see how to deal with request variables. ### Request variables -Emmett's `request` object also provides three important attributes about the active request: +*Changed in version 2.6* + +Emmett's `request` object also provides four important attributes about the active request: | attribute | awaitable | description | | --- | --- | --- | | query_params | no | contains the URL query parameters | +| body | yes | contains the raw (bytes) request body | | body_params | yes | contains parameters passed into the request body | | files | yes | contains files passed into the request body | -All three attributes are `sdict` objects and they work in the same way, within the exception of requiring `await` or not, and an example may help you understand their dynamic: +All the attributes but `body` are `sdict` objects and they work in the same way, within the exception of requiring `await` or not, and an example may help you understand their dynamic: ```python from emmett import App, request @@ -90,6 +93,50 @@ Simple: the `request`'s params attributes will look like this: You can always access the variables you need. +#### Request files + +The `files` attribute works in the same way of `body_params` for multipart requests, but its values are objects wrapping the underlying file. + +These objects have some useful attributes, specifically: + +| attribute | description | +| --- | --- | +| filename | name of the file | +| content\_type | MIME type of the file | +| size | file size | + +Also, these object provides two methods to interact with the file contents: the `read` method, which allows you to load the file content, and the async `save` method, which allows you to directly store the file contents into a file-like object. + +```python +@app.route() +async def multipart_load(): + files = await request.files + # at this point you can either: + # i) read all the file contents + data = files.myfile.read() + # ii) read up to 4k of the file contents + data = files.myfile.read(4096) + # iii) store the file + await files.myfile.save(f"some/destination/{files.myfile.filename}") +``` + +#### Working with raw requests' bodies + +The `body` attribute gives you direct access to the request body. It's an awaitable object, so you can either load the whole body or iterate over it: + +```python +@app.route() +async def post(): + raw_body = await request.body + +@app.route() +async def iterpost(): + async for raw_chunk in request.body: + # do something +``` + +> **Note:** you cannot mix the two approaches. Also, directly interacting with `body` will prevent you to use `body_params` and `files` within the same request. + Errors and redirects -------------------- diff --git a/docs/response.md b/docs/response.md index 8b7545ac..55cf2e24 100644 --- a/docs/response.md +++ b/docs/response.md @@ -78,6 +78,61 @@ Then, in your template, you can just write: and you will have all the meta tags included in your HTML. +Wrapping methods +---------------- + +*New in version 2.6* + +Emmett `Response` object also provides some *wrapping* methods in order to respond with files or streams of data, specifically: + +- `wrap_iter` +- `wrap_aiter` +- `wrap_file` +- `wrap_io` + +These methods can be used to produce responses from iterators and files. + +### Iterable responses + +The `wrap_iter` and `wrap_aiter` methods are very similar, both accepts iterables: you can use the latter for asynchronous iterators: + +```python +def iterator(): + for _ in range(3): + yield b"hello" + +async def aiterator(): + for _ in range(3): + yield b"hello" + +@app.route() +async def response_iter(): + return response.wrap_iter(iterator()) + +@app.route() +async def response_aiter(): + return response.wrap_aiter(aiterator()) +``` + +### File responses + +You can produce responses from file using two different methods in Emmett: + +- `wrap_file` when you want to create a response from a path +- `wrap_io` when you want to create a response from a *file-like* object + +```python +@app.route("/file/") +async def file(name): + return response.wrap_file(f"assets/{name}") + + +@app.route("/io/") +async def io(name): + with open(f"assets/{name}", "r") as f: + return response.wrap_io(f) +``` + Message flashing ---------------- diff --git a/docs/routing.md b/docs/routing.md index b0dc4bbc..1dee90ee 100644 --- a/docs/routing.md +++ b/docs/routing.md @@ -110,7 +110,7 @@ folder. When you need to use a different template name, just tell Emmett to load ### Output -*New in version 2.0* +*Changed in version 2.6* The `output` parameter can be used to increase Emmett's performance in building the proper response from the exposed function. Here is the list of accepted outputs: @@ -120,6 +120,10 @@ The `output` parameter can be used to increase Emmett's performance in building | bytes | `bytes` string return value | | str | `str` return value | | template | `dict` return value to be used in templates | +| snippet | `tuple` return value composed by a template string and a `dict` | +| iter | iterable (of `bytes`) return value | +| aiter | async iterable (of `bytes`) return value | +| http | `HTTPResponse` return value | Under normal circumstances, the default behaviour is the best for most of usage cases. diff --git a/docs/templates.md b/docs/templates.md index 48728e89..2b3b391e 100644 --- a/docs/templates.md +++ b/docs/templates.md @@ -54,6 +54,18 @@ code is that you have to write `pass` after the statements to tell Emmett where the Python block ends. Normally, Python uses indentation for this, but HTML is not structured the same way and just undoing the indentation would be ambiguous. +### Template snippets + +*New in version 2.6* + +For cases in which you want to render just a simple template block, Emmett also support *snippets*, which avoid the need of creating a template file: + +```python +@app.route("/div/") +async def echo_div(msg): + return '
{{ =message }}
', {"message": msg} +``` + Template structure ------------------- diff --git a/docs/tree.yml b/docs/tree.yml index fa7d76ce..06aea704 100644 --- a/docs/tree.yml +++ b/docs/tree.yml @@ -6,6 +6,7 @@ - app_and_modules - routing - templates +- html - request - response - websocket diff --git a/docs/upgrading.md b/docs/upgrading.md index f1960472..46885279 100644 --- a/docs/upgrading.md +++ b/docs/upgrading.md @@ -13,6 +13,75 @@ Just as a remind, you can update Emmett using *pip*: $ pip install -U emmett ``` +Version 2.6 +----------- + +Emmett 2.6 release is focused on modernising part of the codebase. In this release we also rewrote part of the router and the request parsers in Rust, providing additional performance gains to all kind of applications. + +This release introduces some minor breaking changes and few deprecations, while introducing some new features. + +### Breaking changes + +#### Request files' content is now spooled + +Prior to Emmett 2.6 the contents of `Request.files` were loaded in memory on parsing. This might have led to issues with memory allocations, thus in 2.6 the file contents are spooled to temprorary files on disk. +The main consequence for this change is that code relying on the old behavior is now subject to errors. This only involves code relying on the previous – and undocumented – `stream` attribute of the `files` object values; all the other interfaces (iteration, `save` method, etc.) are still the same. +In case your application code falls into this scope, you should change the involved lines accordingly: + +```python +files = await request.files +# prior to 2.6 +data = files.myfile.stream.read() +# from 2.6 +data = files.myfile.read() +``` + +#### Default logger configuration + +With Emmett 2.6 the default logger configuration will now use the standard output rather than a rotating file handler. + +This is considered a *minor breaking change*, as it involves the default configuration, and thus can be set to the previous one: + +```python +app.config.logging.production.file.no = 4 +app.config.logging.production.file.max_size = 5 * 1024 * 1024 +``` + +### Deprecations + +#### Stream helpers + +Stream helpers like `stream_file` and `stream_dbfile` are now deprecated in favour of the newly introduced [response wrap methods](./response#wrapping-methods). + +Code involving this methods like: + +```python +from emmett.helpers import stream_dbfile + +@app.route("/download/") +async def download(filename): + stream_dbfile(db, filename) +``` + +should be converted to the new format: + +```python +from emmett import response + +@app.route("/download/") +async def download(filename): + return response.wrap_dbfile(db, filename) +``` + +### New features + +- [Response *wrap* methods](./response#wrapping-methods) +- Ability to return [template snippets](./templates#template-snippets) in routes +- Support for `iter`, `aiter` and `http` [route outputs](./routing#output) +- The request body now [supports iteration](./request#request-variables) + +Emmett 2.6 also introduces support for Python 3.13. + Version 2.5 ----------- diff --git a/emmett/__init__.py b/emmett/__init__.py index 018add3f..0ed4bb2f 100644 --- a/emmett/__init__.py +++ b/emmett/__init__.py @@ -1,3 +1,4 @@ +from . import _internal from .app import App, AppModule from .cache import Cache from .ctx import current @@ -8,5 +9,5 @@ from .http import redirect from .locals import T, now, request, response, session, websocket from .orm import Field -from .pipeline import Pipe, Injector +from .pipeline import Injector, Pipe from .routing.urls import url diff --git a/emmett/__main__.py b/emmett/__main__.py index 6dedb24f..5e4c7d2b 100644 --- a/emmett/__main__.py +++ b/emmett/__main__.py @@ -1,14 +1,15 @@ # -*- coding: utf-8 -*- """ - emmett.__main__ - ---------------- +emmett.__main__ +---------------- - Alias for Emmett CLI. +Alias for Emmett CLI. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ from emmett.cli import main + main(as_module=True) diff --git a/emmett/__version__.py b/emmett/__version__.py index 177f9a09..e5e59e38 100644 --- a/emmett/__version__.py +++ b/emmett/__version__.py @@ -1 +1 @@ -__version__ = "2.5.13" +__version__ = "2.6.0" diff --git a/emmett/_internal.py b/emmett/_internal.py index 15f4bf6d..1b1aae5a 100644 --- a/emmett/_internal.py +++ b/emmett/_internal.py @@ -1,392 +1,35 @@ # -*- coding: utf-8 -*- """ - emmett._internal - ---------------- +emmett._internal +---------------- - Provides internally used helpers and objects. +Provides internally used helpers and objects. - :copyright: 2014 Giovanni Barillari +:copyright: 2014 Giovanni Barillari - Several parts of this code comes from Flask and Werkzeug. - :copyright: (c) 2014 by Armin Ronacher. +Several parts of this code comes from Flask and Werkzeug. +:copyright: (c) 2014 by Armin Ronacher. - :license: BSD-3-Clause +:license: BSD-3-Clause """ from __future__ import annotations -import asyncio import datetime -import os -import pkgutil -import sys -import traceback -import warnings - -from functools import partial -from shutil import copyfileobj -from types import ModuleType -from typing import Any, Generic, Optional import pendulum -from .typing import T - - -class ProxyMixin: - def _get_robj(self): - raise NotImplementedError - - def __getitem__(self, key): - return self._get_robj()[key] - - def __setitem__(self, key, value): - self._get_robj()[key] = value - - def __delitem__(self, key): - del self._get_robj()[key] - - def __getattr__(self, name: str) -> Any: - return getattr(self._get_robj(), name) - - def __setattr__(self, name, value): - setattr(self._get_robj(), name, value) - - def __delattr__(self, name): - delattr(self._get_robj(), name) - - def __bool__(self): - try: - return bool(self._get_robj()) - except RuntimeError: - return False - - def __eq__(self, obj) -> bool: - return self._get_robj() == obj - - def __ne__(self, obj) -> bool: - return self._get_robj() != obj - - def __call__(self, *args, **kwargs): - return self._get_robj()(*args, **kwargs) - - def __iter__(self): - return iter(self._get_robj()) - - def __contains__(self, element): - return element in self._get_robj() - - def __dir__(self): - try: - return dir(self._get_robj()) - except RuntimeError: - return [] - - @property - def __dict__(self): - try: - return self._get_robj().__dict__ - except RuntimeError: - raise AttributeError('__dict__') - - def __str__(self): - return str(self._get_robj()) - - def __repr__(self): - try: - obj = self._get_robj() - except RuntimeError: - return '<%s unbound>' % self.__class__.__name__ - return repr(obj) - - -class ObjectProxy(ProxyMixin, Generic[T]): - __slots__ = ('__obj', '__name__') - - def __init__(self, obj: Any, name: str): - object.__setattr__(self, '_ObjectProxy__obj', obj) - object.__setattr__(self, '__name__', name) - - def _get_robj(self) -> T: - return getattr(self.__obj, self.__name__) - - -class ContextVarProxy(ProxyMixin, Generic[T]): - __slots__ = ('__obj', '__name__') - - def __init__(self, obj: Any, name: str): - object.__setattr__(self, '_ContextVarProxy__obj', obj) - object.__setattr__(self, '__name__', name) - - def _get_robj(self) -> T: - return getattr(self.__obj.get(), self.__name__) - - -class ImmutableListMixin: - _hash_cache = None - - def __hash__(self) -> Optional[int]: # type: ignore - if self._hash_cache is not None: - return self._hash_cache - rv = self._hash_cache = hash(tuple(self)) # type: ignore - return rv - - def __reduce_ex__(self, protocol): - return type(self), (list(self),) - - def __delitem__(self, key): - _is_immutable(self) - - def __iadd__(self, other): - _is_immutable(self) - - def __imul__(self, other): - _is_immutable(self) - - def __setitem__(self, key, value): - _is_immutable(self) - - def append(self, item): - _is_immutable(self) - - def remove(self, itme): - _is_immutable(self) - - def extend(self, iterable): - _is_immutable(self) - - def insert(self, pos, value): - _is_immutable(self) - - def pop(self, index=-1): - _is_immutable(self) - - def reverse(self): - _is_immutable(self) - - def sort(self, cmp=None, key=None, reverse=None): - _is_immutable(self) - - -class ImmutableList(ImmutableListMixin, list): # type: ignore - def __repr__(self): - return '%s(%s)' % ( - self.__class__.__name__, list.__repr__(self) - ) - - -class LoopFileCtxWrapper: - __slots__ = ('_coro', '_obj') - - def __init__(self, coro): - self._coro = coro - self._obj = None - - def __await__(self): - return self._coro.__await__() - - async def __aenter__(self): - self._obj = await self._coro - return self._obj - - async def __aexit__(self, exc_type, exc, tb): - await self._obj.close() - self._obj = None - - -class LoopFileWrapper: - __slots__ = ('_file', '_loop') - - def __init__(self, f, loop=None): - self._file = f - self._loop = loop or asyncio.get_running_loop() - - async def read(self, *args, **kwargs): - return await self._loop.run_in_executor( - None, partial(self._file.read, *args, **kwargs)) - - async def write(self, *args, **kwargs): - return await self._loop.run_in_executor( - None, partial(self._file.write, *args, **kwargs)) - - async def close(self, *args, **kwargs): - return await self._loop.run_in_executor( - None, partial(self._file.close, *args, **kwargs)) - - def __getattr__(self, name): - return getattr(self._file, name) - - -def _is_immutable(self): - raise TypeError('%r objects are immutable' % self.__class__.__name__) - - -async def _loop_open_file(loop, *args, **kwargs): - f = await loop.run_in_executor(None, partial(open, *args, **kwargs)) - return LoopFileWrapper(f, loop) - - -def loop_open_file(*args, **kwargs): - return LoopFileCtxWrapper( - _loop_open_file(asyncio.get_running_loop(), *args, **kwargs) - ) - - -async def loop_copyfileobj(fsrc, fdst, length=None): - return await asyncio.get_running_loop().run_in_executor( - None, partial(copyfileobj, fsrc, fdst, length) - ) - - -#: application loaders -def get_app_module( - module_name: str, - raise_on_failure: bool = True -) -> Optional[ModuleType]: - try: - __import__(module_name) - except ImportError: - if sys.exc_info()[-1].tb_next: - raise RuntimeError( - f"While importing '{module_name}', an ImportError was raised:" - f"\n\n{traceback.format_exc()}" - ) - elif raise_on_failure: - raise RuntimeError(f"Could not import '{module_name}'.") - else: - return - return sys.modules[module_name] - - -def find_best_app(module: ModuleType) -> Any: - #: Given a module instance this tries to find the best possible - # application in the module. - from .app import App # noqa - - # Search for the most common names first. - for attr_name in ('app', 'application'): - app = getattr(module, attr_name, None) - if isinstance(app, App): - return app - - # Otherwise find the only object that is an App instance. - matches = [v for k, v in module.__dict__.items() if isinstance(v, App)] - - if len(matches) == 1: - return matches[0] - raise RuntimeError( - f"Failed to find Emmett application in module '{module.__name__}'." - ) - - -def locate_app(module_name: str, app_name: str, raise_on_failure: bool = True) -> Any: - module = get_app_module(module_name, raise_on_failure=raise_on_failure) - if app_name: - return getattr(module, app_name, None) - return find_best_app(module) - - -#: deprecation helpers -class RemovedInNextVersionWarning(DeprecationWarning): - pass - - -class deprecated(object): - def __init__(self, old_method_name, new_method_name, class_name=None, s=0): - self.class_name = class_name - self.old_method_name = old_method_name - self.new_method_name = new_method_name - self.additional_stack = s - - def __call__(self, f): - def wrapped(*args, **kwargs): - warn_of_deprecation( - self.old_method_name, self.new_method_name, self.class_name, - 3 + self.additional_stack) - return f(*args, **kwargs) - return wrapped - - -warnings.simplefilter('always', RemovedInNextVersionWarning) - - -def warn_of_deprecation(old_name, new_name, prefix=None, stack=2): - msg = "%(old)s is deprecated, use %(new)s instead." - if prefix: - msg = "%(prefix)s." + msg - warnings.warn( - msg % {'old': old_name, 'new': new_name, 'prefix': prefix}, - RemovedInNextVersionWarning, stack) - - -#: app init helpers -def get_root_path(import_name): - """Returns the path of the package or cwd if that cannot be found.""" - # Module already imported and has a file attribute. Use that first. - mod = sys.modules.get(import_name) - if mod is not None and hasattr(mod, '__file__'): - return os.path.dirname(os.path.abspath(mod.__file__)) - - # Next attempt: check the loader. - loader = pkgutil.get_loader(import_name) - - # Loader does not exist or we're referring to an unloaded main module - # or a main module without path (interactive sessions), go with the - # current working directory. - if loader is None or import_name == '__main__': - return os.getcwd() - - # For .egg, zipimporter does not have get_filename until Python 2.7. - # Some other loaders might exhibit the same behavior. - if hasattr(loader, 'get_filename'): - filepath = loader.get_filename(import_name) - else: - # Fall back to imports. - __import__(import_name) - mod = sys.modules[import_name] - filepath = getattr(mod, '__file__', None) - - # If we don't have a filepath it might be because we are a - # namespace package. In this case we pick the root path from the - # first module that is contained in our package. - if filepath is None: - raise RuntimeError('No root path can be found for the provided ' - 'module "%s". This can happen because the ' - 'module came from an import hook that does ' - 'not provide file name information or because ' - 'it\'s a namespace package. In this case ' - 'the root path needs to be explicitly ' - 'provided.' % import_name) - - # filepath is import_name.py for a module, or __init__.py for a package. - return os.path.dirname(os.path.abspath(filepath)) - - -def create_missing_app_folders(app): - try: - for subfolder in ['languages', 'logs', 'static']: - path = os.path.join(app.root_path, subfolder) - if not os.path.exists(path): - os.mkdir(path) - except Exception: - pass - #: monkey patches def _pendulum_to_datetime(obj): return datetime.datetime( - obj.year, obj.month, obj.day, - obj.hour, obj.minute, obj.second, obj.microsecond, - tzinfo=obj.tzinfo + obj.year, obj.month, obj.day, obj.hour, obj.minute, obj.second, obj.microsecond, tzinfo=obj.tzinfo ) def _pendulum_to_naive_datetime(obj): - obj = obj.in_timezone('UTC') - return datetime.datetime( - obj.year, obj.month, obj.day, - obj.hour, obj.minute, obj.second, obj.microsecond - ) + obj = obj.in_timezone("UTC") + return datetime.datetime(obj.year, obj.month, obj.day, obj.hour, obj.minute, obj.second, obj.microsecond) def _pendulum_json(obj): diff --git a/emmett/_reloader.py b/emmett/_reloader.py index 58e1a167..4dcf06a0 100644 --- a/emmett/_reloader.py +++ b/emmett/_reloader.py @@ -1,16 +1,16 @@ # -*- coding: utf-8 -*- """ - emmett._reloader - ---------------- +emmett._reloader +---------------- - Provides auto-reloading utilities. +Provides auto-reloading utilities. - :copyright: 2014 Giovanni Barillari +:copyright: 2014 Giovanni Barillari - Adapted from werkzeug code (http://werkzeug.pocoo.org) - :copyright: (c) 2015 by Armin Ronacher. +Adapted from werkzeug code (http://werkzeug.pocoo.org) +:copyright: (c) 2015 by Armin Ronacher. - :license: BSD-3-Clause +:license: BSD-3-Clause """ import multiprocessing @@ -19,14 +19,12 @@ import subprocess import sys import time - from itertools import chain from typing import Optional import click - -from ._internal import locate_app -from .server import run as _server_run +from emmett_core._internal import locate_app +from emmett_core.server import run as _server_run def _iter_module_files(): @@ -37,7 +35,7 @@ def _iter_module_files(): for module in list(sys.modules.values()): if module is None: continue - filename = getattr(module, '__file__', None) + filename = getattr(module, "__file__", None) if filename: old = None while not os.path.isfile(filename): @@ -58,15 +56,9 @@ def _get_args_for_reloading(): """ rv = [sys.executable] py_script = sys.argv[0] - if ( - os.name == 'nt' and not os.path.exists(py_script) and - os.path.exists(py_script + '.exe') - ): - py_script += '.exe' - if ( - os.path.splitext(rv[0])[1] == '.exe' and - os.path.splitext(py_script)[1] == '.exe' - ): + if os.name == "nt" and not os.path.exists(py_script) and os.path.exists(py_script + ".exe"): + py_script += ".exe" + if os.path.splitext(rv[0])[1] == ".exe" and os.path.splitext(py_script)[1] == ".exe": rv.pop(0) rv.append(py_script) rv.extend(sys.argv[1:]) @@ -82,8 +74,7 @@ class ReloaderLoop(object): _sleep = staticmethod(time.sleep) def __init__(self, extra_files=None, interval=1): - self.extra_files = set( - os.path.abspath(x) for x in extra_files or ()) + self.extra_files = {os.path.abspath(x) for x in extra_files or ()} self.interval = interval def run(self): @@ -94,10 +85,10 @@ def restart_with_reloader(self): but running the reloader thread. """ while 1: - click.secho('> Restarting (%s mode)' % self.name, fg='yellow') + click.secho("> Restarting (%s mode)" % self.name, fg="yellow") args = _get_args_for_reloading() new_environ = os.environ.copy() - new_environ['EMMETT_RUN_MAIN'] = 'true' + new_environ["EMMETT_RUN_MAIN"] = "true" # a weird bug on windows. sometimes unicode strings end up in the # environment and subprocess.call does not like this, encode them @@ -107,20 +98,20 @@ def restart_with_reloader(self): # if isinstance(value, unicode): # new_environ[key] = value.encode('iso-8859-1') - exit_code = subprocess.call(args, env=new_environ) + exit_code = subprocess.call(args, env=new_environ) # noqa: S603 if exit_code != 3: return exit_code def trigger_reload(self, process, filename): filename = os.path.abspath(filename) - click.secho('> Detected change in %r, reloading' % filename, fg='cyan') + click.secho("> Detected change in %r, reloading" % filename, fg="cyan") os.kill(process.pid, signal.SIGTERM) process.join() sys.exit(3) class StatReloaderLoop(ReloaderLoop): - name = 'stat' + name = "stat" def run(self, process): mtimes = {} @@ -141,18 +132,18 @@ def run(self, process): reloader_loops = { - 'stat': StatReloaderLoop, + "stat": StatReloaderLoop, } -reloader_loops['auto'] = reloader_loops['stat'] +reloader_loops["auto"] = reloader_loops["stat"] def run_with_reloader( interface, app_target, - host='127.0.0.1', + host="127.0.0.1", port=8000, - loop='auto', + loop="auto", log_level=None, log_access=False, threads=1, @@ -161,15 +152,17 @@ def run_with_reloader( ssl_keyfile: Optional[str] = None, extra_files=None, interval=1, - reloader_type='auto' + reloader_type="auto", ): reloader = reloader_loops[reloader_type](extra_files, interval) signal.signal(signal.SIGTERM, lambda *args: sys.exit(0)) try: - if os.environ.get('EMMETT_RUN_MAIN') == 'true': + if os.environ.get("EMMETT_RUN_MAIN") == "true": + from .app import App + # FIXME: find a better way to have app files in stat checker - locate_app(*app_target) + locate_app(App, *app_target) process = multiprocessing.Process( target=_server_run, @@ -184,7 +177,7 @@ def run_with_reloader( "threading_mode": threading_mode, "ssl_certfile": ssl_certfile, "ssl_keyfile": ssl_keyfile, - } + }, ) process.start() reloader.run(process) diff --git a/emmett/_shortcuts.py b/emmett/_shortcuts.py index e791e8a7..570e78df 100644 --- a/emmett/_shortcuts.py +++ b/emmett/_shortcuts.py @@ -1,31 +1,34 @@ # -*- coding: utf-8 -*- """ - emmett._shortcuts - ----------------- +emmett._shortcuts +----------------- - Some shortcuts +Some shortcuts - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ import hashlib +from uuid import uuid4 -hashlib_md5 = lambda s: hashlib.md5(bytes(s, 'utf8')) -hashlib_sha1 = lambda s: hashlib.sha1(bytes(s, 'utf8')) +hashlib_md5 = lambda s: hashlib.md5(bytes(s, "utf8")) +hashlib_sha1 = lambda s: hashlib.sha1(bytes(s, "utf8")) +uuid = lambda: str(uuid4()) -def to_bytes(obj, charset='utf8', errors='strict'): + +def to_bytes(obj, charset="utf8", errors="strict"): if obj is None: return None if isinstance(obj, (bytes, bytearray, memoryview)): return bytes(obj) if isinstance(obj, str): return obj.encode(charset, errors) - raise TypeError('Expected bytes') + raise TypeError("Expected bytes") -def to_unicode(obj, charset='utf8', errors='strict'): +def to_unicode(obj, charset="utf8", errors="strict"): if obj is None: return None if not isinstance(obj, bytes): diff --git a/emmett/app.py b/emmett/app.py index 86c4b65b..0253e5f4 100644 --- a/emmett/app.py +++ b/emmett/app.py @@ -1,82 +1,51 @@ # -*- coding: utf-8 -*- """ - emmett.app - ---------- +emmett.app +---------- - Provides the central application object. +Provides the central application object. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ from __future__ import annotations import os -import sys - -from logging import Logger -from typing import Any, Awaitable, Callable, Dict, List, Optional, Set, Type, Union +from typing import Any, Dict, List, Optional, Type, Union import click - +from emmett_core._internal import create_missing_app_folders, get_root_path +from emmett_core.app import App as _App, AppModule as _AppModule, AppModuleGroup as _AppModuleGroup, Config as _Config +from emmett_core.routing.cache import RouteCacheRule from yaml import SafeLoader as ymlLoader, load as ymlload -from ._internal import get_root_path, create_missing_app_folders, warn_of_deprecation -from .asgi import handlers as asgi_handlers -from .cache import RouteCacheRule +from .asgi.handlers import HTTPHandler as ASGIHTTPHandler, WSHandler as ASGIWSHandler from .ctx import current -from .datastructures import sdict, ConfigData -from .extensions import Extension, ExtensionType, Signals +from .extensions import Signals from .helpers import load_component from .html import asis from .language.helpers import Tstr from .language.translator import Translator -from .pipeline import Pipe, Injector -from .routing.router import HTTPRouter, WebsocketRouter, RoutingCtx, RoutingCtxGroup +from .pipeline import Injector, Pipe +from .routing.router import HTTPRouter, RoutingCtx, RoutingCtxGroup, WebsocketRouter from .routing.urls import url -from .rsgi import handlers as rsgi_handlers +from .rsgi.handlers import HTTPHandler as RSGIHTTPHandler, WSHandler as RSGIWSHandler from .templating.templater import Templater from .testing import EmmettTestClient -from .typing import ErrorHandlerType -from .utils import dict_to_sdict, cachedprop, read_file +from .utils import dict_to_sdict, read_file -class Config(ConfigData): +class Config(_Config): __slots__ = () def __init__(self, app: App): - self._app = app - super().__init__( - modules_class=AppModule, - hostname_default=None, - static_version=None, - static_version_urls=False, - url_default_namespace=None, - request_max_content_length=None, - request_body_timeout=None, - response_timeout=None - ) - self._handle_static = True + super().__init__(app) self._templates_auto_reload = app.debug or False - self._templates_encoding = 'utf8' - self._templates_escape = 'common' + self._templates_encoding = "utf8" + self._templates_escape = "common" self._templates_indent = False - def __setattr__(self, key, value): - obj = getattr(self.__class__, key, None) - if isinstance(obj, property): - return obj.fset(self, value) - return super().__setattr__(key, value) - - @property - def handle_static(self) -> bool: - return self._handle_static - - @handle_static.setter - def handle_static(self, value: bool): - self._handle_static = value - self._app._configure_asgi_handlers() - @property def templates_auto_reload(self) -> bool: return self._templates_auto_reload @@ -114,371 +83,7 @@ def templates_adjust_indent(self, value: bool): self._app.templater._set_indent(value) -class App: - __slots__ = [ - '__dict__', - '_asgi_handlers', - '_extensions_env', - '_extensions_listeners', - '_language_default', - '_language_force_on_url', - '_languages_set', - '_languages', - '_logger', - '_modules', - '_pipeline', - '_router_http', - '_router_ws', - 'cli', - 'config_path', - 'config', - 'error_handlers', - 'ext', - 'import_name', - 'logger_name', - 'root_path', - 'static_path', - 'template_default_extension', - 'template_path', - 'templater', - 'translator' - ] - - debug = None - test_client_class = None - - def __init__( - self, - import_name: str, - root_path: Optional[str] = None, - url_prefix: Optional[str] = None, - template_folder: str = 'templates', - config_folder: str = 'config' - ): - self.import_name = import_name - #: init debug var - self.debug = os.environ.get('EMMETT_RUN_ENV') == "true" - #: set paths for the application - if root_path is None: - root_path = get_root_path(self.import_name) - self.root_path = root_path - self.static_path = os.path.join(self.root_path, "static") - self.template_path = os.path.join(self.root_path, template_folder) - self.config_path = os.path.join(self.root_path, config_folder) - #: the click command line context for this application - self.cli = click.Group(self.import_name) - #: init the configuration - self.config = Config(self) - #: try to create needed folders - create_missing_app_folders(self) - #: init languages - self._languages: List[str] = [] - self._languages_set: Set[str] = set() - self._language_default: Optional[str] = None - self._language_force_on_url = False - self.translator = Translator( - os.path.join(self.root_path, 'languages'), - default_language=self.language_default or 'en', - watch_changes=self.debug, - str_class=Tstr - ) - #: init routing - self._pipeline: List[Pipe] = [] - self._router_http = HTTPRouter(self, url_prefix=url_prefix) - self._router_ws = WebsocketRouter(self, url_prefix=url_prefix) - self._asgi_handlers = { - 'http': asgi_handlers.HTTPHandler(self), - 'lifespan': asgi_handlers.LifeSpanHandler(self), - 'websocket': asgi_handlers.WSHandler(self) - } - self._rsgi_handlers = { - 'http': rsgi_handlers.HTTPHandler(self), - 'ws': rsgi_handlers.WSHandler(self) - } - self.error_handlers: Dict[int, Callable[[], Awaitable[str]]] = {} - self.template_default_extension = '.html' - #: init logger - self._logger = None - self.logger_name = self.import_name - #: init extensions - self.ext: sdict[str, Extension] = sdict() - self._extensions_env: sdict[str, Any] = sdict() - self._extensions_listeners: Dict[str, List[Callable[..., Any]]] = { - element.value: [] for element in Signals - } - #: init templater - self.templater: Templater = Templater( - path=self.template_path, - encoding=self.config.templates_encoding, - escape=self.config.templates_escape, - adjust_indent=self.config.templates_adjust_indent, - reload=self.config.templates_auto_reload - ) - #: finalise - self._modules: Dict[str, AppModule] = {} - current.app = self - - def _configure_asgi_handlers(self): - self._asgi_handlers['http']._configure_methods() - self._rsgi_handlers['http']._configure_methods() - - @cachedprop - def name(self): - if self.import_name == '__main__': - fn = getattr(sys.modules['__main__'], '__file__', None) - if fn is None: - rv = '__main__' - else: - rv = os.path.splitext(os.path.basename(fn))[0] - else: - rv = self.import_name - return rv - - @property - def languages(self) -> List[str]: - return self._languages - - @languages.setter - def languages(self, value: List[str]): - self._languages = value - self._languages_set = set(self._languages) - - @property - def language_default(self) -> Optional[str]: - return self._language_default - - @language_default.setter - def language_default(self, value: str): - self._language_default = value - self.translator._update_config(self._language_default or 'en') - - @property - def language_force_on_url(self) -> bool: - return self._language_force_on_url - - @language_force_on_url.setter - def language_force_on_url(self, value: bool): - self._language_force_on_url = value - self._router_http._set_language_handling() - self._router_ws._set_language_handling() - self._configure_asgi_handlers() - - @property - def pipeline(self) -> List[Pipe]: - return self._pipeline - - @pipeline.setter - def pipeline(self, pipes: List[Pipe]): - self._pipeline = pipes - self._router_http.pipeline = self._pipeline - self._router_ws.pipeline = self._pipeline - - @property - def injectors(self) -> List[Injector]: - return self._router_http.injectors - - @injectors.setter - def injectors(self, injectors: List[Injector]): - self._router_http.injectors = injectors - - def route( - self, - paths: Optional[Union[str, List[str]]] = None, - name: Optional[str] = None, - template: Optional[str] = None, - pipeline: Optional[List[Pipe]] = None, - injectors: Optional[List[Injector]] = None, - schemes: Optional[Union[str, List[str]]] = None, - hostname: Optional[str] = None, - methods: Optional[Union[str, List[str]]] = None, - prefix: Optional[str] = None, - template_folder: Optional[str] = None, - template_path: Optional[str] = None, - cache: Optional[RouteCacheRule] = None, - output: str = 'auto' - ) -> RoutingCtx: - if callable(paths): - raise SyntaxError('Use @route(), not @route.') - return self._router_http( - paths=paths, - name=name, - template=template, - pipeline=pipeline, - injectors=injectors, - schemes=schemes, - hostname=hostname, - methods=methods, - prefix=prefix, - template_folder=template_folder, - template_path=template_path, - cache=cache, - output=output - ) - - def websocket( - self, - paths: Optional[Union[str, List[str]]] = None, - name: Optional[str] = None, - pipeline: Optional[List[Pipe]] = None, - schemes: Optional[Union[str, List[str]]] = None, - hostname: Optional[str] = None, - prefix: Optional[str] = None - ) -> RoutingCtx: - if callable(paths): - raise SyntaxError('Use @websocket(), not @websocket.') - return self._router_ws( - paths=paths, - name=name, - pipeline=pipeline, - schemes=schemes, - hostname=hostname, - prefix=prefix - ) - - def on_error(self, code: int) -> Callable[[ErrorHandlerType], ErrorHandlerType]: - def decorator(f: ErrorHandlerType) -> ErrorHandlerType: - self.error_handlers[code] = f - return f - return decorator - - @property - def command(self): - return self.cli.command - - @property - def command_group(self): - return self.cli.group - - @property - def log(self) -> Logger: - if self._logger and self._logger.name == self.logger_name: - return self._logger - from .logger import _logger_lock, create_logger - with _logger_lock: - if self._logger and self._logger.name == self.logger_name: - return self._logger - self._logger = rv = create_logger(self) - return rv - - def render_template(self, filename: str) -> str: - ctx = { - 'current': current, 'url': url, 'asis': asis, - 'load_component': load_component - } - return self.templater.render(filename, ctx) - - def config_from_yaml(self, filename: str, namespace: Optional[str] = None): - #: import configuration from yaml files - rc = read_file(os.path.join(self.config_path, filename)) - rc = ymlload(rc, Loader=ymlLoader) - c = self.config if namespace is None else self.config[namespace] - for key, val in rc.items(): - c[key] = dict_to_sdict(val) - - #: Register modules - def _register_module(self, mod: AppModule): - self._modules[mod.name] = mod - - #: Creates the extensions' environments and configs - def __init_extension(self, ext): - if ext.namespace is None: - ext.namespace = ext.__name__ - if self._extensions_env[ext.namespace] is None: - self._extensions_env[ext.namespace] = sdict() - return self._extensions_env[ext.namespace], self.config[ext.namespace] - - #: Register extension listeners - def __register_extension_listeners(self, ext): - for signal, listener in ext._listeners_: - self._extensions_listeners[signal].append(listener) - - #: Add an extension to application - def use_extension(self, ext_cls: Type[ExtensionType]) -> ExtensionType: - if not issubclass(ext_cls, Extension): - raise RuntimeError( - f'{ext_cls.__name__} is an invalid Emmett extension' - ) - ext_env, ext_config = self.__init_extension(ext_cls) - ext = self.ext[ext_cls.__name__] = ext_cls(self, ext_env, ext_config) - self.__register_extension_listeners(ext) - ext.on_load() - return ext - - #: Add a template extension to application - def use_template_extension(self, ext_cls, **config): - return self.templater.use_extension(ext_cls, **config) - - def send_signal(self, signal: Union[str, Signals], *args, **kwargs): - if not isinstance(signal, Signals): - warn_of_deprecation( - "App.send_signal str argument", - "extensions.Signals as argument", - stack=3 - ) - try: - signal = Signals[signal] - except KeyError: - raise SyntaxError(f"{signal} is not a valid signal") - for listener in self._extensions_listeners[signal]: - listener(*args, **kwargs) - - def make_shell_context(self, context: Dict[str, Any] = {}) -> Dict[str, Any]: - context['app'] = self - return context - - def test_client(self, use_cookies: bool = True, **kwargs) -> EmmettTestClient: - tclass = self.test_client_class or EmmettTestClient - return tclass(self, use_cookies=use_cookies, **kwargs) - - def __call__(self, scope, receive, send): - return self._asgi_handlers[scope['type']](scope, receive, send) - - def __rsgi__(self, scope, protocol): - return self._rsgi_handlers[scope.proto](scope, protocol) - - def __rsgi_init__(self, loop): - self.send_signal(Signals.after_loop, loop=loop) - - def module( - self, - import_name: str, - name: str, - template_folder: Optional[str] = None, - template_path: Optional[str] = None, - static_folder: Optional[str] = None, - static_path: Optional[str] = None, - url_prefix: Optional[str] = None, - hostname: Optional[str] = None, - cache: Optional[RouteCacheRule] = None, - root_path: Optional[str] = None, - pipeline: Optional[List[Pipe]] = None, - injectors: Optional[List[Injector]] = None, - module_class: Optional[Type[AppModule]] = None, - **kwargs: Any - ) -> AppModule: - module_class = module_class or self.config.modules_class - return module_class.from_app( - self, - import_name, - name, - template_folder=template_folder, - template_path=template_path, - static_folder=static_folder, - static_path=static_path, - url_prefix=url_prefix, - hostname=hostname, - cache=cache, - root_path=root_path, - pipeline=pipeline or [], - injectors=injectors or [], - opts=kwargs - ) - - def module_group(self, *modules: AppModule) -> AppModuleGroup: - return AppModuleGroup(*modules) - - -class AppModule: +class AppModule(_AppModule): @classmethod def from_app( cls, @@ -495,7 +100,7 @@ def from_app( root_path: Optional[str], pipeline: List[Pipe], injectors: List[Injector], - opts: Dict[str, Any] = {} + opts: Dict[str, Any] = {}, ): return cls( app, @@ -511,7 +116,7 @@ def from_app( root_path=root_path, pipeline=pipeline, injectors=injectors, - **opts + **opts, ) @classmethod @@ -528,17 +133,14 @@ def from_module( hostname: Optional[str], cache: Optional[RouteCacheRule], root_path: Optional[str], - opts: Dict[str, Any] = {} + opts: Dict[str, Any] = {}, ): - if '.' in name: - raise RuntimeError( - "Nested app modules' names should not contains dots" - ) - name = appmod.name + '.' + name - if url_prefix and not url_prefix.startswith('/'): - url_prefix = '/' + url_prefix - module_url_prefix = (appmod.url_prefix + (url_prefix or '')) \ - if appmod.url_prefix else url_prefix + if "." in name: + raise RuntimeError("Nested app modules' names should not contains dots") + name = appmod.name + "." + name + if url_prefix and not url_prefix.startswith("/"): + url_prefix = "/" + url_prefix + module_url_prefix = (appmod.url_prefix + (url_prefix or "")) if appmod.url_prefix else url_prefix hostname = hostname or appmod.hostname cache = cache or appmod.cache return cls( @@ -555,7 +157,7 @@ def from_module( root_path=root_path, pipeline=appmod.pipeline, injectors=appmod.injectors, - **opts + **opts, ) @classmethod @@ -572,7 +174,7 @@ def from_module_group( hostname: Optional[str], cache: Optional[RouteCacheRule], root_path: Optional[str], - opts: Dict[str, Any] = {} + opts: Dict[str, Any] = {}, ) -> AppModulesGrouped: mods = [] for module in appmodgroup.modules: @@ -588,7 +190,7 @@ def from_module_group( hostname=hostname, cache=cache, root_path=root_path, - opts=opts + opts=opts, ) mods.append(mod) return AppModulesGrouped(*mods) @@ -606,7 +208,7 @@ def module( cache: Optional[RouteCacheRule] = None, root_path: Optional[str] = None, module_class: Optional[Type[AppModule]] = None, - **kwargs: Any + **kwargs: Any, ) -> AppModule: module_class = module_class or self.__class__ return module_class.from_module( @@ -621,7 +223,7 @@ def module( hostname=hostname, cache=cache, root_path=root_path, - opts=kwargs + opts=kwargs, ) def __init__( @@ -639,44 +241,29 @@ def __init__( root_path: Optional[str] = None, pipeline: Optional[List[Pipe]] = None, injectors: Optional[List[Injector]] = None, - **kwargs: Any + **kwargs: Any, ): - self.app = app - self.name = name - self.import_name = import_name - if root_path is None: - root_path = get_root_path(self.import_name) - self.root_path = root_path + super().__init__( + app=app, + name=name, + import_name=import_name, + static_folder=static_folder, + static_path=static_path, + url_prefix=url_prefix, + hostname=hostname, + cache=cache, + root_path=root_path, + pipeline=pipeline, + **kwargs, + ) #: - `template_folder` is referred to application `template_path` # - `template_path` is referred to module root_directory unless absolute self.template_folder = template_folder if template_path and not template_path.startswith("/"): template_path = os.path.join(self.root_path, template_path) self.template_path = template_path - #: - `static_folder` is referred to application `static_path` - # - `static_path` is referred to module root_directory unless absolute - if static_path and not static_path.startswith("/"): - static_path = os.path.join(self.root_path, static_path) - self._static_path = ( - os.path.join(self.app.static_path, static_folder) if static_folder else - (static_path or self.app.static_path) - ) - self.url_prefix = url_prefix - self.hostname = hostname - self.cache = cache - self._super_pipeline = pipeline or [] self._super_injectors = injectors or [] - self.pipeline = [] self.injectors = [] - self.app._register_module(self) - - @property - def pipeline(self) -> List[Pipe]: - return self._pipeline - - @pipeline.setter - def pipeline(self, pipeline: List[Pipe]): - self._pipeline = self._super_pipeline + pipeline @property def injectors(self) -> List[Injector]: @@ -691,22 +278,20 @@ def route( paths: Optional[Union[str, List[str]]] = None, name: Optional[str] = None, template: Optional[str] = None, - **kwargs + **kwargs, ) -> RoutingCtx: if name is not None and "." in name: - raise RuntimeError( - "App modules' route names should not contains dots" - ) + raise RuntimeError("App modules' route names should not contains dots") name = self.name + "." + (name or "") - pipeline = kwargs.get('pipeline', []) - injectors = kwargs.get('injectors', []) + pipeline = kwargs.get("pipeline", []) + injectors = kwargs.get("injectors", []) if self.pipeline: pipeline = self.pipeline + pipeline - kwargs['pipeline'] = pipeline + kwargs["pipeline"] = pipeline if self.injectors: injectors = self.injectors + injectors - kwargs['injectors'] = injectors - kwargs['cache'] = kwargs.get('cache', self.cache) + kwargs["injectors"] = injectors + kwargs["cache"] = kwargs.get("cache", self.cache) return self.app.route( paths=paths, name=name, @@ -715,37 +300,190 @@ def route( template_folder=self.template_folder, template_path=self.template_path, hostname=self.hostname, - **kwargs + **kwargs, ) - def websocket( + +class App(_App): + __slots__ = ["cli", "template_default_extension", "template_path", "templater", "translator"] + + config_class = Config + modules_class = AppModule + signals_class = Signals + test_client_class = EmmettTestClient + + def __init__( + self, + import_name: str, + root_path: Optional[str] = None, + url_prefix: Optional[str] = None, + template_folder: str = "templates", + config_folder: str = "config", + ): + super().__init__( + import_name=import_name, + root_path=root_path, + url_prefix=url_prefix, + config_folder=config_folder, + template_folder=template_folder, + ) + self.cli = click.Group(self.import_name) + self.translator = Translator( + os.path.join(self.root_path, "languages"), + default_language=self.language_default or "en", + watch_changes=self.debug, + str_class=Tstr, + ) + self.template_default_extension = ".html" + self.templater: Templater = Templater( + path=self.template_path, + encoding=self.config.templates_encoding, + escape=self.config.templates_escape, + adjust_indent=self.config.templates_adjust_indent, + reload=self.config.templates_auto_reload, + ) + + def _configure_paths(self, root_path, opts): + if root_path is None: + root_path = get_root_path(self.import_name) + self.root_path = root_path + self.static_path = os.path.join(self.root_path, "static") + self.template_path = os.path.join(self.root_path, opts["template_folder"]) + self.config_path = os.path.join(self.root_path, opts["config_folder"]) + create_missing_app_folders(self, ["languages", "logs", "static"]) + + def _init_routers(self, url_prefix): + self._router_http = HTTPRouter(self, current, url_prefix=url_prefix) + self._router_ws = WebsocketRouter(self, current, url_prefix=url_prefix) + + def _init_handlers(self): + self._asgi_handlers["http"] = ASGIHTTPHandler(self, current) + self._asgi_handlers["ws"] = ASGIWSHandler(self, current) + self._rsgi_handlers["http"] = RSGIHTTPHandler(self, current) + self._rsgi_handlers["ws"] = RSGIWSHandler(self, current) + + def _register_with_ctx(self): + current.app = self + + @property + def language_default(self) -> Optional[str]: + return self._language_default + + @language_default.setter + def language_default(self, value: str): + self._language_default = value + self.translator._update_config(self._language_default or "en") + + @property + def injectors(self) -> List[Injector]: + return self._router_http.injectors + + @injectors.setter + def injectors(self, injectors: List[Injector]): + self._router_http.injectors = injectors + + def route( self, paths: Optional[Union[str, List[str]]] = None, name: Optional[str] = None, - **kwargs + template: Optional[str] = None, + pipeline: Optional[List[Pipe]] = None, + injectors: Optional[List[Injector]] = None, + schemes: Optional[Union[str, List[str]]] = None, + hostname: Optional[str] = None, + methods: Optional[Union[str, List[str]]] = None, + prefix: Optional[str] = None, + template_folder: Optional[str] = None, + template_path: Optional[str] = None, + cache: Optional[RouteCacheRule] = None, + output: str = "auto", ) -> RoutingCtx: - if name is not None and "." in name: - raise RuntimeError( - "App modules' websocket names should not contains dots" - ) - name = self.name + "." + (name or "") - pipeline = kwargs.get('pipeline', []) - if self.pipeline: - pipeline = self.pipeline + pipeline - kwargs['pipeline'] = pipeline - return self.app.websocket( + if callable(paths): + raise SyntaxError("Use @route(), not @route.") + return self._router_http( paths=paths, name=name, - prefix=self.url_prefix, - hostname=self.hostname, - **kwargs + template=template, + pipeline=pipeline, + injectors=injectors, + schemes=schemes, + hostname=hostname, + methods=methods, + prefix=prefix, + template_folder=template_folder, + template_path=template_path, + cache=cache, + output=output, ) + @property + def command(self): + return self.cli.command -class AppModuleGroup: - def __init__(self, *modules: AppModule): - self.modules = modules + @property + def command_group(self): + return self.cli.group + def make_shell_context(self, context): + context["app"] = self + return context + + def render_template(self, filename: str) -> str: + ctx = {"current": current, "url": url, "asis": asis, "load_component": load_component} + return self.templater.render(filename, ctx) + + def config_from_yaml(self, filename: str, namespace: Optional[str] = None): + #: import configuration from yaml files + rc = read_file(os.path.join(self.config_path, filename)) + rc = ymlload(rc, Loader=ymlLoader) + c = self.config if namespace is None else self.config[namespace] + for key, val in rc.items(): + c[key] = dict_to_sdict(val) + + #: Add a template extension to application + def use_template_extension(self, ext_cls, **config): + return self.templater.use_extension(ext_cls, **config) + + def module( + self, + import_name: str, + name: str, + template_folder: Optional[str] = None, + template_path: Optional[str] = None, + static_folder: Optional[str] = None, + static_path: Optional[str] = None, + url_prefix: Optional[str] = None, + hostname: Optional[str] = None, + cache: Optional[RouteCacheRule] = None, + root_path: Optional[str] = None, + pipeline: Optional[List[Pipe]] = None, + injectors: Optional[List[Injector]] = None, + module_class: Optional[Type[AppModule]] = None, + **kwargs: Any, + ) -> AppModule: + module_class = module_class or self.modules_class + return module_class.from_app( + self, + import_name, + name, + template_folder=template_folder, + template_path=template_path, + static_folder=static_folder, + static_path=static_path, + url_prefix=url_prefix, + hostname=hostname, + cache=cache, + root_path=root_path, + pipeline=pipeline or [], + injectors=injectors or [], + opts=kwargs, + ) + + def module_group(self, *modules: AppModule) -> AppModuleGroup: + return AppModuleGroup(*modules) + + +class AppModuleGroup(_AppModuleGroup): def module( self, import_name: str, @@ -759,7 +497,7 @@ def module( cache: Optional[RouteCacheRule] = None, root_path: Optional[str] = None, module_class: Optional[Type[AppModule]] = None, - **kwargs: Any + **kwargs: Any, ) -> AppModulesGrouped: module_class = module_class or AppModule return module_class.from_module_group( @@ -774,7 +512,7 @@ def module( hostname=hostname, cache=cache, root_path=root_path, - opts=kwargs + opts=kwargs, ) def route( @@ -782,23 +520,9 @@ def route( paths: Optional[Union[str, List[str]]] = None, name: Optional[str] = None, template: Optional[str] = None, - **kwargs + **kwargs, ) -> RoutingCtxGroup: - return RoutingCtxGroup([ - mod.route(paths=paths, name=name, template=template, **kwargs) - for mod in self.modules - ]) - - def websocket( - self, - paths: Optional[Union[str, List[str]]] = None, - name: Optional[str] = None, - **kwargs - ): - return RoutingCtxGroup([ - mod.websocket(paths=paths, name=name, **kwargs) - for mod in self.modules - ]) + return RoutingCtxGroup([mod.route(paths=paths, name=name, template=template, **kwargs) for mod in self.modules]) class AppModulesGrouped(AppModuleGroup): diff --git a/emmett/asgi/handlers.py b/emmett/asgi/handlers.py index 2c93455a..0be20684 100644 --- a/emmett/asgi/handlers.py +++ b/emmett/asgi/handlers.py @@ -1,498 +1,76 @@ # -*- coding: utf-8 -*- """ - emmett.asgi.handlers - -------------------- +emmett.asgi.handlers +-------------------- - Provides ASGI handlers. +Provides ASGI handlers. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ from __future__ import annotations -import asyncio -import os -import re -import time - -from collections import OrderedDict -from email.utils import formatdate from hashlib import md5 from importlib import resources -from typing import Any, Awaitable, Callable, Optional, Tuple, Union +from typing import Awaitable, Callable + +from emmett_core.http.response import HTTPBytesResponse, HTTPResponse +from emmett_core.protocols.asgi.handlers import HTTPHandler as _HTTPHandler, WSHandler as _WSHandler +from emmett_core.protocols.asgi.typing import Receive, Scope, Send +from emmett_core.utils import cachedprop -from ..ctx import RequestContext, WSContext, current -from ..debug import smart_traceback, debug_handler -from ..extensions import Signals -from ..http import HTTPBytes, HTTPResponse, HTTPFile, HTTP +from ..ctx import current +from ..debug import debug_handler, smart_traceback from ..libs.contenttype import contenttype -from ..utils import cachedprop from ..wrappers.response import Response -from .helpers import RequestCancelled -from .typing import Event, EventHandler, EventLooper, Receive, Scope, Send from .wrappers import Request, Websocket -REGEX_STATIC = re.compile( - r'^/static/(?P__[\w\-\.]+__/)?(?P_\d+\.\d+\.\d+/)?(?P.*?)$' -) -REGEX_STATIC_LANG = re.compile( - r'^/(?P\w{2}/)?static/(?P__[\w\-\.]__+/)?(?P_\d+\.\d+\.\d+/)?(?P.*?)$' -) - - -class EventHandlerWrapper: - __slots__ = ['event', 'f'] - - def __init__(self, event: str, f: EventHandler): - self.event = event - self.f = f - - async def __call__( - self, - handler: Handler, - scope: Scope, - receive: Receive, - send: Send, - event: Event - ) -> Tuple[Optional[EventHandler], None]: - task = await self.f(handler, scope, receive, send, event) - return task, None - - -class MetaHandler(type): - def __new__(cls, name, bases, attrs): - new_class = type.__new__(cls, name, bases, attrs) - declared_events = OrderedDict() - all_events = OrderedDict() - events = [] - for key, value in list(attrs.items()): - if isinstance(value, EventHandlerWrapper): - events.append((key, value)) - declared_events.update(events) - new_class._declared_events_ = declared_events - for base in reversed(new_class.__mro__[1:]): - if hasattr(base, '_declared_events_'): - all_events.update(base._declared_events_) - all_events.update(declared_events) - new_class._all_events_ = all_events - new_class._events_handlers_ = { - el.event: el for el in new_class._all_events_.values()} - return new_class - - -class Handler(metaclass=MetaHandler): - __slots__ = ['app'] - - def __init__(self, app): - self.app = app - - @classmethod - def on_event( - cls, event: str - ) -> Callable[[EventHandler], EventHandlerWrapper]: - def wrap(f: EventHandler) -> EventHandlerWrapper: - return EventHandlerWrapper(event, f) - return wrap - - def get_event_handler( - self, event_type: str - ) -> Union[EventHandler, EventHandlerWrapper]: - return self._events_handlers_.get(event_type, _event_missing) - - def __call__( - self, - scope: Scope, - receive: Receive, - send: Send - ) -> Awaitable[None]: - return self.handle_events(scope, receive, send) - - async def handle_events( - self, - scope: Scope, - receive: Receive, - send: Send - ): - task: Optional[EventLooper] = _event_looper - event = None - while task: - task, event = await task(self, scope, receive, send, event) - -class LifeSpanHandler(Handler): +class HTTPHandler(_HTTPHandler): __slots__ = [] - - @Handler.on_event('lifespan.startup') - async def event_startup( - self, - scope: Scope, - receive: Receive, - send: Send, - event: Event - ) -> EventLooper: - self.app.send_signal(Signals.after_loop, loop=asyncio.get_event_loop()) - await send({'type': 'lifespan.startup.complete'}) - return _event_looper - - @Handler.on_event('lifespan.shutdown') - async def event_shutdown( - self, - scope: Scope, - receive: Receive, - send: Send, - event: Event - ): - await send({'type': 'lifespan.shutdown.complete'}) - - -class RequestHandler(Handler): - __slots__ = ['router'] - - def __init__(self, app): - super().__init__(app) - self._bind_router() - self._configure_methods() - - def _bind_router(self): - raise NotImplementedError - - def _configure_methods(self): - raise NotImplementedError - - -class HTTPHandler(RequestHandler): - __slots__ = ['pre_handler', 'static_handler', 'static_matcher', '__dict__'] - - def _bind_router(self): - self.router = self.app._router_http - self._internal_assets_md = ( - str(int(time.time())), - formatdate(time.time(), usegmt=True) - ) - - def _configure_methods(self): - self.static_matcher = ( - self._static_lang_matcher if self.app.language_force_on_url else - self._static_nolang_matcher) - self.static_handler = ( - self._static_handler if self.app.config.handle_static else - self.dynamic_handler) - self.pre_handler = ( - self._prefix_handler if self.router._prefix_main else - self.static_handler) - - async def __call__( - self, - scope: Scope, - receive: Receive, - send: Send - ): - scope['emt.path'] = scope['path'] or '/' - try: - http = await self.pre_handler(scope, receive, send) - await asyncio.wait_for( - http.asgi(scope, send), - self.app.config.response_timeout - ) - except RequestCancelled: - return - except asyncio.TimeoutError: - self.app.log.warn( - f"Timeout sending response: ({scope['emt.path']})" - ) + wrapper_cls = Request + response_cls = Response @cachedprop def error_handler(self) -> Callable[[], Awaitable[str]]: - return ( - self._debug_handler if self.app.debug else self.exception_handler - ) + return self._debug_handler if self.app.debug else self.exception_handler - @cachedprop - def exception_handler(self) -> Callable[[], Awaitable[str]]: - return self.app.error_handlers.get(500, self._exception_handler) - - @staticmethod - async def _http_response(code: int) -> HTTPResponse: - return HTTP(code) - - def _prefix_handler( - self, - scope: Scope, - receive: Receive, - send: Send - ) -> Awaitable[HTTPResponse]: - path = scope['emt.path'] - if not path.startswith(self.router._prefix_main): - return self._http_response(404) - scope['emt.path'] = path[self.router._prefix_main_len:] or '/' - return self.static_handler(scope, receive, send) - - def _static_lang_matcher( - self, path: str - ) -> Tuple[Optional[str], Optional[str]]: - match = REGEX_STATIC_LANG.match(path) - if match: - lang, mname, version, file_name = match.group('l', 'm', 'v', 'f') - if mname: - mod = self.app._modules.get(mname) - spath = mod._static_path if mod else self.app.static_path - else: - spath = self.app.static_path - static_file = os.path.join(spath, file_name) - if lang: - lang_file = os.path.join(spath, lang, file_name) - if os.path.exists(lang_file): - static_file = lang_file - return static_file, version - return None, None - - def _static_nolang_matcher( - self, path: str - ) -> Tuple[Optional[str], Optional[str]]: - if path.startswith('/static/'): - mname, version, file_name = REGEX_STATIC.match(path).group('m', 'v', 'f') - if mname: - mod = self.app._modules.get(mname[2:-3]) - static_file = os.path.join(mod._static_path, file_name) if mod else None - elif file_name: - static_file = os.path.join(self.app.static_path, file_name) - else: - static_file = None - return static_file, version - return None, None - - async def _static_response(self, file_path: str) -> HTTPFile: - return HTTPFile(file_path) - - async def _static_content(self, content: bytes, content_type: str) -> HTTPBytes: + async def _static_content(self, content: bytes, content_type: str) -> HTTPBytesResponse: content_len = str(len(content)) - return HTTPBytes( + return HTTPBytesResponse( 200, content, headers={ - 'content-type': content_type, - 'content-length': content_len, - 'last-modified': self._internal_assets_md[1], - 'etag': md5( - f"{self._internal_assets_md[0]}_{content_len}".encode("utf8") - ).hexdigest() - } + "content-type": content_type, + "content-length": content_len, + "last-modified": self._internal_assets_md[1], + "etag": md5(f"{self._internal_assets_md[0]}_{content_len}".encode("utf8")).hexdigest(), + }, ) - def _static_handler( - self, - scope: Scope, - receive: Receive, - send: Send - ) -> Awaitable[HTTPResponse]: - path = scope['emt.path'] + def _static_handler(self, scope: Scope, receive: Receive, send: Send) -> Awaitable[HTTPResponse]: + path = scope["emt.path"] #: handle internal assets - if path.startswith('/__emmett__'): + if path.startswith("/__emmett__"): file_name = path[12:] if not file_name or file_name.endswith(".html"): return self._http_response(404) pkg = None - if '/' in file_name: - pkg, file_name = file_name.split('/', 1) + if "/" in file_name: + pkg, file_name = file_name.split("/", 1) try: - file_contents = resources.read_binary( - f'emmett.assets.{pkg}' if pkg else 'emmett.assets', - file_name - ) + file_contents = resources.read_binary(f"emmett.assets.{pkg}" if pkg else "emmett.assets", file_name) except FileNotFoundError: return self._http_response(404) return self._static_content(file_contents, contenttype(file_name)) - #: handle app assets - static_file, _ = self.static_matcher(path) - if static_file: - return self._static_response(static_file) - return self.dynamic_handler(scope, receive, send) - - async def dynamic_handler( - self, - scope: Scope, - receive: Receive, - send: Send - ) -> HTTPResponse: - request = Request( - scope, - receive, - send, - max_content_length=self.app.config.request_max_content_length, - body_timeout=self.app.config.request_body_timeout - ) - response = Response() - ctx = RequestContext(self.app, request, response) - ctx_token = current._init_(ctx) - try: - http = await self.router.dispatch(request, response) - except HTTPResponse as http_exception: - http = http_exception - #: render error with handlers if in app - error_handler = self.app.error_handlers.get(http.status_code) - if error_handler: - http = HTTP( - http.status_code, - await error_handler(), - headers=response.headers, - cookies=response.cookies - ) - except RequestCancelled: - raise - except Exception: - self.app.log.exception('Application exception:') - http = HTTP( - 500, - await self.error_handler(), - headers=response.headers - ) - finally: - current._close_(ctx_token) - return http + return super()._static_handler(scope, receive, send) async def _debug_handler(self) -> str: - current.response.headers._data['content-type'] = ( - 'text/html; charset=utf-8' - ) + current.response.headers._data["content-type"] = "text/html; charset=utf-8" return debug_handler(smart_traceback(self.app)) - async def _exception_handler(self) -> str: - current.response.headers._data['content-type'] = 'text/plain' - return 'Internal error' - - -class WSHandler(RequestHandler): - __slots__ = ['pre_handler'] - def _bind_router(self): - self.router = self.app._router_ws - - def _configure_methods(self): - self.pre_handler = ( - self._prefix_handler if self.router._prefix_main else - self.dynamic_handler) - - async def __call__( - self, - scope: Scope, - receive: Receive, - send: Send - ): - scope['emt.input'] = asyncio.Queue() - task_events = asyncio.create_task(self.handle_events(scope, receive, send)) - task_request = asyncio.create_task(self.handle_request(scope, send)) - _, pending = await asyncio.wait( - [task_request, task_events], return_when=asyncio.FIRST_COMPLETED - ) - scope['emt._flow_cancel'] = True - _cancel_tasks(pending) - - @Handler.on_event('websocket.connect') - async def event_connect( - self, - scope: Scope, - receive: Receive, - send: Send, - event: Event - ) -> EventLooper: - return _event_looper - - @Handler.on_event('websocket.disconnect') - async def event_disconnect( - self, - scope: Scope, - receive: Receive, - send: Send, - event: Event - ): - return - - @Handler.on_event('websocket.receive') - async def event_receive( - self, - scope: Scope, - receive: Receive, - send: Send, - event: Event - ) -> EventLooper: - await scope['emt.input'].put(event.get('bytes') or event['text']) - return _event_looper - - async def handle_request( - self, - scope: Scope, - send: Send - ): - scope['emt.path'] = scope['path'] or '/' - scope['emt._ws_closed'] = False - try: - await self.pre_handler(scope, send) - except HTTPResponse: - if not scope['emt._ws_closed']: - await send({'type': 'websocket.close', 'code': 1006}) - except asyncio.CancelledError: - if not scope.get('emt._flow_cancel', False): - self.app.log.exception('Application exception:') - except Exception: - if not scope['emt._ws_closed']: - await send({'type': 'websocket.close', 'code': 1006}) - self.app.log.exception('Application exception:') - - def _prefix_handler( - self, - scope: Scope, - send: Send - ) -> Awaitable[None]: - path = scope['emt.path'] - if not path.startswith(self.router._prefix_main): - raise HTTP(404) - scope['emt.path'] = path[self.router._prefix_main_len:] or '/' - return self.dynamic_handler(scope, send) - - async def dynamic_handler( - self, - scope: Scope, - send: Send - ): - ctx = WSContext( - self.app, - Websocket(scope, scope['emt.input'].get, send) - ) - ctx_token = current._init_(ctx) - try: - await self.router.dispatch(ctx.websocket) - finally: - if ( - not scope.get('emt._flow_cancel', False) and - ctx.websocket._accepted - ): - await send({'type': 'websocket.close', 'code': 1000}) - scope['emt._ws_closed'] = True - current._close_(ctx_token) - - -async def _event_looper( - handler: Handler, - scope: Scope, - receive: Receive, - send: Send, - event: Any = None -) -> Tuple[Union[EventHandler, EventHandlerWrapper], Event]: - event = await receive() - return handler.get_event_handler(event['type']), event - - -async def _event_missing( - handler: Handler, - scope: Scope, - receive: Receive, - send: Send, - event: Event -): - raise RuntimeError(f"Event type '{event['type']}' not recognized") - - -def _cancel_tasks(tasks): - for task in tasks: - task.cancel() +class WSHandler(_WSHandler): + __slots__ = [] + wrapper_cls = Websocket diff --git a/emmett/asgi/helpers.py b/emmett/asgi/helpers.py deleted file mode 100644 index 1a41792a..00000000 --- a/emmett/asgi/helpers.py +++ /dev/null @@ -1,14 +0,0 @@ -# -*- coding: utf-8 -*- -""" - emmett.asgi.helpers - ------------------- - - Provides ASGI helpers - - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause -""" - - -class RequestCancelled(Exception): - ... diff --git a/emmett/asgi/typing.py b/emmett/asgi/typing.py deleted file mode 100644 index 901c9cef..00000000 --- a/emmett/asgi/typing.py +++ /dev/null @@ -1,19 +0,0 @@ -# -*- coding: utf-8 -*- -""" - emmett.wrappers.typing - ---------------------- - - Provides typing helpers. - - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause -""" - -from typing import Any, Awaitable, Callable, Dict, Tuple - -Scope = Dict[str, Any] -Receive = Callable[[], Awaitable[Dict[str, Any]]] -Send = Callable[[Dict[str, Any]], Awaitable[None]] -Event = Dict[str, Any] -EventHandler = Callable[[Any, Scope, Receive, Send, Event], Awaitable[Any]] -EventLooper = Callable[..., Awaitable[Tuple[EventHandler, Event]]] diff --git a/emmett/asgi/workers.py b/emmett/asgi/workers.py deleted file mode 100644 index fc335cbc..00000000 --- a/emmett/asgi/workers.py +++ /dev/null @@ -1,106 +0,0 @@ -# -*- coding: utf-8 -*- -""" - emmett.asgi.workers - ------------------- - - Provides ASGI gunicorn workers. - - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause -""" - -import asyncio -import logging -import signal -import sys - -from gunicorn.arbiter import Arbiter -from gunicorn.workers.base import Worker as _Worker -from uvicorn.config import Config -from uvicorn.server import Server - - -class Worker(_Worker): - EMMETT_CONFIG = {} - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - logger = logging.getLogger("uvicorn.error") - logger.handlers = self.log.error_log.handlers - logger.setLevel(self.log.error_log.level) - logger.propagate = False - - logger = logging.getLogger("uvicorn.access") - logger.handlers = self.log.access_log.handlers - logger.setLevel(self.log.access_log.level) - logger.propagate = False - - config = { - "app": None, - "log_config": None, - "timeout_keep_alive": self.cfg.keepalive, - "timeout_notify": self.timeout, - "callback_notify": self.callback_notify, - "limit_max_requests": self.max_requests, - "forwarded_allow_ips": self.cfg.forwarded_allow_ips - } - - if self.cfg.is_ssl: - config.update( - ssl_keyfile=self.cfg.ssl_options.get("keyfile"), - ssl_certfile=self.cfg.ssl_options.get("certfile"), - ssl_keyfile_password=self.cfg.ssl_options.get("password"), - ssl_version=self.cfg.ssl_options.get("ssl_version"), - ssl_cert_reqs=self.cfg.ssl_options.get("cert_reqs"), - ssl_ca_certs=self.cfg.ssl_options.get("ca_certs"), - ssl_ciphers=self.cfg.ssl_options.get("ciphers") - ) - - if self.cfg.settings["backlog"].value: - config["backlog"] = self.cfg.settings["backlog"].value - - config.update(self.EMMETT_CONFIG) - - self.config = Config(**config) - - def init_process(self): - self.config.setup_event_loop() - super().init_process() - - def init_signals(self) -> None: - for s in self.SIGNALS: - signal.signal(s, signal.SIG_DFL) - signal.signal(signal.SIGUSR1, self.handle_usr1) - signal.siginterrupt(signal.SIGUSR1, False) - - async def _serve(self) -> None: - self.config.app = self.wsgi - server = Server(config=self.config) - await server.serve(sockets=self.sockets) - if not server.started: - sys.exit(Arbiter.WORKER_BOOT_ERROR) - - def run(self) -> None: - return asyncio.run(self._serve()) - - async def callback_notify(self) -> None: - self.notify() - - -class EmmettWorker(Worker): - EMMETT_CONFIG = { - "loop": "uvloop", - "http": "httptools", - "proxy_headers": False, - "interface": "asgi3" - } - - -class EmmettH11Worker(EmmettWorker): - EMMETT_CONFIG = { - "loop": "auto", - "http": "h11", - "proxy_headers": False, - "interface": "asgi3" - } diff --git a/emmett/asgi/wrappers.py b/emmett/asgi/wrappers.py index 90251d87..cf12204d 100644 --- a/emmett/asgi/wrappers.py +++ b/emmett/asgi/wrappers.py @@ -1,270 +1,26 @@ # -*- coding: utf-8 -*- """ - emmett.asgi.wrappers - -------------------- +emmett.asgi.wrappers +-------------------- - Provides ASGI request and websocket wrappers +Provides ASGI request and websocket wrappers - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ -import asyncio +import pendulum +from emmett_core.protocols.asgi.wrappers import Request as _Request, Websocket as Websocket +from emmett_core.utils import cachedprop -from datetime import datetime -from typing import Any, Callable, Dict, Iterator, List, Mapping, Optional, Tuple, Union -from urllib.parse import parse_qs -from ..datastructures import sdict -from ..http import HTTP -from ..utils import cachedprop -from ..wrappers.helpers import regex_client -from ..wrappers.request import Request as _Request -from ..wrappers.websocket import Websocket as _Websocket -from .helpers import RequestCancelled -from .typing import Scope, Receive, Send - -_push_headers = { - "accept", - "accept-encoding", - "accept-language", - "cache-control", - "user-agent" -} - - -class Headers(Mapping[str, str]): - __slots__ = ["_data"] - - def __init__(self, scope: Dict[str, Any]): - self._data: Dict[bytes, bytes] = { - key: val for key, val in scope["headers"] - } - - __hash__ = None # type: ignore - - def __getitem__(self, key: str) -> str: - return self._data[key.lower().encode("latin-1")].decode("latin-1") - - def __contains__(self, key: str) -> bool: # type: ignore - return key.lower().encode("latin-1") in self._data - - def __iter__(self) -> Iterator[str]: - for key in self._data.keys(): - yield key.decode("latin-1") - - def __len__(self) -> int: - return len(self._data) - - def get( - self, - key: str, - default: Optional[Any] = None, - cast: Optional[Callable[[Any], Any]] = None - ) -> Any: - rv = self._data.get(key.lower().encode("latin-1")) - rv = rv.decode() if rv is not None else default # type: ignore - if cast is None: - return rv - try: - return cast(rv) - except ValueError: - return default - - def items(self) -> Iterator[Tuple[str, str]]: # type: ignore - for key, value in self._data.items(): - yield key.decode("latin-1"), value.decode("latin-1") - - def keys(self) -> Iterator[str]: # type: ignore - for key in self._data.keys(): - yield key.decode("latin-1") - - def values(self) -> Iterator[str]: # type: ignore - for value in self._data.values(): - yield value.decode("latin-1") - - -class Body: - __slots__ = ('_data', '_receive', '_max_content_length') - - def __init__(self, receive, max_content_length=None): - self._data = bytearray() - self._receive = receive - self._max_content_length = max_content_length - - def append(self, data: bytes): - if data == b'': - return - self._data.extend(data) - if ( - self._max_content_length is not None and - len(self._data) > self._max_content_length - ): - raise HTTP(413, 'Request entity too large') - - async def __load(self) -> bytes: - while True: - event = await self._receive() - if event['type'] == 'http.request': - self.append(event['body']) - if not event.get('more_body', False): - break - elif event['type'] == 'http.disconnect': - raise RequestCancelled - return bytes(self._data) - - def __await__(self): - return self.__load().__await__() - - -class ASGIIngressMixin: - def __init__( - self, - scope: Scope, - receive: Receive, - send: Send - ): - self._scope = scope - self._receive = receive - self._send = send - self.scheme = scope['scheme'] - self.path = scope['emt.path'] - - @cachedprop - def headers(self) -> Headers: - return Headers(self._scope) - - @cachedprop - def query_params(self) -> sdict[str, Union[str, List[str]]]: - rv: sdict[str, Any] = sdict() - for key, values in parse_qs( - self._scope['query_string'].decode('latin-1'), keep_blank_values=True - ).items(): - if len(values) == 1: - rv[key] = values[0] - continue - rv[key] = values - return rv - - @cachedprop - def client(self) -> str: - g = regex_client.search(self.headers.get('x-forwarded-for', '')) - client = ( - (g.group() or '').split(',')[0] if g else ( - self._scope['client'][0] if self._scope['client'] else None - ) - ) - if client in (None, '', 'unknown', 'localhost'): - client = '::1' if self.host.startswith('[') else '127.0.0.1' - return client # type: ignore - - -class Request(ASGIIngressMixin, _Request): - __slots__ = ['_scope', '_receive', '_send'] - - def __init__( - self, - scope: Scope, - receive: Receive, - send: Send, - max_content_length: Optional[int] = None, - body_timeout: Optional[int] = None - ): - super().__init__(scope, receive, send) - self.max_content_length = max_content_length - self.body_timeout = body_timeout - self._now = datetime.utcnow() - self.method = scope['method'] +class Request(_Request): + __slots__ = [] @cachedprop - def _input(self): - return Body(self._receive, self.max_content_length) + def now(self) -> pendulum.DateTime: + return pendulum.instance(self._now) @cachedprop - async def body(self) -> bytes: - if ( - self.max_content_length and - self.content_length > self.max_content_length - ): - raise HTTP(413, 'Request entity too large') - try: - rv = await asyncio.wait_for(self._input, timeout=self.body_timeout) - except asyncio.TimeoutError: - raise HTTP(408, 'Request timeout') - return rv - - async def push_promise(self, path: str): - if "http.response.push" not in self._scope.get("extensions", {}): - return - await self._send({ - "type": "http.response.push", - "path": path, - "headers": [ - (key.encode("latin-1"), self.headers[key].encode("latin-1")) - for key in _push_headers & set(self.headers.keys()) - ] - }) - - -class Websocket(ASGIIngressMixin, _Websocket): - __slots__ = ['_scope', '_receive', '_send', '_accepted'] - - def __init__( - self, - scope: Scope, - receive: Receive, - send: Send - ): - super().__init__(scope, receive, send) - self._accepted = False - self._flow_receive = None - self._flow_send = None - self.receive = self._accept_and_receive - self.send = self._accept_and_send - - @property - def _asgi_spec_version(self) -> int: - return int(''.join( - self._scope.get('asgi', {}).get('spec_version', '2.0').split('.') - )) - - def _encode_headers( - self, - headers: Dict[str, str] - ) -> List[Tuple[bytes, bytes]]: - return [ - (key.encode('utf-8'), val.encode('utf-8')) - for key, val in headers.items() - ] - - async def accept( - self, - headers: Optional[Dict[str, str]] = None, - subprotocol: Optional[str] = None - ): - if self._accepted: - return - message: Dict[str, Any] = { - 'type': 'websocket.accept', - 'subprotocol': subprotocol - } - if headers and self._asgi_spec_version > 20: - message['headers'] = self._encode_headers(headers) - await self._send(message) - self._accepted = True - self.receive = self._wrapped_receive - self.send = self._wrapped_send - - async def _wrapped_receive(self) -> Any: - data = await self._receive() - for method in self._flow_receive: # type: ignore - data = method(data) - return data - - async def _wrapped_send(self, data: Any): - for method in self._flow_send: # type: ignore - data = method(data) - if isinstance(data, str): - await self._send({'type': 'websocket.send', 'text': data}) - else: - await self._send({'type': 'websocket.send', 'bytes': data}) + def now_local(self) -> pendulum.DateTime: + return self.now.in_timezone(pendulum.local_timezone()) # type: ignore diff --git a/emmett/cache.py b/emmett/cache.py index 026b4f01..e9bbebe0 100644 --- a/emmett/cache.py +++ b/emmett/cache.py @@ -1,329 +1,37 @@ # -*- coding: utf-8 -*- """ - emmett.cache - ------------ +emmett.cache +------------ - Provides a caching system. +Provides a caching system. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ from __future__ import annotations -import asyncio -import heapq import os import pickle import tempfile import threading import time +from typing import Any, List, Optional -from collections import OrderedDict -from functools import wraps -from typing import ( - Any, Awaitable, Callable, Dict, List, Optional, Tuple, Union, overload -) +from emmett_core.cache import Cache as Cache +from emmett_core.cache.handlers import CacheHandler, RamCache as RamCache, RedisCache as RedisCache from ._shortcuts import hashlib_sha1 from .ctx import current from .libs.portalocker import LockedFile -from .typing import T - -__all__ = ['Cache'] - - -class CacheHashMixin: - def __init__(self): - self.strategies = OrderedDict() - - def add_strategy( - self, - key: str, - method: Callable[..., Any] = lambda data: data - ): - self.strategies[key] = method - - def _hash_component(self, key: str, data: Any) -> str: - return ''.join([key, "{", repr(data), "}"]) - - def _build_hash(self, data: Dict[str, Any]) -> str: - components = [] - for key, strategy in self.strategies.items(): - components.append(self._hash_component(key, strategy(data[key]))) - return hashlib_sha1(':'.join(components)).hexdigest() - - def _build_ctx_key(self, **ctx) -> str: - return self.key + ":" + self._build_hash(ctx) # type: ignore # noqa - - @staticmethod - def dict_strategy(data: Dict[str, Any]) -> List[Tuple[str, Any]]: - return [(key, data[key]) for key in sorted(data)] - - -class CacheHandler: - def __init__(self, prefix: str = '', default_expire: int = 300): - self._default_expire = default_expire - self._prefix = prefix - - @staticmethod - def _key_prefix_(method: Callable[..., Any]) -> Callable[..., Any]: - @wraps(method) - def wrap(self, key: Optional[str] = None, *args, **kwargs) -> Any: - key = self._prefix + key if key is not None else key - return method(self, key, *args, **kwargs) - return wrap - - @staticmethod - def _convert_duration_(method: Callable[..., Any]) -> Callable[..., Any]: - @wraps(method) - def wrap( - self, - key: str, - value: Any, - duration: Union[int, str, None] = 'default' - ) -> Any: - if duration is None: - duration = 60 * 60 * 24 * 365 - if duration == "default": - duration = self._default_expire - now = time.time() - return method( - self, key, value, - now=now, - duration=duration, - expiration=now + duration # type: ignore - ) - return wrap - - @overload - def __call__( - self, - key: Optional[str] = None, - function: None = None, - duration: Union[int, str, None] = 'default' - ) -> CacheDecorator: - ... - - @overload - def __call__( - self, - key: str, - function: Optional[Callable[..., T]], - duration: Union[int, str, None] = 'default' - ) -> T: - ... - - def __call__( - self, - key: Optional[str] = None, - function: Optional[Callable[..., T]] = None, - duration: Union[int, str, None] = 'default' - ) -> Union[CacheDecorator, T]: - if function: - if asyncio.iscoroutinefunction(function): - return self.get_or_set_loop(key, function, duration) # type: ignore - return self.get_or_set(key, function, duration) # type: ignore - return CacheDecorator(self, key, duration) - - def get_or_set( - self, - key: str, - function: Callable[[], T], - duration: Union[int, str, None] = 'default' - ) -> T: - value = self.get(key) - if value is None: - value = function() - self.set(key, value, duration) - return value - - async def get_or_set_loop( - self, - key: str, - function: Callable[[], T], - duration: Union[int, str, None] = 'default' - ) -> T: - value = self.get(key) - if value is None: - value = await function() # type: ignore - self.set(key, value, duration) - return value - - def get(self, key: str) -> Any: - return None - - def set(self, key: str, value: Any, duration: Union[int, str, None]): - pass - - def clear(self, key: Optional[str] = None): - pass - - def response( - self, - duration: Union[int, str, None] = 'default', - query_params: bool = True, - language: bool = True, - hostname: bool = False, - headers: List[str] = [] - ) -> RouteCacheRule: - return RouteCacheRule( - self, query_params, language, hostname, headers, duration - ) - - -class CacheDecorator(CacheHashMixin): - def __init__( - self, - handler: CacheHandler, - key: Optional[str], - duration: Union[int, str, None] = 'default' - ): - super().__init__() - self._cache = handler - self.key = key - self.duration = duration - self.add_strategy('args') - self.add_strategy('kwargs', self.dict_strategy) - - def _key_from_wrapped(self, f: Callable[..., Any]) -> str: - return f.__module__ + '.' + f.__name__ - - def _wrap_sync(self, f: Callable[..., Any]) -> Callable[..., Any]: - @wraps(f) - def wrap(*args, **kwargs) -> Any: - if not args and not kwargs: - key = self.key or self._key_from_wrapped(f) - else: - key = self._build_ctx_key(args=args, kwargs=kwargs) - return self._cache.get_or_set( - key, lambda: f(*args, **kwargs), self.duration - ) - return wrap - - def _wrap_loop( - self, - f: Callable[..., Awaitable[Any]] - ) -> Callable[..., Awaitable[Any]]: - @wraps(f) - async def wrap(*args, **kwargs) -> Any: - if not args and not kwargs: - key = self.key or self._key_from_wrapped(f) - else: - key = self._build_ctx_key(args=args, kwargs=kwargs) - return await self._cache.get_or_set_loop( - key, lambda: f(*args, **kwargs), self.duration - ) - return wrap - - def __call__(self, f: Callable[..., Any]) -> Callable[..., Any]: - rv = ( - self._wrap_loop(f) if asyncio.iscoroutinefunction(f) else - self._wrap_sync(f) - ) - if not self.key: - self.key = f.__module__ + '.' + f.__name__ - return rv - - -class RamElement: - __slots__ = ('value', 'exp', 'acc') - - def __init__(self, value: Any, exp: int, acc: int): - self.value = value - self.exp = exp - self.acc = acc - - -class RamCache(CacheHandler): - lock = threading.RLock() - - def __init__( - self, - prefix: str = '', - threshold: int = 500, - default_expire: int = 300 - ): - super().__init__( - prefix=prefix, default_expire=default_expire) - self.data: Dict[str, Any] = {} - self._heap_exp: List[Tuple[int, str]] = [] - self._heap_acc: List[Tuple[float, str]] = [] - self._threshold = threshold - - def _prune(self, now): - # remove expired items - while self._heap_exp: - exp, rk = self._heap_exp[0] - if exp >= now: - break - self._heap_exp.remove((exp, rk)) - element = self.data.get(rk) - if element and element.exp == exp: - self._heap_acc.remove((self.data[rk].acc, rk)) - del self.data[rk] - # remove threshold exceding elements - while len(self.data) > self._threshold: - rk = heapq.heappop(self._heap_acc)[1] - element = self.data.get(rk) - if element: - self._heap_exp.remove((element.exp, rk)) - del self.data[rk] - - @CacheHandler._key_prefix_ - def get(self, key: str) -> Any: - try: - with self.lock: - element = self.data[key] - now = time.time() - if element.exp < now: - return None - self._heap_acc.remove((element.acc, key)) - element.acc = now - heapq.heappush(self._heap_acc, (element.acc, key)) - val = element.value - except KeyError: - return None - return val - - @CacheHandler._key_prefix_ - @CacheHandler._convert_duration_ - def set(self, key: str, value: Any, **kwargs): - with self.lock: - self._prune(kwargs['now']) - heapq.heappush(self._heap_exp, (kwargs['expiration'], key)) - heapq.heappush(self._heap_acc, (kwargs['now'], key)) - self.data[key] = RamElement( - value, kwargs['expiration'], kwargs['now']) - - @CacheHandler._key_prefix_ - def clear(self, key: Optional[str] = None): - with self.lock: - if key is not None: - try: - rv = self.data[key] - self._heap_acc.remove((rv.acc, key)) - self._heap_exp.remove((rv.exp, key)) - del self.data[key] - return - except Exception: - return - self.data.clear() - self._heap_acc = [] - self._heap_exp = [] class DiskCache(CacheHandler): lock = threading.RLock() - _fs_transaction_suffix = '.__mt_cache' + _fs_transaction_suffix = ".__mt_cache" _fs_mode = 0o600 - def __init__( - self, - cache_dir: str = 'cache', - threshold: int = 500, - default_expire: int = 300 - ): + def __init__(self, cache_dir: str = "cache", threshold: int = 500, default_expire: int = 300): super().__init__(default_expire=default_expire) self._threshold = threshold self._path = os.path.join(current.app.root_path, cache_dir) @@ -356,7 +64,7 @@ def _prune(self): try: for i, fpath in enumerate(entries): remove = False - f = LockedFile(fpath, 'rb') + f = LockedFile(fpath, "rb") exp = pickle.load(f.file) f.close() remove = exp <= now or i % 3 == 0 @@ -370,7 +78,7 @@ def get(self, key: str) -> Any: try: with self.lock: now = time.time() - f = LockedFile(filename, 'rb') + f = LockedFile(filename, "rb") exp = pickle.load(f.file) if exp < now: f.close() @@ -390,10 +98,9 @@ def set(self, key: str, value: Any, **kwargs): if os.path.exists(filepath): self._del_file(filepath) try: - fd, tmp = tempfile.mkstemp( - suffix=self._fs_transaction_suffix, dir=self._path) - with os.fdopen(fd, 'wb') as f: - pickle.dump(kwargs['expiration'], f, 1) + fd, tmp = tempfile.mkstemp(suffix=self._fs_transaction_suffix, dir=self._path) + with os.fdopen(fd, "wb") as f: + pickle.dump(kwargs["expiration"], f, 1) pickle.dump(value, f, pickle.HIGHEST_PROTOCOL) os.rename(tmp, filename) os.chmod(filename, self._fs_mode) @@ -411,209 +118,3 @@ def clear(self, key: Optional[str] = None): return for name in self._list_dir(): self._del_file(name) - - -class RedisCache(CacheHandler): - def __init__( - self, - host: str = 'localhost', - port: int = 6379, - password: Optional[str] = None, - db: int = 0, - prefix: str = 'cache:', - default_expire: int = 300, - **kwargs - ): - super().__init__( - prefix=prefix, default_expire=default_expire) - try: - import redis - except ImportError: - raise RuntimeError('no redis module found') - self._cache = redis.Redis( - host=host, port=port, password=password, db=db, **kwargs - ) - - def _dump_obj(self, value: Any) -> bytes: - if isinstance(value, int): - return str(value).encode('ascii') - return b'!' + pickle.dumps(value) - - def _load_obj(self, value: Any) -> Any: - if value is None: - return None - if value.startswith(b'!'): - try: - return pickle.loads(value[1:]) - except pickle.PickleError: - return None - try: - return int(value) - except ValueError: - return None - - @CacheHandler._key_prefix_ - def get(self, key: str) -> Any: - return self._load_obj(self._cache.get(key)) - - @CacheHandler._key_prefix_ - @CacheHandler._convert_duration_ - def set(self, key: str, value: Any, **kwargs): - dumped = self._dump_obj(value) - return self._cache.setex( - name=key, - time=kwargs['duration'], - value=dumped - ) - - @CacheHandler._key_prefix_ - def clear(self, key: Optional[str] = None): - if key is not None: - if key.endswith('*'): - keys = self._cache.delete(self._cache.keys(key)) - if keys: - self._cache.delete(*keys) - return - self._cache.delete(key) - return - if self._prefix: - keys = self._cache.keys(self._prefix + '*') - if keys: - self._cache.delete(*keys) - return - self._cache.flushdb() - - -class RouteCacheRule(CacheHashMixin): - def __init__( - self, - handler: CacheHandler, - query_params: bool = True, - language: bool = True, - hostname: bool = False, - headers: List[str] = [], - duration: Union[int, str, None] = 'default' - ): - super().__init__() - self.cache = handler - self.check_headers = headers - self.duration = duration - self.add_strategy('kwargs', self.dict_strategy) - self._ctx_builders = [] - if hostname: - self.add_strategy('hostname') - self._ctx_builders.append( - ('hostname', lambda route, current: route.hostname)) - if language: - self.add_strategy('language') - self._ctx_builders.append( - ('language', lambda route, current: current.language)) - if query_params: - self.add_strategy('query_params', self.dict_strategy) - self._ctx_builders.append( - ('query_params', lambda route, current: - current.request.query_params)) - if headers: - self.add_strategy('headers', self.headers_strategy) - self._ctx_builders.append( - ('headers', lambda route, current: current.request.headers)) - - def _build_ctx_key(self, route: Any, **ctx) -> str: # type: ignore - return route.name + ":" + self._build_hash(ctx) - - def _build_ctx( - self, - kwargs: Dict[str, Any], - route: Any, - current: Any - ) -> Dict[str, Any]: - rv = {'kwargs': kwargs} - for key, builder in self._ctx_builders: - rv[key] = builder(route, current) - return rv - - def headers_strategy(self, data: Dict[str, str]) -> List[str]: - return [data[key] for key in self.check_headers] - - def __call__(self, f: Callable[..., Any]) -> Callable[..., Any]: - from .routing.router import Router - obj = Router.exposing() - obj.cache_rule = self - return f - - -class Cache: - def __init__(self, **kwargs): - #: load handlers - handlers = [] - for key, val in kwargs.items(): - if key == "default": - continue - handlers.append((key, val)) - if not handlers: - handlers.append(('ram', RamCache())) - #: set handlers - for name, handler in handlers: - setattr(self, name, handler) - _default_handler_name = kwargs.get('default', handlers[0][0]) - self._default_handler = getattr(self, _default_handler_name) - - @overload - def __call__( - self, - key: Optional[str] = None, - function: None = None, - duration: Union[int, str, None] = 'default' - ) -> CacheDecorator: - ... - - @overload - def __call__( - self, - key: str, - function: Optional[Callable[..., T]], - duration: Union[int, str, None] = 'default' - ) -> T: - ... - - def __call__( - self, - key: Optional[str] = None, - function: Optional[Callable[..., T]] = None, - duration: Union[int, str, None] = 'default' - ) -> Union[CacheDecorator, T]: - return self._default_handler(key, function, duration) - - def get(self, key: str) -> Any: - return self._default_handler.get(key) - - def set( - self, - key: str, - value: Any, - duration: Union[int, str, None] = 'default' - ): - self._default_handler.set(key, value, duration) - - def get_or_set( - self, - key: str, - function: Callable[..., T], - duration: Union[int, str, None] = 'default' - ) -> T: - return self._default_handler.get_or_set(key, function, duration) - - def clear(self, key: Optional[str] = None): - self._default_handler.clear(key) - - def response( - self, - duration: Union[int, str, None] = 'default', - query_params: bool = True, - language: bool = True, - hostname: bool = False, - headers: List[str] = [] - ) -> RouteCacheRule: - return self._default_handler.response( - duration, query_params, language, hostname, headers - ) diff --git a/emmett/cli.py b/emmett/cli.py index 9cd977c2..0a64d7b6 100644 --- a/emmett/cli.py +++ b/emmett/cli.py @@ -1,16 +1,16 @@ # -*- coding: utf-8 -*- """ - emmett.cli - --------- +emmett.cli +--------- - Provide command line tools for Emmett applications. +Provide command line tools for Emmett applications. - :copyright: 2014 Giovanni Barillari +:copyright: 2014 Giovanni Barillari - Based on the code of Flask (http://flask.pocoo.org) - :copyright: (c) 2014 by Armin Ronacher. +Based on the code of Flask (http://flask.pocoo.org) +:copyright: (c) 2014 by Armin Ronacher. - :license: BSD-3-Clause +:license: BSD-3-Clause """ import code @@ -20,11 +20,11 @@ import types import click +from emmett_core._internal import get_app_module, locate_app +from emmett_core.log import LOG_LEVELS +from emmett_core.server import run as sgi_run from .__version__ import __version__ as fw_version -from ._internal import locate_app, get_app_module -from .logger import LOG_LEVELS -from .server import run as sgi_run def find_app_module(): @@ -67,9 +67,7 @@ def find_db(module, var_name=None): from .orm import Database - matches = [ - v for k, v in module.__dict__.items() if isinstance(v, Database) - ] + matches = [v for k, v in module.__dict__.items() if isinstance(v, Database)] return matches @@ -129,8 +127,10 @@ def load_app(self): if self._loaded_app is not None: return self._loaded_app + from .app import App + import_name, app_name = self._get_import_name() - app = locate_app(import_name, app_name) if import_name else None + app = locate_app(App, import_name, app_name) if import_name else None if app is None: raise RuntimeError("Could not locate an Emmett application.") @@ -168,30 +168,19 @@ def set_app_value(ctx, param, value): ctx.ensure_object(ScriptInfo).app_import_path = value -app_option = click.Option( - ['-a', '--app'], - help='The application to run', - callback=set_app_value, - is_eager=True -) +app_option = click.Option(["-a", "--app"], help="The application to run", callback=set_app_value, is_eager=True) class EmmettGroup(click.Group): - def __init__( - self, - add_default_commands=True, - add_app_option=True, - add_debug_option=True, - **extra - ): - params = list(extra.pop('params', None) or ()) + def __init__(self, add_default_commands=True, add_app_option=True, add_debug_option=True, **extra): + params = list(extra.pop("params", None) or ()) if add_app_option: params.append(app_option) - #if add_debug_option: + # if add_debug_option: # params.append(debug_option) click.Group.__init__(self, params=params, **extra) - #self.create_app = create_app + # self.create_app = create_app if add_default_commands: self.add_command(develop_command) @@ -229,57 +218,47 @@ def get_command(self, ctx, name): pass def main(self, *args, **kwargs): - obj = kwargs.get('obj') + obj = kwargs.get("obj") if obj is None: obj = ScriptInfo() - kwargs['obj'] = obj + kwargs["obj"] = obj return super().main(*args, **kwargs) -@click.command('develop', short_help='Runs a development server.') -@click.option( - '--host', '-h', default='127.0.0.1', help='The interface to bind to.') -@click.option( - '--port', '-p', type=int, default=8000, help='The port to bind to.') -@click.option( - '--interface', type=click.Choice(['rsgi', 'asgi']), default='rsgi', - help='Application interface.') -@click.option( - '--loop', type=click.Choice(['auto', 'asyncio', 'uvloop']), default='auto', - help='Event loop implementation.') -@click.option( - '--ssl-certfile', type=str, default=None, help='SSL certificate file') +@click.command("develop", short_help="Runs a development server.") +@click.option("--host", "-h", default="127.0.0.1", help="The interface to bind to.") +@click.option("--port", "-p", type=int, default=8000, help="The port to bind to.") +@click.option("--interface", type=click.Choice(["rsgi", "asgi"]), default="rsgi", help="Application interface.") @click.option( - '--ssl-keyfile', type=str, default=None, help='SSL key file') -@click.option( - '--reloader/--no-reloader', is_flag=True, default=True, - help='Runs with reloader.') + "--loop", type=click.Choice(["auto", "asyncio", "uvloop"]), default="auto", help="Event loop implementation." +) +@click.option("--ssl-certfile", type=str, default=None, help="SSL certificate file") +@click.option("--ssl-keyfile", type=str, default=None, help="SSL key file") +@click.option("--reloader/--no-reloader", is_flag=True, default=True, help="Runs with reloader.") @pass_script_info -def develop_command( - info, host, port, interface, loop, ssl_certfile, ssl_keyfile, reloader -): - os.environ["EMMETT_RUN_ENV"] = 'true' +def develop_command(info, host, port, interface, loop, ssl_certfile, ssl_keyfile, reloader): + os.environ["EMMETT_RUN_ENV"] = "true" app_target = info._get_import_name() - if os.environ.get('EMMETT_RUN_MAIN') != 'true': + if os.environ.get("EMMETT_RUN_MAIN") != "true": click.echo( - ' '.join([ - "> Starting Emmett development server on app", - click.style(app_target[0], fg="cyan", bold=True) - ]) + " ".join(["> Starting Emmett development server on app", click.style(app_target[0], fg="cyan", bold=True)]) ) click.echo( - ' '.join([ - click.style("> Emmett application", fg="green"), - click.style(app_target[0], fg="cyan", bold=True), - click.style("running on", fg="green"), - click.style(f"http://{host}:{port}", fg="cyan"), - click.style("(press CTRL+C to quit)", fg="green") - ]) + " ".join( + [ + click.style("> Emmett application", fg="green"), + click.style(app_target[0], fg="cyan", bold=True), + click.style("running on", fg="green"), + click.style(f"http://{host}:{port}", fg="cyan"), + click.style("(press CTRL+C to quit)", fg="green"), + ] + ) ) if reloader: from ._reloader import run_with_reloader + runner = run_with_reloader else: runner = sgi_run @@ -290,7 +269,7 @@ def develop_command( host=host, port=port, loop=loop, - log_level='debug', + log_level="debug", log_access=True, threading_mode="workers", ssl_certfile=ssl_certfile, @@ -298,54 +277,46 @@ def develop_command( ) -@click.command('serve', short_help='Serve the app.') -@click.option( - '--host', '-h', default='0.0.0.0', help='The interface to bind to.') -@click.option( - '--port', '-p', type=int, default=8000, help='The port to bind to.') -@click.option( - "--workers", '-w', type=int, default=1, - help="Number of worker processes. Defaults to 1.") -@click.option( - "--threads", type=int, default=1, help="Number of worker threads.") -@click.option( - "--threading-mode", type=click.Choice(['runtime', 'workers']), default='workers', - help="Server threading mode.") -@click.option( - '--interface', type=click.Choice(['rsgi', 'asgi']), default='rsgi', - help='Application interface.') -@click.option( - '--http', type=click.Choice(['auto', '1', '2']), default='auto', - help='HTTP version.') -@click.option( - '--ws/--no-ws', is_flag=True, default=True, - help='Enable websockets support.') -@click.option( - '--loop', type=click.Choice(['auto', 'asyncio', 'uvloop']), default='auto', - help='Event loop implementation.') -@click.option( - '--opt/--no-opt', is_flag=True, default=False, - help='Enable loop optimizations.') -@click.option( - '--log-level', type=click.Choice(LOG_LEVELS.keys()), default='info', - help='Logging level.') -@click.option( - '--access-log/--no-access-log', is_flag=True, default=False, - help='Enable access log.') -@click.option( - '--backlog', type=int, default=2048, - help='Maximum number of connections to hold in backlog') -@click.option( - '--backpressure', type=int, - help='Maximum number of requests to process concurrently (per worker)') +@click.command("serve", short_help="Serve the app.") +@click.option("--host", "-h", default="0.0.0.0", help="The interface to bind to.") +@click.option("--port", "-p", type=int, default=8000, help="The port to bind to.") +@click.option("--workers", "-w", type=int, default=1, help="Number of worker processes. Defaults to 1.") +@click.option("--threads", type=int, default=1, help="Number of worker threads.") @click.option( - '--ssl-certfile', type=str, default=None, help='SSL certificate file') + "--threading-mode", type=click.Choice(["runtime", "workers"]), default="workers", help="Server threading mode." +) +@click.option("--interface", type=click.Choice(["rsgi", "asgi"]), default="rsgi", help="Application interface.") +@click.option("--http", type=click.Choice(["auto", "1", "2"]), default="auto", help="HTTP version.") +@click.option("--ws/--no-ws", is_flag=True, default=True, help="Enable websockets support.") @click.option( - '--ssl-keyfile', type=str, default=None, help='SSL key file') + "--loop", type=click.Choice(["auto", "asyncio", "uvloop"]), default="auto", help="Event loop implementation." +) +@click.option("--opt/--no-opt", is_flag=True, default=False, help="Enable loop optimizations.") +@click.option("--log-level", type=click.Choice(LOG_LEVELS.keys()), default="info", help="Logging level.") +@click.option("--access-log/--no-access-log", is_flag=True, default=False, help="Enable access log.") +@click.option("--backlog", type=int, default=2048, help="Maximum number of connections to hold in backlog") +@click.option("--backpressure", type=int, help="Maximum number of requests to process concurrently (per worker)") +@click.option("--ssl-certfile", type=str, default=None, help="SSL certificate file") +@click.option("--ssl-keyfile", type=str, default=None, help="SSL key file") @pass_script_info def serve_command( - info, host, port, workers, threads, threading_mode, interface, http, ws, loop, opt, - log_level, access_log, backlog, backpressure, ssl_certfile, ssl_keyfile + info, + host, + port, + workers, + threads, + threading_mode, + interface, + http, + ws, + loop, + opt, + log_level, + access_log, + backlog, + backpressure, + ssl_certfile, + ssl_keyfile, ): app_target = info._get_import_name() sgi_run( @@ -369,31 +340,22 @@ def serve_command( ) -@click.command('shell', short_help='Runs a shell in the app context.') +@click.command("shell", short_help="Runs a shell in the app context.") @pass_script_info def shell_command(info): - os.environ['EMMETT_CLI_ENV'] = 'true' + os.environ["EMMETT_CLI_ENV"] = "true" ctx = info.load_appctx() app = info.load_app() - banner = 'Python %s on %s\nEmmett %s shell on app: %s' % ( - sys.version, - sys.platform, - fw_version, - app.import_name - ) + banner = "Python %s on %s\nEmmett %s shell on app: %s" % (sys.version, sys.platform, fw_version, app.import_name) code.interact(banner=banner, local=app.make_shell_context(ctx)) -@click.command('routes', short_help='Display the app routing table.') +@click.command("routes", short_help="Display the app routing table.") @pass_script_info def routes_command(info): app = info.load_app() click.echo( - "".join([ - "> Routing table for Emmett application ", - click.style(app.import_name, fg="cyan", bold=True), - ":" - ]) + "".join(["> Routing table for Emmett application ", click.style(app.import_name, fg="cyan", bold=True), ":"]) ) for route in app._router_http._routes_str.values(): click.echo(route) @@ -408,114 +370,100 @@ def set_db_value(ctx, param, value): ctx.ensure_object(ScriptInfo).db_var_name = value -@cli.group('migrations', short_help='Runs migration operations.') -@click.option( - '--db', help='The db instance to use', callback=set_db_value, is_eager=True -) +@cli.group("migrations", short_help="Runs migration operations.") +@click.option("--db", help="The db instance to use", callback=set_db_value, is_eager=True) def migrations_cli(db): pass -@migrations_cli.command('status', short_help='Shows current database revision.') -@click.option('--verbose', '-v', default=False, is_flag=True) +@migrations_cli.command("status", short_help="Shows current database revision.") +@click.option("--verbose", "-v", default=False, is_flag=True) @pass_script_info def migrations_status(info, verbose): from .orm.migrations.commands import status + app = info.load_app() dbs = info.load_db() status(app, dbs, verbose) -@migrations_cli.command('history', short_help="Shows migrations history.") -@click.option('--range', '-r', default=None) -@click.option('--verbose', '-v', default=False, is_flag=True) +@migrations_cli.command("history", short_help="Shows migrations history.") +@click.option("--range", "-r", default=None) +@click.option("--verbose", "-v", default=False, is_flag=True) @pass_script_info def migrations_history(info, range, verbose): from .orm.migrations.commands import history + app = info.load_app() dbs = info.load_db() history(app, dbs, range, verbose) -@migrations_cli.command( - 'generate', short_help='Generates a new migration from application models.' -) -@click.option( - '--message', '-m', default='Generated migration', - help='The description for the new migration.' -) -@click.option('-head', default='head', help='The migration to generate from') +@migrations_cli.command("generate", short_help="Generates a new migration from application models.") +@click.option("--message", "-m", default="Generated migration", help="The description for the new migration.") +@click.option("-head", default="head", help="The migration to generate from") @pass_script_info def migrations_generate(info, message, head): from .orm.migrations.commands import generate + app = info.load_app() dbs = info.load_db() generate(app, dbs, message, head) -@migrations_cli.command('new', short_help='Generates a new empty migration.') -@click.option( - '--message', '-m', default='New migration', - help='The description for the new migration.' -) -@click.option('-head', default='head', help='The migration to generate from') +@migrations_cli.command("new", short_help="Generates a new empty migration.") +@click.option("--message", "-m", default="New migration", help="The description for the new migration.") +@click.option("-head", default="head", help="The migration to generate from") @pass_script_info def migrations_new(info, message, head): from .orm.migrations.commands import new + app = info.load_app() dbs = info.load_db() new(app, dbs, message, head) -@migrations_cli.command( - 'up', short_help='Upgrades the database to the selected migration.' -) -@click.option('--revision', '-r', default='head', help='The migration to upgrade to.') +@migrations_cli.command("up", short_help="Upgrades the database to the selected migration.") +@click.option("--revision", "-r", default="head", help="The migration to upgrade to.") @click.option( - '--dry-run', + "--dry-run", default=False, is_flag=True, - help='Only print SQL instructions, without actually applying the migration.' + help="Only print SQL instructions, without actually applying the migration.", ) @pass_script_info def migrations_up(info, revision, dry_run): from .orm.migrations.commands import up + app = info.load_app() dbs = info.load_db() up(app, dbs, revision, dry_run) -@migrations_cli.command( - 'down', short_help='Downgrades the database to the selected migration.' -) -@click.option('--revision', '-r', required=True, help='The migration to downgrade to.') +@migrations_cli.command("down", short_help="Downgrades the database to the selected migration.") +@click.option("--revision", "-r", required=True, help="The migration to downgrade to.") @click.option( - '--dry-run', + "--dry-run", default=False, is_flag=True, - help='Only print SQL instructions, without actually applying the migration.' + help="Only print SQL instructions, without actually applying the migration.", ) @pass_script_info def migrations_down(info, revision, dry_run): from .orm.migrations.commands import down + app = info.load_app() dbs = info.load_db() down(app, dbs, revision, dry_run) -@migrations_cli.command( - 'set', short_help='Overrides database revision with selected migration.' -) -@click.option('--revision', '-r', default='head', help='The migration to set.') -@click.option( - '--auto-confirm', - default=False, - is_flag=True, - help='Skip asking confirmation.' -) +@migrations_cli.command("set", short_help="Overrides database revision with selected migration.") +@click.option("--revision", "-r", default="head", help="The migration to set.") +@click.option("--auto-confirm", default=False, is_flag=True, help="Skip asking confirmation.") @pass_script_info def migrations_set(info, revision, auto_confirm): from .orm.migrations.commands import set_revision + app = info.load_app() dbs = info.load_db() set_revision(app, dbs, revision, auto_confirm) @@ -525,5 +473,5 @@ def main(as_module=False): cli.main(prog_name="python -m emmett" if as_module else None) -if __name__ == '__main__': +if __name__ == "__main__": main(as_module=True) diff --git a/emmett/ctx.py b/emmett/ctx.py index 73bee75f..45dd62d9 100644 --- a/emmett/ctx.py +++ b/emmett/ctx.py @@ -1,119 +1,62 @@ # -*- coding: utf-8 -*- """ - emmett.ctx - ---------- +emmett.ctx +---------- - Provides the current object. - Used by application to deal with request related objects. +Provides the current object. +Used by application to deal with request related objects. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ -import contextvars - from datetime import datetime -from typing import Any import pendulum +from emmett_core.ctx import ( + Context as _Context, + Current as _Current, + RequestContext as _RequestContext, + WSContext as _WsContext, + _ctxv, +) +from emmett_core.utils import cachedprop -from .utils import cachedprop - -_ctxv = contextvars.ContextVar("_emt_ctxv") - - -class Context: - __slots__ = ["app", "__dict__"] - def __init__(self): - self.language = None +class Context(_Context): + __slots__ = [] @property def now(self): return pendulum.instance(datetime.utcnow()) -class RequestContext(Context): - __slots__ = ["request", "response", "session"] - - def __init__( - self, - app, - request, - response - ): - self.app = app - self.request = request - self.response = response - self.session = None - - @property - def now(self): - return self.request.now +class RequestContext(_RequestContext): + __slots__ = [] @cachedprop def language(self): - return self.request.accept_language.best_match( - list(self.app.translator._langmap) - ) + return self.request.accept_language.best_match(list(self.app.translator._langmap)) -class WSContext(Context): - __slots__ = ["websocket", "session"] +class WSContext(_WsContext): + __slots__ = [] - def __init__(self, app, websocket): - self.app = app - self.websocket = websocket - self.session = None + @property + def now(self): + return pendulum.instance(datetime.utcnow()) @cachedprop def language(self): - return self.websocket.accept_language.best_match( - list(self.app.translator._langmap) - ) + return self.websocket.accept_language.best_match(list(self.app.translator._langmap)) -class Current: +class Current(_Current): __slots__ = [] - ctx = property(_ctxv.get) - def __init__(self): _ctxv.set(Context()) - def _init_(self, ctx): - return _ctxv.set(ctx) - - def _close_(self, token): - _ctxv.reset(token) - - def __getattr__(self, name: str) -> Any: - return getattr(self.ctx, name) - - def __setattr__(self, name: str, value: Any): - setattr(self.ctx, name, value) - - def __delattr__(self, name: str): - delattr(self.ctx, name) - - def __getitem__(self, name: str) -> Any: - try: - return getattr(self.ctx, name) - except AttributeError as e: - raise KeyError from e - - def __setitem__(self, name: str, value: Any): - setattr(self.ctx, name, value) - - def __delitem__(self, name: str): - delattr(self.ctx, name) - - def __contains__(self, name: str) -> bool: - return hasattr(self.ctx, name) - - def get(self, name: str, default: Any = None) -> Any: - return getattr(self.ctx, name, default) - @property def T(self): return self.ctx.app.translator diff --git a/emmett/datastructures.py b/emmett/datastructures.py index e83ea1e1..a40a0836 100644 --- a/emmett/datastructures.py +++ b/emmett/datastructures.py @@ -1,104 +1,15 @@ # -*- coding: utf-8 -*- """ - emmett.datastructures - --------------------- +emmett.datastructures +--------------------- - Provide some useful data structures. +Provide some useful data structures. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ -import copy -import hashlib -import pickle - -from typing import Dict, Optional - -from ._internal import ImmutableList -from .typing import KT, VT - - -class sdict(Dict[KT, VT]): - #: like a dictionary except `obj.foo` can be used in addition to - # `obj['foo']`, and setting obj.foo = None deletes item foo. - __slots__ = () - - __setattr__ = dict.__setitem__ # type: ignore - __delattr__ = dict.__delitem__ # type: ignore - __getitem__ = dict.get # type: ignore - - # see http://stackoverflow.com/questions/10364332/how-to-pickle-python-object-derived-from-dict - def __getattr__(self, key: str) -> Optional[VT]: - if key.startswith('__'): - raise AttributeError - return self.get(key, None) # type: ignore - - __repr__ = lambda self: '' % dict.__repr__(self) - __getstate__ = lambda self: None - __copy__ = lambda self: sdict(self) - __deepcopy__ = lambda self, memo: sdict(copy.deepcopy(dict(self))) - - -class ConfigData(sdict[KT, VT]): - #: like sdict, except it autogrows creating sub-sdict attributes. - # Useful for configurations. - __slots__ = () - - def __getitem__(self, key): - if key not in self.keys(): - self[key] = sdict() - return super().__getitem__(key) - - __getattr__ = __getitem__ - - -class SessionData(sdict): - __slots__ = ('__sid', '__hash', '__expires', '__dump') - - def __init__(self, initial=None, sid=None, expires=None): - sdict.__init__(self, initial or ()) - object.__setattr__( - self, '_SessionData__dump', pickle.dumps(sdict(self))) - h = hashlib.md5(self._dump).hexdigest() - object.__setattr__(self, '_SessionData__sid', sid) - object.__setattr__(self, '_SessionData__hash', h) - object.__setattr__(self, '_SessionData__expires', expires) - - @property - def _sid(self): - return self.__sid - - @property - def _modified(self): - dump = pickle.dumps(sdict(self)) - h = hashlib.md5(dump).hexdigest() - if h != self.__hash: - object.__setattr__(self, '_SessionData__dump', dump) - return True - return False - - @property - def _expiration(self): - return self.__expires - - @property - def _dump(self): - # note: self.__dump is updated only on _modified call - return self.__dump - - def _expires_after(self, value): - object.__setattr__(self, '_SessionData__expires', value) - - -def _unique_list(seq, hashfunc=None): - seen = set() - seen_add = seen.add - if not hashfunc: - return [x for x in seq if x not in seen and not seen_add(x)] - return [ - x for x in seq if hashfunc(x) not in seen and not seen_add(hashfunc(x)) - ] +from emmett_core.datastructures import sdict as sdict class OrderedSet(set): @@ -144,7 +55,7 @@ def __add__(self, other): return self.union(other) def __repr__(self): - return '%s(%r)' % (self.__class__.__name__, self._list) + return "%s(%r)" % (self.__class__.__name__, self._list) __str__ = __repr__ @@ -208,90 +119,9 @@ def difference_update(self, other): __isub__ = difference_update -class Accept(ImmutableList): - def __init__(self, values=()): - if values is None: - list.__init__(self) - self.provided = False - elif isinstance(values, Accept): - self.provided = values.provided - list.__init__(self, values) - else: - self.provided = True - values = sorted(values, key=lambda x: (x[1], x[0]), reverse=True) - list.__init__(self, values) - - def _value_matches(self, value, item): - return item == '*' or item.lower() == value.lower() - - def __getitem__(self, key): - if isinstance(key, str): - return self.quality(key) - return list.__getitem__(self, key) - - def quality(self, key): - for item, quality in self: - if self._value_matches(key, item): - return quality - return 0 - - def __contains__(self, value): - for item, quality in self: - if self._value_matches(value, item): - return True - return False - - def __repr__(self): - return '%s([%s])' % ( - self.__class__.__name__, - ', '.join('(%r, %s)' % (x, y) for x, y in self) - ) - - def index(self, key): - if isinstance(key, str): - for idx, (item, quality) in enumerate(self): - if self._value_matches(key, item): - return idx - raise ValueError(key) - return list.index(self, key) - - def find(self, key): - try: - return self.index(key) - except ValueError: - return -1 - - def values(self): - for item in self: - yield item[0] - - def to_header(self): - result = [] - for value, quality in self: - if quality != 1: - value = '%s;q=%s' % (value, quality) - result.append(value) - return ','.join(result) - - def __str__(self): - return self.to_header() - - def best_match(self, matches, default=None): - best_quality = -1 - result = default - for server_item in matches: - for client_item, quality in self: - if quality <= best_quality: - break - if ( - self._value_matches(server_item, client_item) and - quality > 0 - ): - best_quality = quality - result = server_item - return result - - @property - def best(self): - if self: - return self[0][0] +def _unique_list(seq, hashfunc=None): + seen = set() + seen_add = seen.add + if not hashfunc: + return [x for x in seq if x not in seen and not seen_add(x)] + return [x for x in seq if hashfunc(x) not in seen and not seen_add(hashfunc(x))] diff --git a/emmett/debug.py b/emmett/debug.py index dc0735c0..278d8502 100644 --- a/emmett/debug.py +++ b/emmett/debug.py @@ -1,12 +1,12 @@ # -*- coding: utf-8 -*- """ - emmett.debug - ------------ +emmett.debug +------------ - Provides debugging utilities. +Provides debugging utilities. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ import inspect @@ -14,10 +14,9 @@ import sys import traceback +from emmett_core.utils import cachedprop from renoir import Renoir -from .utils import cachedprop - class Traceback: """Wraps a traceback.""" @@ -28,10 +27,8 @@ def __init__(self, app, exc_type, exc_value, tb): self.exc_value = exc_value if not isinstance(exc_type, str): exception_type = exc_type.__name__ - if exc_type.__module__ not in ( - '__builtin__', 'builtins', 'exceptions' - ): - exception_type = exc_type.__module__ + '.' + exception_type + if exc_type.__module__ not in ("__builtin__", "builtins", "exceptions"): + exception_type = exc_type.__module__ + "." + exception_type else: exception_type = exc_type self.exception_type = exception_type @@ -45,39 +42,31 @@ def __init__(self, app, exc_type, exc_value, tb): def exception(self): """String representation of the exception.""" buf = traceback.format_exception_only(self.exc_type, self.exc_value) - return ''.join(buf).strip() + return "".join(buf).strip() def generate_plaintext_traceback(self): """Like the plaintext attribute but returns a generator""" - yield u'Traceback (most recent call last):' + yield "Traceback (most recent call last):" for frame in self.frames: - yield u' File "%s", line %s, in %s' % ( - frame.filename, - frame.lineno, - frame.function_name - ) - yield u' ' + frame.current_line.strip() + yield ' File "%s", line %s, in %s' % (frame.filename, frame.lineno, frame.function_name) + yield " " + frame.current_line.strip() yield self.exception def generate_plain_tb_app(self): - yield u'Traceback (most recent call last):' + yield "Traceback (most recent call last):" for frame in self.frames: if frame.is_in_app: - yield u' File "%s", line %s, in %s' % ( - frame.filename, - frame.lineno, - frame.function_name - ) - yield u' ' + frame.current_line.strip() + yield ' File "%s", line %s, in %s' % (frame.filename, frame.lineno, frame.function_name) + yield " " + frame.current_line.strip() yield self.exception @property def full_tb(self): - return u'\n'.join(self.generate_plaintext_traceback()) + return "\n".join(self.generate_plaintext_traceback()) @property def app_tb(self): - return u'\n'.join(self.generate_plain_tb_app()) + return "\n".join(self.generate_plain_tb_app()) class Frame: @@ -91,50 +80,47 @@ def __init__(self, app, exc_type, exc_value, tb): self.globals = tb.tb_frame.f_globals fn = inspect.getsourcefile(tb) or inspect.getfile(tb) - if fn[-4:] in ('.pyo', '.pyc'): + if fn[-4:] in (".pyo", ".pyc"): fn = fn[:-1] # if it's a file on the file system resolve the real filename. if os.path.isfile(fn): fn = os.path.realpath(fn) self.filename = fn - self.module = self.globals.get('__name__') + self.module = self.globals.get("__name__") self.code = tb.tb_frame.f_code @property def is_in_fw(self): fw_path = os.path.dirname(__file__) - return self.filename[0:len(fw_path)] == fw_path + return self.filename[0 : len(fw_path)] == fw_path @property def is_in_app(self): - return self.filename[0:len(self.app.root_path)] == self.app.root_path + return self.filename[0 : len(self.app.root_path)] == self.app.root_path @property def rendered_filename(self): if self.is_in_app: - return self.filename[len(self.app.root_path) + 1:] + return self.filename[len(self.app.root_path) + 1 :] if self.is_in_fw: - return ''.join([ - "emmett.", - self.filename[ - len(os.path.dirname(__file__)) + 1: - ].replace("/", ".").split(".py")[0] - ]) + return "".join( + ["emmett.", self.filename[len(os.path.dirname(__file__)) + 1 :].replace("/", ".").split(".py")[0]] + ) return self.filename @cachedprop def sourcelines(self): try: - with open(self.filename, 'rb') as file: - source = file.read().decode('utf8') + with open(self.filename, "rb") as file: + source = file.read().decode("utf8") except IOError: - source = '' + source = "" return source.splitlines() @property def sourceblock(self): lmax = self.lineno + 4 - return u'\n'.join(self.sourcelines[self.first_line_no - 1:lmax]) + return "\n".join(self.sourcelines[self.first_line_no - 1 : lmax]) @property def first_line_no(self): @@ -152,22 +138,20 @@ def current_line(self): try: return self.sourcelines[self.lineno - 1] except IndexError: - return u'' + return "" @cachedprop def render_locals(self): - rv = dict() + rv = {} for k, v in self.locals.items(): try: rv[k] = str(v) except Exception: - rv[k] = '' + rv[k] = "" return rv -debug_templater = Renoir( - path=os.path.join(os.path.dirname(__file__), 'assets', 'debug') -) +debug_templater = Renoir(path=os.path.join(os.path.dirname(__file__), "assets", "debug")) def smart_traceback(app): @@ -176,4 +160,4 @@ def smart_traceback(app): def debug_handler(tb): - return debug_templater.render('view.html', {'tb': tb}) + return debug_templater.render("view.html", {"tb": tb}) diff --git a/emmett/extensions.py b/emmett/extensions.py index c1172e3f..8f9377cc 100644 --- a/emmett/extensions.py +++ b/emmett/extensions.py @@ -1,23 +1,19 @@ # -*- coding: utf-8 -*- """ - emmett.extensions - ----------------- +emmett.extensions +----------------- - Provides base classes to create extensions. +Provides base classes to create extensions. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ from __future__ import annotations -from collections import OrderedDict from enum import Enum -from functools import wraps -from typing import Any, Callable, Dict, Optional, TypeVar, Union -from ._internal import warn_of_deprecation -from .datastructures import sdict +from emmett_core.extensions import Extension as Extension, listen_signal as listen_signal class Signals(str, Enum): @@ -29,80 +25,3 @@ class Signals(str, Enum): before_database = "before_database" before_route = "before_route" before_routes = "before_routes" - - -class listen_signal: - _inst_count_ = 0 - - def __init__(self, signal: Union[Signals, str]): - if not isinstance(signal, Signals): - warn_of_deprecation( - "extensions.listen_signal str argument", - "extensions.Signals as argument", - stack=3 - ) - try: - signal = Signals[signal] - except KeyError: - raise SyntaxError(f"{signal} is not a valid signal") - self.signal = signal.value if isinstance(signal, Signals) else signal - self._inst_count_ = listen_signal._inst_count_ - listen_signal._inst_count_ += 1 - - def __call__(self, f: Callable[..., None]) -> listen_signal: - self.f = f - return self - - -class MetaExtension(type): - def __new__(cls, name, bases, attrs): - new_class = type.__new__(cls, name, bases, attrs) - declared_listeners = OrderedDict() - all_listeners = OrderedDict() - listeners = [] - for key, value in list(attrs.items()): - if isinstance(value, listen_signal): - listeners.append((key, value)) - listeners.sort(key=lambda x: x[1]._inst_count_) - declared_listeners.update(listeners) - new_class._declared_listeners_ = declared_listeners - for base in reversed(new_class.__mro__[1:]): - if hasattr(base, "_declared_listeners_"): - all_listeners.update(base._declared_listeners_) - all_listeners.update(declared_listeners) - new_class._all_listeners_ = all_listeners - return new_class - - -class Extension(metaclass=MetaExtension): - namespace: Optional[str] = None - default_config: Dict[str, Any] = {} - - def __init__(self, app, env: sdict, config: sdict): - self.app = app - self.env = env - self.config = config - self.__init_config() - self.__init_listeners() - - def __init_config(self): - for key, dval in self.default_config.items(): - self.config[key] = self.config.get(key, dval) - - def __init_listeners(self): - self._listeners_ = [] - for name, obj in self._all_listeners_.items(): - self._listeners_.append((obj.signal, _wrap_listener(self, obj.f))) - - def on_load(self): - pass - - -def _wrap_listener(ext, f): - @wraps(f) - def wrapped(*args, **kwargs): - return f(ext, *args, **kwargs) - return wrapped - - -ExtensionType = TypeVar("ExtensionType", bound=Extension) diff --git a/emmett/forms.py b/emmett/forms.py index 72cc69f1..fd0b87aa 100644 --- a/emmett/forms.py +++ b/emmett/forms.py @@ -1,12 +1,12 @@ # -*- coding: utf-8 -*- """ - emmett.forms - ------------ +emmett.forms +------------ - Provides classes to create and style forms. +Provides classes to create and style forms. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ from __future__ import annotations @@ -14,15 +14,17 @@ from functools import wraps from typing import Any, Callable, Dict, List, Optional, Type, Union +from emmett_core.http.wrappers.helpers import FileStorage +from emmett_core.utils import cachedprop + from .ctx import current from .datastructures import sdict -from .html import HtmlTag, tag, cat, asis +from .html import HtmlTag, asis, cat, tag from .orm import Field, Model from .orm.objects import Row, Table from .security import CSRFStorage -from .utils import cachedprop from .validators import isEmptyOr -from .wrappers.helpers import FileStorage + __all__ = ["Form", "ModelForm"] @@ -46,7 +48,7 @@ def __init__( _action: str = "", _enctype: str = "multipart/form-data", _method: str = "POST", - **kwargs: Any + **kwargs: Any, ): self._action = _action self._csrf = csrf @@ -61,9 +63,7 @@ def __init__( self.onvalidation = onvalidation self.writable_fields = writable_fields if not issubclass(self._formstyle, FormStyle): - raise RuntimeError( - "{!r} is an invalid Emmett form style".format(formstyle) - ) + raise RuntimeError("{!r} is an invalid Emmett form style".format(formstyle)) self._preprocess(**kwargs) def _preprocess(self, **kwargs): @@ -75,7 +75,7 @@ def _preprocess(self, **kwargs): "id_prefix": self._id_prefix, "hidden": {}, "submit": self._submit_text, - "upload": self._upload + "upload": self._upload, } self.attributes.update(kwargs) #: init the form @@ -90,9 +90,7 @@ def _preprocess(self, **kwargs): @property def csrf(self) -> bool: - return self._csrf is True or ( - self._csrf == "auto" and self._submit_method == "POST" - ) + return self._csrf is True or (self._csrf == "auto" and self._submit_method == "POST") def _load_csrf(self): if not self.csrf: @@ -175,16 +173,9 @@ async def _process(self, write_defaults=True): if self.csrf and not self.accepted: self.formkey = current.session._csrf.gen_token() # reset default values in form - if ( - write_defaults and ( - not self.processed or (self.accepted and not self.keepvalues) - ) - ): + if write_defaults and (not self.processed or (self.accepted and not self.keepvalues)): for field in self.fields: - self.input_params[field.name] = ( - field.default() if callable(field.default) else - field.default - ) + self.input_params[field.name] = field.default() if callable(field.default) else field.default return self def _render(self): @@ -229,12 +220,9 @@ def custom(self): custom.begin = asis(f"
") # add hidden stuffs to get process working hidden = cat() - hidden.append( - tag.input(_name="_csrf_token", _type="hidden", _value=self.formkey) - ) + hidden.append(tag.input(_name="_csrf_token", _type="hidden", _value=self.formkey)) for key, value in self.attributes["hidden"].items(): - hidden.append(tag.input(_name=key, _type="hidden", _value=value) - ) + hidden.append(tag.input(_name=key, _type="hidden", _value=value)) # provides end attribute custom.end = asis(f"{hidden.__html__()}
") return custom @@ -257,7 +245,7 @@ def __init__( _action: str = "", _enctype: str = "multipart/form-data", _method: str = "POST", - **kwargs: Any + **kwargs: Any, ): fields = fields or {} #: get fields from kwargs @@ -292,7 +280,7 @@ def __init__( upload=upload, _action=_action, _enctype=_enctype, - _method=_method + _method=_method, ) @@ -313,26 +301,23 @@ def __init__( _action: str = "", _enctype: str = "multipart/form-data", _method: str = "POST", - **attributes + **attributes, ): self.model = model._instance_() self.table: Table = self.model.table - self.record = record or ( - self.model.get(record_id) if record_id else - self.model.new() - ) + self.record = record or (self.model.get(record_id) if record_id else self.model.new()) #: build fields for form fields_list_all = [] fields_list_writable = [] if fields is not None: #: developer has selected specific fields if not isinstance(fields, dict): - fields = {'writable': fields, 'readable': fields} + fields = {"writable": fields, "readable": fields} for field in self.table: - if field.name not in fields['readable']: + if field.name not in fields["readable"]: continue fields_list_all.append(field) - if field.name in fields['writable']: + if field.name in fields["writable"]: fields_list_writable.append(field) else: #: use table fields @@ -359,7 +344,7 @@ def __init__( upload=upload, _action=_action, _enctype=_enctype, - _method=_method + _method=_method, ) def _get_id_value(self): @@ -368,16 +353,13 @@ def _get_id_value(self): return self.record[self.table._id.name] def _validate_input(self): - record, fields = self.record.clone(), { - field.name: self._get_input_val(field) - for field in self.writable_fields - } + record, fields = self.record.clone(), {field.name: self._get_input_val(field) for field in self.writable_fields} for field in filter(lambda f: f.type == "upload", self.writable_fields): val = fields[field.name] if ( - (val == b"" or val is None) and - not self.input_params.get(field.name + "__del", False) and - self.record[field.name] + (val == b"" or val is None) + and not self.input_params.get(field.name + "__del", False) + and self.record[field.name] ): fields.pop(field.name) record.update(fields) @@ -413,9 +395,7 @@ async def _process(self, **kwargs): continue else: source_file, original_filename = upload.stream, upload.filename - newfilename = field.store( - source_file, original_filename, field.uploadfolder - ) + newfilename = field.store(source_file, original_filename, field.uploadfolder) if isinstance(field.uploadfield, str): self.params[field.uploadfield] = source_file.read() self.params[field.name] = newfilename @@ -428,15 +408,11 @@ async def _process(self, **kwargs): #: cleanup inputs if not self.processed or (self.accepted and not self.keepvalues): for field in self.fields: - self.input_params[field.name] = field.formatter( - self.record[field.name] - ) + self.input_params[field.name] = field.formatter(self.record[field.name]) elif self.processed and not self.accepted and self.record._concrete: for field in self.writable_fields: if field.type == "upload" and field.name not in self.params: - self.input_params[field.name] = field.formatter( - self.record[field.name] - ) + self.input_params[field.name] = field.formatter(self.record[field.name]) return self @@ -456,17 +432,12 @@ def widget_string(attr, field, value, _class="string", _id=None): _name=field.name, _value=value if value is not None else "", _class=_class, - _id=_id or field.name + _id=_id or field.name, ) @staticmethod def widget_text(attr, field, value, _class="text", _id=None): - return tag.textarea( - value or "", - _name=field.name, - _class=_class, - _id=_id or field.name - ) + return tag.textarea(value or "", _name=field.name, _class=_class, _id=_id or field.name) @staticmethod def widget_int(attr, field, value, _class="int", _id=None): @@ -487,7 +458,7 @@ def widget_date(attr, field, value, _class="date", _id=None): _name=field.name, _value=value if value is not None else "", _class=_class, - _id=_id or field.name + _id=_id or field.name, ) @staticmethod @@ -497,7 +468,7 @@ def widget_time(attr, field, value, _class="time", _id=None): _name=field.name, _value=value if value is not None else "", _class=_class, - _id=_id or field.name + _id=_id or field.name, ) @staticmethod @@ -507,18 +478,12 @@ def widget_datetime(attr, field, value, _class="datetime", _id=None): _name=field.name, _value=value if value is not None else "", _class=_class, - _id=_id or field.name + _id=_id or field.name, ) @staticmethod def widget_password(attr, field, value, _class="password", _id=None): - return tag.input( - _type="password", - _name=field.name, - _value=value or "", - _class=_class, - _id=_id or field.name - ) + return tag.input(_type="password", _name=field.name, _value=value or "", _class=_class, _id=_id or field.name) @staticmethod def widget_bool(attr, field, value, _class="bool", _id=None): @@ -527,7 +492,7 @@ def widget_bool(attr, field, value, _class="bool", _id=None): _name=field.name, _checked="checked" if value else None, _class=_class, - _id=_id or field.name + _id=_id or field.name, ) @staticmethod @@ -537,16 +502,12 @@ def selected(k): options, multiple = FormStyle._field_options(field) if multiple: - return FormStyle.widget_multiple( - attr, field, value, options, _class=_class, _id=_id - ) + return FormStyle.widget_multiple(attr, field, value, options, _class=_class, _id=_id) return tag.select( - *[ - tag.option(n, _value=k, _selected=selected(k)) for k, n in options - ], + *[tag.option(n, _value=k, _selected=selected(k)) for k, n in options], _name=field.name, _class=_class, - _id=_id or field.name + _id=_id or field.name, ) @staticmethod @@ -556,13 +517,11 @@ def selected(k): values = values or [] return tag.select( - *[ - tag.option(n, _value=k, _selected=selected(k)) for k, n in options - ], + *[tag.option(n, _value=k, _selected=selected(k)) for k, n in options], _name=field.name, _class=_class, _multiple="multiple", - _id=_id or field.name + _id=_id or field.name, ) @staticmethod @@ -604,14 +563,10 @@ def _coerce_value(value): _class="checkbox", _id=_id + "__del", _name=field.name + "__del", - _style="display: inline;" - ), - tag.label( - "delete", - _for=_id + "__del", - _style="margin: 4px" + _style="display: inline;", ), - _style="white-space: nowrap;" + tag.label("delete", _for=_id + "__del", _style="margin: 4px"), + _style="white-space: nowrap;", ) ) return tag.div(*elements, _class="upload_wrap") @@ -627,19 +582,22 @@ def widget_jsonb(attr, field, value, _id=None): @staticmethod def widget_radio(field, value): options, _ = FormStyle._field_options(field) - return cat(*[ - tag.div( - tag.input( - _id=f"{field.name}_{k}", - _name=field.name, - _value=k, - _type="radio", - _checked=("checked" if str(k) == str(value) else None) - ), - tag.label(n, _for=f"{field.name}_{k}"), - _class="option_wrap" - ) for k, n in options - ]) + return cat( + *[ + tag.div( + tag.input( + _id=f"{field.name}_{k}", + _name=field.name, + _value=k, + _type="radio", + _checked=("checked" if str(k) == str(value) else None), + ), + tag.label(n, _for=f"{field.name}_{k}"), + _class="option_wrap", + ) + for k, n in options + ] + ) def __init__(self, attributes): self.attr = attributes @@ -667,16 +625,12 @@ def _get_widget(self, field, value): elif wtype.startswith("decimal"): wtype = "float" try: - widget = getattr(self, "widget_" + wtype)( - self.attr, field, value, _id=widget_id - ) + widget = getattr(self, "widget_" + wtype)(self.attr, field, value, _id=widget_id) if not field.writable: self._disable_widget(widget) return widget, False except AttributeError: - raise RuntimeError( - f"Missing form widget for field {field.name} of type {wtype}" - ) + raise RuntimeError(f"Missing form widget for field {field.name} of type {wtype}") def _disable_widget(self, widget): widget.attributes["_disabled"] = "disabled" @@ -743,7 +697,8 @@ def add_form_on_model(cls): @wraps(cls) def wrapped(model, *args, **kwargs): return cls(model, *args, **kwargs) + return wrapped -setattr(Model, "form", classmethod(add_form_on_model(ModelForm))) +Model.form = classmethod(add_form_on_model(ModelForm)) diff --git a/emmett/helpers.py b/emmett/helpers.py index bcf5824f..1603c8f5 100644 --- a/emmett/helpers.py +++ b/emmett/helpers.py @@ -1,78 +1,38 @@ # -*- coding: utf-8 -*- """ - emmett.helpers - -------------- +emmett.helpers +-------------- - Provides helping methods for applications. +Provides helping methods for applications. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ -import os -import re - from typing import Any, List, Optional, Tuple, Union -from pydal.exceptions import NotAuthorizedException, NotFoundException +from emmett_core._internal import deprecated +from emmett_core.http.helpers import abort as _abort from .ctx import current from .html import HtmlTag, tag -from .http import HTTP, HTTPFile, HTTPIO - -_re_dbstream = re.compile(r'(?P.*?)\.(?P.*?)\..*') -def abort(code: int, body: str = ''): - response = current.response - response.status = code - raise HTTP( - code, - body=body, - cookies=response.cookies - ) +def abort(code: int, body: str = ""): + _abort(current, code, body) +@deprecated("stream_file", "Response.wrap_file") def stream_file(path: str): - full_path = os.path.join(current.app.root_path, path) - raise HTTPFile( - full_path, - headers=current.response.headers, - cookies=current.response.cookies - ) + raise current.response.wrap_file(path) +@deprecated("stream_dbfile", "Response.wrap_dbfile") def stream_dbfile(db: Any, name: str): - items = _re_dbstream.match(name) - if not items: - abort(404) - table_name, field_name = items.group('table'), items.group('field') - try: - field = db[table_name][field_name] - except AttributeError: - abort(404) - try: - filename, path_or_stream = field.retrieve(name, nameonly=True) - except NotAuthorizedException: - abort(403) - except NotFoundException: - abort(404) - except IOError: - abort(404) - if isinstance(path_or_stream, str): - raise HTTPFile( - path_or_stream, - headers=current.response.headers, - cookies=current.response.cookies - ) - raise HTTPIO( - path_or_stream, - headers=current.response.headers, - cookies=current.response.cookies - ) - - -def flash(message: str, category: str = 'message'): + raise current.response.wrap_dbfile(db, name) + + +def flash(message: str, category: str = "message"): #: Flashes a message to the next request. if current.session._flashes is None: current.session._flashes = [] @@ -80,8 +40,7 @@ def flash(message: str, category: str = 'message'): def get_flashed_messages( - with_categories: bool = False, - category_filter: Union[str, List[str]] = [] + with_categories: bool = False, category_filter: Union[str, List[str]] = [] ) -> Union[List[str], Tuple[str, str]]: #: Pulls flashed messages from the session and returns them. # By default just the messages are returned, but when `with_categories` @@ -102,13 +61,9 @@ def get_flashed_messages( return flashes -def load_component( - url: str, - target: Optional[str] = None, - content: str = 'loading...' -) -> HtmlTag: +def load_component(url: str, target: Optional[str] = None, content: str = "loading...") -> HtmlTag: attr = {} if target: - attr['_id'] = target - attr['_data-emt_remote'] = url + attr["_id"] = target + attr["_data-emt_remote"] = url return tag.div(content, **attr) diff --git a/emmett/html.py b/emmett/html.py index 25e11c97..38b41a89 100644 --- a/emmett/html.py +++ b/emmett/html.py @@ -1,230 +1,84 @@ # -*- coding: utf-8 -*- """ - emmett.html - ----------- +emmett.html +----------- - Provides html generation classes. +Provides html generation classes. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ -import html import re -import threading - from functools import reduce -__all__ = ['tag', 'cat', 'asis'] - - -class TagStack(threading.local): - def __init__(self): - self.stack = [] - - def __getitem__(self, key): - return self.stack[key] - - def append(self, item): - self.stack.append(item) - - def pop(self, idx): - self.stack.pop(idx) - - def __bool__(self): - return len(self.stack) > 0 - +from emmett_core.html import ( + MetaHtmlTag as _MetaHtmlTag, + TagStack, + TreeHtmlTag, + _to_str, + cat as cat, + htmlescape as htmlescape, +) -class HtmlTag: - rules = { - 'ul': ['li'], - 'ol': ['li'], - 'table': ['tr', 'thead', 'tbody'], - 'thead': ['tr'], - 'tbody': ['tr'], - 'tr': ['td', 'th'], - 'select': ['option', 'optgroup'], - 'optgroup': ['optionp'] - } - _self_closed = {'br', 'col', 'embed', 'hr', 'img', 'input', 'link', 'meta'} - def __init__(self, name): - self.name = name - self.parent = None - self.components = [] - self.attributes = {} - if _stack: - _stack[-1].append(self) +__all__ = ["tag", "cat", "asis"] - def __enter__(self): - _stack.append(self) - return self +_re_tag = re.compile(r"^([\w\-\:]+)") +_re_id = re.compile(r"#([\w\-]+)") +_re_class = re.compile(r"\.([\w\-]+)") +_re_attr = re.compile(r"\[([\w\-\:]+)=(.*?)\]") - def __exit__(self, type, value, traceback): - _stack.pop(-1) - @staticmethod - def wrap(component, rules): - if rules and ( - not isinstance(component, HtmlTag) or component.name not in rules - ): - return HtmlTag(rules[0])(component) - return component +class HtmlTag(TreeHtmlTag): + __slots__ = [] def __call__(self, *components, **attributes): - rules = self.rules.get(self.name, []) - self.components = [self.wrap(comp, rules) for comp in components] - self.attributes = attributes - for component in self.components: - if isinstance(component, HtmlTag): - component.parent = self - return self - - def append(self, component): - self.components.append(component) - - def insert(self, i, component): - self.components.insert(i, component) - - def remove(self, component): - self.components.remove(component) - - def __getitem__(self, key): - if isinstance(key, int): - return self.components[key] - else: - return self.attributes.get(key) - - def __setitem__(self, key, value): - if isinstance(key, int): - self.components.insert(key, value) - else: - self.attributes[key] = value - - def __iter__(self): - for item in self.components: - yield item - - def __str__(self): - return self.__html__() - - def __add__(self, other): - return cat(self, other) - - def add_class(self, name): - """ add a class to _class attribute """ - c = self['_class'] - classes = (set(c.split()) if c else set()) | set(name.split()) - self['_class'] = ' '.join(classes) if classes else None - return self - - def remove_class(self, name): - """ remove a class from _class attribute """ - c = self['_class'] - classes = (set(c.split()) if c else set()) - set(name.split()) - self['_class'] = ' '.join(classes) if classes else None - return self - - regex_tag = re.compile(r'^([\w\-\:]+)') - regex_id = re.compile(r'#([\w\-]+)') - regex_class = re.compile(r'\.([\w\-]+)') - regex_attr = re.compile(r'\[([\w\-\:]+)=(.*?)\]') + # legacy "data" attribute + if _data := attributes.pop("data", None): + attributes["_data"] = _data + return super().__call__(*components, **attributes) def find(self, expr): union = lambda a, b: a.union(b) - if ',' in expr: - tags = reduce( - union, - [self.find(x.strip()) for x in expr.split(',')], - set()) - elif ' ' in expr: + if "," in expr: + tags = reduce(union, [self.find(x.strip()) for x in expr.split(",")], set()) + elif " " in expr: tags = [self] for k, item in enumerate(expr.split()): if k > 0: - children = [ - set([c for c in tag if isinstance(c, HtmlTag)]) - for tag in tags] + children = [{c for c in tag if isinstance(c, self.__class__)} for tag in tags] tags = reduce(union, children) tags = reduce(union, [tag.find(item) for tag in tags], set()) else: - tags = reduce( - union, - [c.find(expr) for c in self if isinstance(c, HtmlTag)], - set()) - tag = HtmlTag.regex_tag.match(expr) - id = HtmlTag.regex_id.match(expr) - _class = HtmlTag.regex_class.match(expr) - attr = HtmlTag.regex_attr.match(expr) + tags = reduce(union, [c.find(expr) for c in self if isinstance(c, self.__class__)], set()) + tag = _re_tag.match(expr) + id = _re_id.match(expr) + _class = _re_class.match(expr) + attr = _re_attr.match(expr) if ( - (tag is None or self.name == tag.group(1)) and - (id is None or self['_id'] == id.group(1)) and - (_class is None or _class.group(1) in - (self['_class'] or '').split()) and - (attr is None or self['_' + attr.group(1)] == attr.group(2)) + (tag is None or self.name == tag.group(1)) + and (id is None or self["_id"] == id.group(1)) + and (_class is None or _class.group(1) in (self["_class"] or "").split()) + and (attr is None or self["_" + attr.group(1)] == attr.group(2)) ): tags.add(self) return tags - def _build_html_attributes(self): - return ' '.join( - '%s="%s"' % (k[1:], k[1:] if v is True else htmlescape(v)) - for (k, v) in sorted(self.attributes.items()) - if k.startswith('_') and v is not None) - - def __html__(self): - name = self.name - attrs = self._build_html_attributes() - data = self.attributes.get('data', {}) - data_attrs = ' '.join( - 'data-%s="%s"' % (k, htmlescape(v)) for k, v in data.items()) - if data_attrs: - attrs = attrs + ' ' + data_attrs - attrs = ' ' + attrs if attrs else '' - if name in self._self_closed: - return '<%s%s />' % (name, attrs) - components = ''.join(htmlescape(v) for v in self.components) - return '<%s%s>%s' % (name, attrs, components, name) - - def __json__(self): - return str(self) - -class MetaHtmlTag: - def __getattr__(self, name): - return HtmlTag(name) - - def __getitem__(self, name): - return HtmlTag(name) - - -class cat(HtmlTag): - def __init__(self, *components): - self.components = [c for c in components] - self.attributes = {} - - def __html__(self): - return ''.join(htmlescape(v) for v in self.components) +class MetaHtmlTag(_MetaHtmlTag): + __slots__ = [] + _tag_cls = HtmlTag class asis(HtmlTag): - def __init__(self, text): - self.text = text - - def __html__(self): - return _to_str(self.text) + __slots__ = [] + def __init__(self, val): + self.name = val -def _to_str(obj): - if not isinstance(obj, str): - return str(obj) - return obj - - -def htmlescape(obj): - if hasattr(obj, '__html__'): - return obj.__html__() - return html.escape(_to_str(obj), True).replace("'", "'") + def __html__(self): + return _to_str(self.name) -_stack = TagStack() -tag = MetaHtmlTag() +tag = MetaHtmlTag(TagStack()) diff --git a/emmett/http.py b/emmett/http.py index c96aa3ea..e685be8e 100644 --- a/emmett/http.py +++ b/emmett/http.py @@ -1,368 +1,76 @@ # -*- coding: utf-8 -*- """ - emmett.http - ----------- +emmett.http +----------- - Provides the HTTP interfaces. +Provides the HTTP interfaces. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ from __future__ import annotations -import errno -import os -import stat +from emmett_core.http.helpers import redirect as _redirect +from emmett_core.http.response import ( + HTTPAsyncIterResponse as HTTPAsyncIter, + HTTPBytesResponse as HTTPBytes, + HTTPFileResponse as HTTPFile, + HTTPIOResponse as HTTPIO, + HTTPIterResponse as HTTPIter, + HTTPResponse as HTTPResponse, + HTTPStringResponse as HTTPStringResponse, +) -from email.utils import formatdate -from hashlib import md5 -from typing import Any, AsyncIterable, BinaryIO, Dict, Generator, Iterable, Tuple - -from granian.rsgi import HTTPProtocol - -from ._internal import loop_open_file from .ctx import current -from .libs.contenttype import contenttype +HTTP = HTTPStringResponse + status_codes = { - 100: '100 CONTINUE', - 101: '101 SWITCHING PROTOCOLS', - 200: '200 OK', - 201: '201 CREATED', - 202: '202 ACCEPTED', - 203: '203 NON-AUTHORITATIVE INFORMATION', - 204: '204 NO CONTENT', - 205: '205 RESET CONTENT', - 206: '206 PARTIAL CONTENT', - 207: '207 MULTI-STATUS', - 300: '300 MULTIPLE CHOICES', - 301: '301 MOVED PERMANENTLY', - 302: '302 FOUND', - 303: '303 SEE OTHER', - 304: '304 NOT MODIFIED', - 305: '305 USE PROXY', - 307: '307 TEMPORARY REDIRECT', - 400: '400 BAD REQUEST', - 401: '401 UNAUTHORIZED', - 403: '403 FORBIDDEN', - 404: '404 NOT FOUND', - 405: '405 METHOD NOT ALLOWED', - 406: '406 NOT ACCEPTABLE', - 407: '407 PROXY AUTHENTICATION REQUIRED', - 408: '408 REQUEST TIMEOUT', - 409: '409 CONFLICT', - 410: '410 GONE', - 411: '411 LENGTH REQUIRED', - 412: '412 PRECONDITION FAILED', - 413: '413 REQUEST ENTITY TOO LARGE', - 414: '414 REQUEST-URI TOO LONG', - 415: '415 UNSUPPORTED MEDIA TYPE', - 416: '416 REQUESTED RANGE NOT SATISFIABLE', - 417: '417 EXPECTATION FAILED', - 422: '422 UNPROCESSABLE ENTITY', - 500: '500 INTERNAL SERVER ERROR', - 501: '501 NOT IMPLEMENTED', - 502: '502 BAD GATEWAY', - 503: '503 SERVICE UNAVAILABLE', - 504: '504 GATEWAY TIMEOUT', - 505: '505 HTTP VERSION NOT SUPPORTED', + 100: "100 CONTINUE", + 101: "101 SWITCHING PROTOCOLS", + 200: "200 OK", + 201: "201 CREATED", + 202: "202 ACCEPTED", + 203: "203 NON-AUTHORITATIVE INFORMATION", + 204: "204 NO CONTENT", + 205: "205 RESET CONTENT", + 206: "206 PARTIAL CONTENT", + 207: "207 MULTI-STATUS", + 300: "300 MULTIPLE CHOICES", + 301: "301 MOVED PERMANENTLY", + 302: "302 FOUND", + 303: "303 SEE OTHER", + 304: "304 NOT MODIFIED", + 305: "305 USE PROXY", + 307: "307 TEMPORARY REDIRECT", + 400: "400 BAD REQUEST", + 401: "401 UNAUTHORIZED", + 403: "403 FORBIDDEN", + 404: "404 NOT FOUND", + 405: "405 METHOD NOT ALLOWED", + 406: "406 NOT ACCEPTABLE", + 407: "407 PROXY AUTHENTICATION REQUIRED", + 408: "408 REQUEST TIMEOUT", + 409: "409 CONFLICT", + 410: "410 GONE", + 411: "411 LENGTH REQUIRED", + 412: "412 PRECONDITION FAILED", + 413: "413 REQUEST ENTITY TOO LARGE", + 414: "414 REQUEST-URI TOO LONG", + 415: "415 UNSUPPORTED MEDIA TYPE", + 416: "416 REQUESTED RANGE NOT SATISFIABLE", + 417: "417 EXPECTATION FAILED", + 422: "422 UNPROCESSABLE ENTITY", + 500: "500 INTERNAL SERVER ERROR", + 501: "501 NOT IMPLEMENTED", + 502: "502 BAD GATEWAY", + 503: "503 SERVICE UNAVAILABLE", + 504: "504 GATEWAY TIMEOUT", + 505: "505 HTTP VERSION NOT SUPPORTED", } -class HTTPResponse(Exception): - def __init__( - self, - status_code: int, - *, - headers: Dict[str, str] = {'content-type': 'text/plain'}, - cookies: Dict[str, Any] = {} - ): - self.status_code: int = status_code - self._headers: Dict[str, str] = headers - self._cookies: Dict[str, Any] = cookies - - @property - def headers(self) -> Generator[Tuple[bytes, bytes], None, None]: - for key, val in self._headers.items(): - yield key.encode('latin-1'), val.encode('latin-1') - for cookie in self._cookies.values(): - yield b'set-cookie', str(cookie)[12:].encode('latin-1') - - @property - def rsgi_headers(self) -> Generator[Tuple[str, str], None, None]: - for key, val in self._headers.items(): - yield key, val - for cookie in self._cookies.values(): - yield 'set-cookie', str(cookie)[12:] - - async def _send_headers(self, send): - await send({ - 'type': 'http.response.start', - 'status': self.status_code, - 'headers': list(self.headers) - }) - - async def _send_body(self, send): - await send({'type': 'http.response.body'}) - - async def asgi(self, scope, send): - await self._send_headers(send) - await self._send_body(send) - - def rsgi(self, protocol: HTTPProtocol): - protocol.response_empty( - self.status_code, - list(self.rsgi_headers) - ) - - -class HTTPBytes(HTTPResponse): - def __init__( - self, - status_code: int, - body: bytes = b'', - headers: Dict[str, str] = {'content-type': 'text/plain'}, - cookies: Dict[str, Any] = {} - ): - super().__init__(status_code, headers=headers, cookies=cookies) - self.body = body - - async def _send_body(self, send): - await send({ - 'type': 'http.response.body', - 'body': self.body, - 'more_body': False - }) - - def rsgi(self, protocol: HTTPProtocol): - protocol.response_bytes( - self.status_code, - list(self.rsgi_headers), - self.body - ) - - -class HTTP(HTTPResponse): - def __init__( - self, - status_code: int, - body: str = '', - headers: Dict[str, str] = {'content-type': 'text/plain'}, - cookies: Dict[str, Any] = {} - ): - super().__init__(status_code, headers=headers, cookies=cookies) - self.body = body - - @property - def encoded_body(self): - return self.body.encode('utf-8') - - async def _send_body(self, send): - await send({ - 'type': 'http.response.body', - 'body': self.encoded_body, - 'more_body': False - }) - - def rsgi(self, protocol: HTTPProtocol): - protocol.response_str( - self.status_code, - list(self.rsgi_headers), - self.body - ) - - -class HTTPRedirect(HTTPResponse): - def __init__( - self, - status_code: int, - location: str, - cookies: Dict[str, Any] = {} - ): - location = location.replace('\r', '%0D').replace('\n', '%0A') - super().__init__( - status_code, - headers={'location': location}, - cookies=cookies - ) - - -class HTTPFile(HTTPResponse): - def __init__( - self, - file_path: str, - headers: Dict[str, str] = {}, - cookies: Dict[str, Any] = {}, - chunk_size: int = 4096 - ): - super().__init__(200, headers=headers, cookies=cookies) - self.file_path = file_path - self.chunk_size = chunk_size - - def _get_stat_headers(self, stat_data): - content_length = str(stat_data.st_size) - last_modified = formatdate(stat_data.st_mtime, usegmt=True) - etag_base = str(stat_data.st_mtime) + '_' + str(stat_data.st_size) - etag = md5(etag_base.encode('utf-8')).hexdigest() - return { - 'content-type': contenttype(self.file_path), - 'content-length': content_length, - 'last-modified': last_modified, - 'etag': etag - } - - async def asgi(self, scope, send): - try: - stat_data = os.stat(self.file_path) - if not stat.S_ISREG(stat_data.st_mode): - await HTTP(403).send(scope, send) - return - self._headers.update(self._get_stat_headers(stat_data)) - await self._send_headers(send) - if 'http.response.pathsend' in scope.get('extensions', {}): - await send({ - 'type': 'http.response.pathsend', - 'path': str(self.file_path) - }) - else: - await self._send_body(send) - except IOError as e: - if e.errno == errno.EACCES: - await HTTP(403).send(scope, send) - else: - await HTTP(404).send(scope, send) - - async def _send_body(self, send): - async with loop_open_file(self.file_path, mode='rb') as f: - more_body = True - while more_body: - chunk = await f.read(self.chunk_size) - more_body = len(chunk) == self.chunk_size - await send({ - 'type': 'http.response.body', - 'body': chunk, - 'more_body': more_body, - }) - - def rsgi(self, protocol: HTTPProtocol): - try: - stat_data = os.stat(self.file_path) - if not stat.S_ISREG(stat_data.st_mode): - return HTTP(403).rsgi(protocol) - self._headers.update(self._get_stat_headers(stat_data)) - except IOError as e: - if e.errno == errno.EACCES: - return HTTP(403).rsgi(protocol) - return HTTP(404).rsgi(protocol) - - protocol.response_file( - self.status_code, - list(self.rsgi_headers), - self.file_path - ) - - -class HTTPIO(HTTPResponse): - def __init__( - self, - io_stream: BinaryIO, - headers: Dict[str, str] = {}, - cookies: Dict[str, Any] = {}, - chunk_size: int = 4096 - ): - super().__init__(200, headers=headers, cookies=cookies) - self.io_stream = io_stream - self.chunk_size = chunk_size - - def _get_io_headers(self): - content_length = str(self.io_stream.getbuffer().nbytes) - return { - 'content-length': content_length - } - - async def asgi(self, scope, send): - self._headers.update(self._get_io_headers()) - await self._send_headers(send) - await self._send_body(send) - - async def _send_body(self, send): - more_body = True - while more_body: - chunk = self.io_stream.read(self.chunk_size) - more_body = len(chunk) == self.chunk_size - await send({ - 'type': 'http.response.body', - 'body': chunk, - 'more_body': more_body, - }) - - def rsgi(self, protocol: HTTPProtocol): - protocol.response_bytes( - self.status_code, - list(self.rsgi_headers), - self.io_stream.read() - ) - - -class HTTPIter(HTTPResponse): - def __init__( - self, - iter: Iterable[bytes], - headers: Dict[str, str] = {}, - cookies: Dict[str, Any] = {} - ): - super().__init__(200, headers=headers, cookies=cookies) - self.iter = iter - - async def _send_body(self, send): - for chunk in self.iter: - await send({ - 'type': 'http.response.body', - 'body': chunk, - 'more_body': True - }) - await send({'type': 'http.response.body', 'body': b'', 'more_body': False}) - - async def rsgi(self, protocol: HTTPProtocol): - trx = protocol.response_stream( - self.status_code, - list(self.rsgi_headers) - ) - for chunk in self.iter: - await trx.send_bytes(chunk) - - -class HTTPAiter(HTTPResponse): - def __init__( - self, - iter: AsyncIterable[bytes], - headers: Dict[str, str] = {}, - cookies: Dict[str, Any] = {} - ): - super().__init__(200, headers=headers, cookies=cookies) - self.iter = iter - - async def _send_body(self, send): - async for chunk in self.iter: - await send({ - 'type': 'http.response.body', - 'body': chunk, - 'more_body': True - }) - await send({'type': 'http.response.body', 'body': b'', 'more_body': False}) - - async def rsgi(self, protocol: HTTPProtocol): - trx = protocol.response_stream( - self.status_code, - list(self.rsgi_headers) - ) - async for chunk in self.iter: - await trx.send_bytes(chunk) - - def redirect(location: str, status_code: int = 303): - response = current.response - response.status = status_code - raise HTTPRedirect(status_code, location, response.cookies) + _redirect(current, location, status_code) diff --git a/emmett/language/helpers.py b/emmett/language/helpers.py index 271a60ee..7db1ab3c 100644 --- a/emmett/language/helpers.py +++ b/emmett/language/helpers.py @@ -1,37 +1,31 @@ # -*- coding: utf-8 -*- """ - emmett.language.helpers - ----------------------- +emmett.language.helpers +----------------------- - Translation helpers. +Translation helpers. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ import re +from emmett_core.http.headers import Accept from severus.datastructures import Tstr as _Tstr -from ..datastructures import Accept - class Tstr(_Tstr): __slots__ = [] def __getstate__(self): - return { - 'text': self.text, - 'lang': self.lang, - 'args': self.args, - 'kwargs': self.kwargs - } + return {"text": self.text, "lang": self.lang, "args": self.args, "kwargs": self.kwargs} def __setstate__(self, state): - self.text = state['text'] - self.lang = state['lang'] - self.args = state['args'] - self.kwargs = state['kwargs'] + self.text = state["text"] + self.lang = state["lang"] + self.args = state["args"] + self.kwargs = state["kwargs"] def __getattr__(self, name): return getattr(str(self), name) @@ -41,9 +35,10 @@ def __json__(self): class LanguageAccept(Accept): - regex_locale_delim = re.compile(r'[_-]') + regex_locale_delim = re.compile(r"[_-]") def _value_matches(self, value, item): def _normalize(language): return self.regex_locale_delim.split(language.lower())[0] - return item == '*' or _normalize(value) == _normalize(item) + + return item == "*" or _normalize(value) == _normalize(item) diff --git a/emmett/language/translator.py b/emmett/language/translator.py index 4ef6f5be..4fd602f2 100644 --- a/emmett/language/translator.py +++ b/emmett/language/translator.py @@ -1,17 +1,18 @@ # -*- coding: utf-8 -*- """ - emmett.language.translator - -------------------------- +emmett.language.translator +-------------------------- - Severus translator implementation for Emmett. +Severus translator implementation for Emmett. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ from __future__ import annotations from typing import Optional + from severus.ctx import set_context from severus.translator import Translator as _Translator @@ -32,6 +33,4 @@ def _update_config(self, default_language: str): self._build_languages() def _get_best_language(self, lang: Optional[str] = None) -> str: - return self._langmap.get( - lang or current.language, self._default_language - ) + return self._langmap.get(lang or current.language, self._default_language) diff --git a/emmett/libs/contenttype.py b/emmett/libs/contenttype.py index 5f9a3b65..d30a939e 100644 --- a/emmett/libs/contenttype.py +++ b/emmett/libs/contenttype.py @@ -1,717 +1,717 @@ # -*- coding: utf-8 -*- """ - Content-type helper based on file extension. +Content-type helper based on file extension. - Original code from web2py - by Massimo Di Pierro +Original code from web2py +by Massimo Di Pierro - :license: LGPLv3 (http://www.gnu.org/licenses/lgpl.html) +:license: LGPLv3 (http://www.gnu.org/licenses/lgpl.html) """ -__all__ = ['contenttype'] +__all__ = ["contenttype"] CONTENT_TYPE = { - '.load': 'text/html', - '.123': 'application/vnd.lotus-1-2-3', - '.3ds': 'image/x-3ds', - '.3g2': 'video/3gpp', - '.3ga': 'video/3gpp', - '.3gp': 'video/3gpp', - '.3gpp': 'video/3gpp', - '.602': 'application/x-t602', - '.669': 'audio/x-mod', - '.7z': 'application/x-7z-compressed', - '.a': 'application/x-archive', - '.aac': 'audio/mp4', - '.abw': 'application/x-abiword', - '.abw.crashed': 'application/x-abiword', - '.abw.gz': 'application/x-abiword', - '.ac3': 'audio/ac3', - '.ace': 'application/x-ace', - '.adb': 'text/x-adasrc', - '.ads': 'text/x-adasrc', - '.afm': 'application/x-font-afm', - '.ag': 'image/x-applix-graphics', - '.ai': 'application/illustrator', - '.aif': 'audio/x-aiff', - '.aifc': 'audio/x-aiff', - '.aiff': 'audio/x-aiff', - '.al': 'application/x-perl', - '.alz': 'application/x-alz', - '.amr': 'audio/amr', - '.ani': 'application/x-navi-animation', - '.anim[1-9j]': 'video/x-anim', - '.anx': 'application/annodex', - '.ape': 'audio/x-ape', - '.arj': 'application/x-arj', - '.arw': 'image/x-sony-arw', - '.as': 'application/x-applix-spreadsheet', - '.asc': 'text/plain', - '.asf': 'video/x-ms-asf', - '.asp': 'application/x-asp', - '.ass': 'text/x-ssa', - '.asx': 'audio/x-ms-asx', - '.atom': 'application/atom+xml', - '.au': 'audio/basic', - '.avi': 'video/x-msvideo', - '.aw': 'application/x-applix-word', - '.awb': 'audio/amr-wb', - '.awk': 'application/x-awk', - '.axa': 'audio/annodex', - '.axv': 'video/annodex', - '.bak': 'application/x-trash', - '.bcpio': 'application/x-bcpio', - '.bdf': 'application/x-font-bdf', - '.bib': 'text/x-bibtex', - '.bin': 'application/octet-stream', - '.blend': 'application/x-blender', - '.blender': 'application/x-blender', - '.bmp': 'image/bmp', - '.bz': 'application/x-bzip', - '.bz2': 'application/x-bzip', - '.c': 'text/x-csrc', - '.c++': 'text/x-c++src', - '.cab': 'application/vnd.ms-cab-compressed', - '.cb7': 'application/x-cb7', - '.cbr': 'application/x-cbr', - '.cbt': 'application/x-cbt', - '.cbz': 'application/x-cbz', - '.cc': 'text/x-c++src', - '.cdf': 'application/x-netcdf', - '.cdr': 'application/vnd.corel-draw', - '.cer': 'application/x-x509-ca-cert', - '.cert': 'application/x-x509-ca-cert', - '.cgm': 'image/cgm', - '.chm': 'application/x-chm', - '.chrt': 'application/x-kchart', - '.class': 'application/x-java', - '.cls': 'text/x-tex', - '.cmake': 'text/x-cmake', - '.cpio': 'application/x-cpio', - '.cpio.gz': 'application/x-cpio-compressed', - '.cpp': 'text/x-c++src', - '.cr2': 'image/x-canon-cr2', - '.crt': 'application/x-x509-ca-cert', - '.crw': 'image/x-canon-crw', - '.cs': 'text/x-csharp', - '.csh': 'application/x-csh', - '.css': 'text/css', - '.cssl': 'text/css', - '.csv': 'text/csv', - '.cue': 'application/x-cue', - '.cur': 'image/x-win-bitmap', - '.cxx': 'text/x-c++src', - '.d': 'text/x-dsrc', - '.dar': 'application/x-dar', - '.dbf': 'application/x-dbf', - '.dc': 'application/x-dc-rom', - '.dcl': 'text/x-dcl', - '.dcm': 'application/dicom', - '.dcr': 'image/x-kodak-dcr', - '.dds': 'image/x-dds', - '.deb': 'application/x-deb', - '.der': 'application/x-x509-ca-cert', - '.desktop': 'application/x-desktop', - '.dia': 'application/x-dia-diagram', - '.diff': 'text/x-patch', - '.divx': 'video/x-msvideo', - '.djv': 'image/vnd.djvu', - '.djvu': 'image/vnd.djvu', - '.dng': 'image/x-adobe-dng', - '.doc': 'application/msword', - '.docbook': 'application/docbook+xml', - '.docm': 'application/vnd.openxmlformats-officedocument.wordprocessingml.document', - '.docx': 'application/vnd.openxmlformats-officedocument.wordprocessingml.document', - '.dot': 'text/vnd.graphviz', - '.dsl': 'text/x-dsl', - '.dtd': 'application/xml-dtd', - '.dtx': 'text/x-tex', - '.dv': 'video/dv', - '.dvi': 'application/x-dvi', - '.dvi.bz2': 'application/x-bzdvi', - '.dvi.gz': 'application/x-gzdvi', - '.dwg': 'image/vnd.dwg', - '.dxf': 'image/vnd.dxf', - '.e': 'text/x-eiffel', - '.egon': 'application/x-egon', - '.eif': 'text/x-eiffel', - '.el': 'text/x-emacs-lisp', - '.emf': 'image/x-emf', - '.emp': 'application/vnd.emusic-emusic_package', - '.ent': 'application/xml-external-parsed-entity', - '.eps': 'image/x-eps', - '.eps.bz2': 'image/x-bzeps', - '.eps.gz': 'image/x-gzeps', - '.epsf': 'image/x-eps', - '.epsf.bz2': 'image/x-bzeps', - '.epsf.gz': 'image/x-gzeps', - '.epsi': 'image/x-eps', - '.epsi.bz2': 'image/x-bzeps', - '.epsi.gz': 'image/x-gzeps', - '.epub': 'application/epub+zip', - '.erl': 'text/x-erlang', - '.es': 'application/ecmascript', - '.etheme': 'application/x-e-theme', - '.etx': 'text/x-setext', - '.exe': 'application/x-ms-dos-executable', - '.exr': 'image/x-exr', - '.ez': 'application/andrew-inset', - '.f': 'text/x-fortran', - '.f90': 'text/x-fortran', - '.f95': 'text/x-fortran', - '.fb2': 'application/x-fictionbook+xml', - '.fig': 'image/x-xfig', - '.fits': 'image/fits', - '.fl': 'application/x-fluid', - '.flac': 'audio/x-flac', - '.flc': 'video/x-flic', - '.fli': 'video/x-flic', - '.flv': 'video/x-flv', - '.flw': 'application/x-kivio', - '.fo': 'text/x-xslfo', - '.for': 'text/x-fortran', - '.g3': 'image/fax-g3', - '.gb': 'application/x-gameboy-rom', - '.gba': 'application/x-gba-rom', - '.gcrd': 'text/directory', - '.ged': 'application/x-gedcom', - '.gedcom': 'application/x-gedcom', - '.gen': 'application/x-genesis-rom', - '.gf': 'application/x-tex-gf', - '.gg': 'application/x-sms-rom', - '.gif': 'image/gif', - '.glade': 'application/x-glade', - '.gmo': 'application/x-gettext-translation', - '.gnc': 'application/x-gnucash', - '.gnd': 'application/gnunet-directory', - '.gnucash': 'application/x-gnucash', - '.gnumeric': 'application/x-gnumeric', - '.gnuplot': 'application/x-gnuplot', - '.gp': 'application/x-gnuplot', - '.gpg': 'application/pgp-encrypted', - '.gplt': 'application/x-gnuplot', - '.gra': 'application/x-graphite', - '.gsf': 'application/x-font-type1', - '.gsm': 'audio/x-gsm', - '.gtar': 'application/x-tar', - '.gv': 'text/vnd.graphviz', - '.gvp': 'text/x-google-video-pointer', - '.gz': 'application/x-gzip', - '.h': 'text/x-chdr', - '.h++': 'text/x-c++hdr', - '.hdf': 'application/x-hdf', - '.hh': 'text/x-c++hdr', - '.hp': 'text/x-c++hdr', - '.hpgl': 'application/vnd.hp-hpgl', - '.hpp': 'text/x-c++hdr', - '.hs': 'text/x-haskell', - '.htm': 'text/html', - '.html': 'text/html', - '.hwp': 'application/x-hwp', - '.hwt': 'application/x-hwt', - '.hxx': 'text/x-c++hdr', - '.ica': 'application/x-ica', - '.icb': 'image/x-tga', - '.icns': 'image/x-icns', - '.ico': 'image/vnd.microsoft.icon', - '.ics': 'text/calendar', - '.idl': 'text/x-idl', - '.ief': 'image/ief', - '.iff': 'image/x-iff', - '.ilbm': 'image/x-ilbm', - '.ime': 'text/x-imelody', - '.imy': 'text/x-imelody', - '.ins': 'text/x-tex', - '.iptables': 'text/x-iptables', - '.iso': 'application/x-cd-image', - '.iso9660': 'application/x-cd-image', - '.it': 'audio/x-it', - '.j2k': 'image/jp2', - '.jad': 'text/vnd.sun.j2me.app-descriptor', - '.jar': 'application/x-java-archive', - '.java': 'text/x-java', - '.jng': 'image/x-jng', - '.jnlp': 'application/x-java-jnlp-file', - '.jp2': 'image/jp2', - '.jpc': 'image/jp2', - '.jpe': 'image/jpeg', - '.jpeg': 'image/jpeg', - '.jpf': 'image/jp2', - '.jpg': 'image/jpeg', - '.jpr': 'application/x-jbuilder-project', - '.jpx': 'image/jp2', - '.js': 'application/javascript', - '.json': 'application/json', - '.jsonp': 'application/jsonp', - '.k25': 'image/x-kodak-k25', - '.kar': 'audio/midi', - '.karbon': 'application/x-karbon', - '.kdc': 'image/x-kodak-kdc', - '.kdelnk': 'application/x-desktop', - '.kexi': 'application/x-kexiproject-sqlite3', - '.kexic': 'application/x-kexi-connectiondata', - '.kexis': 'application/x-kexiproject-shortcut', - '.kfo': 'application/x-kformula', - '.kil': 'application/x-killustrator', - '.kino': 'application/smil', - '.kml': 'application/vnd.google-earth.kml+xml', - '.kmz': 'application/vnd.google-earth.kmz', - '.kon': 'application/x-kontour', - '.kpm': 'application/x-kpovmodeler', - '.kpr': 'application/x-kpresenter', - '.kpt': 'application/x-kpresenter', - '.kra': 'application/x-krita', - '.ksp': 'application/x-kspread', - '.kud': 'application/x-kugar', - '.kwd': 'application/x-kword', - '.kwt': 'application/x-kword', - '.la': 'application/x-shared-library-la', - '.latex': 'text/x-tex', - '.ldif': 'text/x-ldif', - '.lha': 'application/x-lha', - '.lhs': 'text/x-literate-haskell', - '.lhz': 'application/x-lhz', - '.log': 'text/x-log', - '.ltx': 'text/x-tex', - '.lua': 'text/x-lua', - '.lwo': 'image/x-lwo', - '.lwob': 'image/x-lwo', - '.lws': 'image/x-lws', - '.ly': 'text/x-lilypond', - '.lyx': 'application/x-lyx', - '.lz': 'application/x-lzip', - '.lzh': 'application/x-lha', - '.lzma': 'application/x-lzma', - '.lzo': 'application/x-lzop', - '.m': 'text/x-matlab', - '.m15': 'audio/x-mod', - '.m2t': 'video/mpeg', - '.m3u': 'audio/x-mpegurl', - '.m3u8': 'audio/x-mpegurl', - '.m4': 'application/x-m4', - '.m4a': 'audio/mp4', - '.m4b': 'audio/x-m4b', - '.m4v': 'video/mp4', - '.mab': 'application/x-markaby', - '.man': 'application/x-troff-man', - '.mbox': 'application/mbox', - '.md': 'application/x-genesis-rom', - '.mdb': 'application/vnd.ms-access', - '.mdi': 'image/vnd.ms-modi', - '.me': 'text/x-troff-me', - '.med': 'audio/x-mod', - '.metalink': 'application/metalink+xml', - '.mgp': 'application/x-magicpoint', - '.mid': 'audio/midi', - '.midi': 'audio/midi', - '.mif': 'application/x-mif', - '.minipsf': 'audio/x-minipsf', - '.mka': 'audio/x-matroska', - '.mkv': 'video/x-matroska', - '.ml': 'text/x-ocaml', - '.mli': 'text/x-ocaml', - '.mm': 'text/x-troff-mm', - '.mmf': 'application/x-smaf', - '.mml': 'text/mathml', - '.mng': 'video/x-mng', - '.mo': 'application/x-gettext-translation', - '.mo3': 'audio/x-mo3', - '.moc': 'text/x-moc', - '.mod': 'audio/x-mod', - '.mof': 'text/x-mof', - '.moov': 'video/quicktime', - '.mov': 'video/quicktime', - '.movie': 'video/x-sgi-movie', - '.mp+': 'audio/x-musepack', - '.mp2': 'video/mpeg', - '.mp3': 'audio/mpeg', - '.mp4': 'video/mp4', - '.mpc': 'audio/x-musepack', - '.mpe': 'video/mpeg', - '.mpeg': 'video/mpeg', - '.mpg': 'video/mpeg', - '.mpga': 'audio/mpeg', - '.mpp': 'audio/x-musepack', - '.mrl': 'text/x-mrml', - '.mrml': 'text/x-mrml', - '.mrw': 'image/x-minolta-mrw', - '.ms': 'text/x-troff-ms', - '.msi': 'application/x-msi', - '.msod': 'image/x-msod', - '.msx': 'application/x-msx-rom', - '.mtm': 'audio/x-mod', - '.mup': 'text/x-mup', - '.mxf': 'application/mxf', - '.n64': 'application/x-n64-rom', - '.nb': 'application/mathematica', - '.nc': 'application/x-netcdf', - '.nds': 'application/x-nintendo-ds-rom', - '.nef': 'image/x-nikon-nef', - '.nes': 'application/x-nes-rom', - '.nfo': 'text/x-nfo', - '.not': 'text/x-mup', - '.nsc': 'application/x-netshow-channel', - '.nsv': 'video/x-nsv', - '.o': 'application/x-object', - '.obj': 'application/x-tgif', - '.ocl': 'text/x-ocl', - '.oda': 'application/oda', - '.odb': 'application/vnd.oasis.opendocument.database', - '.odc': 'application/vnd.oasis.opendocument.chart', - '.odf': 'application/vnd.oasis.opendocument.formula', - '.odg': 'application/vnd.oasis.opendocument.graphics', - '.odi': 'application/vnd.oasis.opendocument.image', - '.odm': 'application/vnd.oasis.opendocument.text-master', - '.odp': 'application/vnd.oasis.opendocument.presentation', - '.ods': 'application/vnd.oasis.opendocument.spreadsheet', - '.odt': 'application/vnd.oasis.opendocument.text', - '.oga': 'audio/ogg', - '.ogg': 'video/x-theora+ogg', - '.ogm': 'video/x-ogm+ogg', - '.ogv': 'video/ogg', - '.ogx': 'application/ogg', - '.old': 'application/x-trash', - '.oleo': 'application/x-oleo', - '.opml': 'text/x-opml+xml', - '.ora': 'image/openraster', - '.orf': 'image/x-olympus-orf', - '.otc': 'application/vnd.oasis.opendocument.chart-template', - '.otf': 'application/x-font-otf', - '.otg': 'application/vnd.oasis.opendocument.graphics-template', - '.oth': 'application/vnd.oasis.opendocument.text-web', - '.otp': 'application/vnd.oasis.opendocument.presentation-template', - '.ots': 'application/vnd.oasis.opendocument.spreadsheet-template', - '.ott': 'application/vnd.oasis.opendocument.text-template', - '.owl': 'application/rdf+xml', - '.oxt': 'application/vnd.openofficeorg.extension', - '.p': 'text/x-pascal', - '.p10': 'application/pkcs10', - '.p12': 'application/x-pkcs12', - '.p7b': 'application/x-pkcs7-certificates', - '.p7s': 'application/pkcs7-signature', - '.pack': 'application/x-java-pack200', - '.pak': 'application/x-pak', - '.par2': 'application/x-par2', - '.pas': 'text/x-pascal', - '.patch': 'text/x-patch', - '.pbm': 'image/x-portable-bitmap', - '.pcd': 'image/x-photo-cd', - '.pcf': 'application/x-cisco-vpn-settings', - '.pcf.gz': 'application/x-font-pcf', - '.pcf.z': 'application/x-font-pcf', - '.pcl': 'application/vnd.hp-pcl', - '.pcx': 'image/x-pcx', - '.pdb': 'chemical/x-pdb', - '.pdc': 'application/x-aportisdoc', - '.pdf': 'application/pdf', - '.pdf.bz2': 'application/x-bzpdf', - '.pdf.gz': 'application/x-gzpdf', - '.pef': 'image/x-pentax-pef', - '.pem': 'application/x-x509-ca-cert', - '.perl': 'application/x-perl', - '.pfa': 'application/x-font-type1', - '.pfb': 'application/x-font-type1', - '.pfx': 'application/x-pkcs12', - '.pgm': 'image/x-portable-graymap', - '.pgn': 'application/x-chess-pgn', - '.pgp': 'application/pgp-encrypted', - '.php': 'application/x-php', - '.php3': 'application/x-php', - '.php4': 'application/x-php', - '.pict': 'image/x-pict', - '.pict1': 'image/x-pict', - '.pict2': 'image/x-pict', - '.pickle': 'application/python-pickle', - '.pk': 'application/x-tex-pk', - '.pkipath': 'application/pkix-pkipath', - '.pkr': 'application/pgp-keys', - '.pl': 'application/x-perl', - '.pla': 'audio/x-iriver-pla', - '.pln': 'application/x-planperfect', - '.pls': 'audio/x-scpls', - '.pm': 'application/x-perl', - '.png': 'image/png', - '.pnm': 'image/x-portable-anymap', - '.pntg': 'image/x-macpaint', - '.po': 'text/x-gettext-translation', - '.por': 'application/x-spss-por', - '.pot': 'text/x-gettext-translation-template', - '.ppm': 'image/x-portable-pixmap', - '.pps': 'application/vnd.ms-powerpoint', - '.ppt': 'application/vnd.ms-powerpoint', - '.pptm': 'application/vnd.openxmlformats-officedocument.presentationml.presentation', - '.pptx': 'application/vnd.openxmlformats-officedocument.presentationml.presentation', - '.ppz': 'application/vnd.ms-powerpoint', - '.prc': 'application/x-palm-database', - '.ps': 'application/postscript', - '.ps.bz2': 'application/x-bzpostscript', - '.ps.gz': 'application/x-gzpostscript', - '.psd': 'image/vnd.adobe.photoshop', - '.psf': 'audio/x-psf', - '.psf.gz': 'application/x-gz-font-linux-psf', - '.psflib': 'audio/x-psflib', - '.psid': 'audio/prs.sid', - '.psw': 'application/x-pocket-word', - '.pw': 'application/x-pw', - '.py': 'text/x-python', - '.pyc': 'application/x-python-bytecode', - '.pyo': 'application/x-python-bytecode', - '.qif': 'image/x-quicktime', - '.qt': 'video/quicktime', - '.qtif': 'image/x-quicktime', - '.qtl': 'application/x-quicktime-media-link', - '.qtvr': 'video/quicktime', - '.ra': 'audio/vnd.rn-realaudio', - '.raf': 'image/x-fuji-raf', - '.ram': 'application/ram', - '.rar': 'application/x-rar', - '.ras': 'image/x-cmu-raster', - '.raw': 'image/x-panasonic-raw', - '.rax': 'audio/vnd.rn-realaudio', - '.rb': 'application/x-ruby', - '.rdf': 'application/rdf+xml', - '.rdfs': 'application/rdf+xml', - '.reg': 'text/x-ms-regedit', - '.rej': 'application/x-reject', - '.rgb': 'image/x-rgb', - '.rle': 'image/rle', - '.rm': 'application/vnd.rn-realmedia', - '.rmj': 'application/vnd.rn-realmedia', - '.rmm': 'application/vnd.rn-realmedia', - '.rms': 'application/vnd.rn-realmedia', - '.rmvb': 'application/vnd.rn-realmedia', - '.rmx': 'application/vnd.rn-realmedia', - '.roff': 'text/troff', - '.rp': 'image/vnd.rn-realpix', - '.rpm': 'application/x-rpm', - '.rss': 'application/rss+xml', - '.rt': 'text/vnd.rn-realtext', - '.rtf': 'application/rtf', - '.rtx': 'text/richtext', - '.rv': 'video/vnd.rn-realvideo', - '.rvx': 'video/vnd.rn-realvideo', - '.s3m': 'audio/x-s3m', - '.sam': 'application/x-amipro', - '.sami': 'application/x-sami', - '.sav': 'application/x-spss-sav', - '.scm': 'text/x-scheme', - '.sda': 'application/vnd.stardivision.draw', - '.sdc': 'application/vnd.stardivision.calc', - '.sdd': 'application/vnd.stardivision.impress', - '.sdp': 'application/sdp', - '.sds': 'application/vnd.stardivision.chart', - '.sdw': 'application/vnd.stardivision.writer', - '.sgf': 'application/x-go-sgf', - '.sgi': 'image/x-sgi', - '.sgl': 'application/vnd.stardivision.writer', - '.sgm': 'text/sgml', - '.sgml': 'text/sgml', - '.sh': 'application/x-shellscript', - '.shar': 'application/x-shar', - '.shn': 'application/x-shorten', - '.siag': 'application/x-siag', - '.sid': 'audio/prs.sid', - '.sik': 'application/x-trash', - '.sis': 'application/vnd.symbian.install', - '.sisx': 'x-epoc/x-sisx-app', - '.sit': 'application/x-stuffit', - '.siv': 'application/sieve', - '.sk': 'image/x-skencil', - '.sk1': 'image/x-skencil', - '.skr': 'application/pgp-keys', - '.slk': 'text/spreadsheet', - '.smaf': 'application/x-smaf', - '.smc': 'application/x-snes-rom', - '.smd': 'application/vnd.stardivision.mail', - '.smf': 'application/vnd.stardivision.math', - '.smi': 'application/x-sami', - '.smil': 'application/smil', - '.sml': 'application/smil', - '.sms': 'application/x-sms-rom', - '.snd': 'audio/basic', - '.so': 'application/x-sharedlib', - '.spc': 'application/x-pkcs7-certificates', - '.spd': 'application/x-font-speedo', - '.spec': 'text/x-rpm-spec', - '.spl': 'application/x-shockwave-flash', - '.spx': 'audio/x-speex', - '.sql': 'text/x-sql', - '.sr2': 'image/x-sony-sr2', - '.src': 'application/x-wais-source', - '.srf': 'image/x-sony-srf', - '.srt': 'application/x-subrip', - '.ssa': 'text/x-ssa', - '.stc': 'application/vnd.sun.xml.calc.template', - '.std': 'application/vnd.sun.xml.draw.template', - '.sti': 'application/vnd.sun.xml.impress.template', - '.stm': 'audio/x-stm', - '.stw': 'application/vnd.sun.xml.writer.template', - '.sty': 'text/x-tex', - '.sub': 'text/x-subviewer', - '.sun': 'image/x-sun-raster', - '.sv4cpio': 'application/x-sv4cpio', - '.sv4crc': 'application/x-sv4crc', - '.svg': 'image/svg+xml', - '.svgz': 'image/svg+xml-compressed', - '.swf': 'application/x-shockwave-flash', - '.sxc': 'application/vnd.sun.xml.calc', - '.sxd': 'application/vnd.sun.xml.draw', - '.sxg': 'application/vnd.sun.xml.writer.global', - '.sxi': 'application/vnd.sun.xml.impress', - '.sxm': 'application/vnd.sun.xml.math', - '.sxw': 'application/vnd.sun.xml.writer', - '.sylk': 'text/spreadsheet', - '.t': 'text/troff', - '.t2t': 'text/x-txt2tags', - '.tar': 'application/x-tar', - '.tar.bz': 'application/x-bzip-compressed-tar', - '.tar.bz2': 'application/x-bzip-compressed-tar', - '.tar.gz': 'application/x-compressed-tar', - '.tar.lzma': 'application/x-lzma-compressed-tar', - '.tar.lzo': 'application/x-tzo', - '.tar.xz': 'application/x-xz-compressed-tar', - '.tar.z': 'application/x-tarz', - '.tbz': 'application/x-bzip-compressed-tar', - '.tbz2': 'application/x-bzip-compressed-tar', - '.tcl': 'text/x-tcl', - '.tex': 'text/x-tex', - '.texi': 'text/x-texinfo', - '.texinfo': 'text/x-texinfo', - '.tga': 'image/x-tga', - '.tgz': 'application/x-compressed-tar', - '.theme': 'application/x-theme', - '.themepack': 'application/x-windows-themepack', - '.tif': 'image/tiff', - '.tiff': 'image/tiff', - '.tk': 'text/x-tcl', - '.tlz': 'application/x-lzma-compressed-tar', - '.tnef': 'application/vnd.ms-tnef', - '.tnf': 'application/vnd.ms-tnef', - '.toc': 'application/x-cdrdao-toc', - '.torrent': 'application/x-bittorrent', - '.tpic': 'image/x-tga', - '.tr': 'text/troff', - '.ts': 'application/x-linguist', - '.tsv': 'text/tab-separated-values', - '.tta': 'audio/x-tta', - '.ttc': 'application/x-font-ttf', - '.ttf': 'application/x-font-ttf', - '.ttx': 'application/x-font-ttx', - '.txt': 'text/plain', - '.txz': 'application/x-xz-compressed-tar', - '.tzo': 'application/x-tzo', - '.ufraw': 'application/x-ufraw', - '.ui': 'application/x-designer', - '.uil': 'text/x-uil', - '.ult': 'audio/x-mod', - '.uni': 'audio/x-mod', - '.uri': 'text/x-uri', - '.url': 'text/x-uri', - '.ustar': 'application/x-ustar', - '.vala': 'text/x-vala', - '.vapi': 'text/x-vala', - '.vcf': 'text/directory', - '.vcs': 'text/calendar', - '.vct': 'text/directory', - '.vda': 'image/x-tga', - '.vhd': 'text/x-vhdl', - '.vhdl': 'text/x-vhdl', - '.viv': 'video/vivo', - '.vivo': 'video/vivo', - '.vlc': 'audio/x-mpegurl', - '.vob': 'video/mpeg', - '.voc': 'audio/x-voc', - '.vor': 'application/vnd.stardivision.writer', - '.vst': 'image/x-tga', - '.wav': 'audio/x-wav', - '.wax': 'audio/x-ms-asx', - '.wb1': 'application/x-quattropro', - '.wb2': 'application/x-quattropro', - '.wb3': 'application/x-quattropro', - '.wbmp': 'image/vnd.wap.wbmp', - '.wcm': 'application/vnd.ms-works', - '.wdb': 'application/vnd.ms-works', - '.webm': 'video/webm', - '.wk1': 'application/vnd.lotus-1-2-3', - '.wk3': 'application/vnd.lotus-1-2-3', - '.wk4': 'application/vnd.lotus-1-2-3', - '.wks': 'application/vnd.ms-works', - '.wma': 'audio/x-ms-wma', - '.wmf': 'image/x-wmf', - '.wml': 'text/vnd.wap.wml', - '.wmls': 'text/vnd.wap.wmlscript', - '.wmv': 'video/x-ms-wmv', - '.wmx': 'audio/x-ms-asx', - '.wp': 'application/vnd.wordperfect', - '.wp4': 'application/vnd.wordperfect', - '.wp5': 'application/vnd.wordperfect', - '.wp6': 'application/vnd.wordperfect', - '.wpd': 'application/vnd.wordperfect', - '.wpg': 'application/x-wpg', - '.wpl': 'application/vnd.ms-wpl', - '.wpp': 'application/vnd.wordperfect', - '.wps': 'application/vnd.ms-works', - '.wri': 'application/x-mswrite', - '.wrl': 'model/vrml', - '.wv': 'audio/x-wavpack', - '.wvc': 'audio/x-wavpack-correction', - '.wvp': 'audio/x-wavpack', - '.wvx': 'audio/x-ms-asx', - '.x3f': 'image/x-sigma-x3f', - '.xac': 'application/x-gnucash', - '.xbel': 'application/x-xbel', - '.xbl': 'application/xml', - '.xbm': 'image/x-xbitmap', - '.xcf': 'image/x-xcf', - '.xcf.bz2': 'image/x-compressed-xcf', - '.xcf.gz': 'image/x-compressed-xcf', - '.xhtml': 'application/xhtml+xml', - '.xi': 'audio/x-xi', - '.xla': 'application/vnd.ms-excel', - '.xlc': 'application/vnd.ms-excel', - '.xld': 'application/vnd.ms-excel', - '.xlf': 'application/x-xliff', - '.xliff': 'application/x-xliff', - '.xll': 'application/vnd.ms-excel', - '.xlm': 'application/vnd.ms-excel', - '.xls': 'application/vnd.ms-excel', - '.xlsm': 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet', - '.xlsx': 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet', - '.xlt': 'application/vnd.ms-excel', - '.xlw': 'application/vnd.ms-excel', - '.xm': 'audio/x-xm', - '.xmf': 'audio/x-xmf', - '.xmi': 'text/x-xmi', - '.xml': 'application/xml', - '.xpm': 'image/x-xpixmap', - '.xps': 'application/vnd.ms-xpsdocument', - '.xsl': 'application/xml', - '.xslfo': 'text/x-xslfo', - '.xslt': 'application/xml', - '.xspf': 'application/xspf+xml', - '.xul': 'application/vnd.mozilla.xul+xml', - '.xwd': 'image/x-xwindowdump', - '.xyz': 'chemical/x-pdb', - '.xz': 'application/x-xz', - '.w2p': 'application/w2p', - '.z': 'application/x-compress', - '.zabw': 'application/x-abiword', - '.zip': 'application/zip', - '.zoo': 'application/x-zoo', + ".load": "text/html", + ".123": "application/vnd.lotus-1-2-3", + ".3ds": "image/x-3ds", + ".3g2": "video/3gpp", + ".3ga": "video/3gpp", + ".3gp": "video/3gpp", + ".3gpp": "video/3gpp", + ".602": "application/x-t602", + ".669": "audio/x-mod", + ".7z": "application/x-7z-compressed", + ".a": "application/x-archive", + ".aac": "audio/mp4", + ".abw": "application/x-abiword", + ".abw.crashed": "application/x-abiword", + ".abw.gz": "application/x-abiword", + ".ac3": "audio/ac3", + ".ace": "application/x-ace", + ".adb": "text/x-adasrc", + ".ads": "text/x-adasrc", + ".afm": "application/x-font-afm", + ".ag": "image/x-applix-graphics", + ".ai": "application/illustrator", + ".aif": "audio/x-aiff", + ".aifc": "audio/x-aiff", + ".aiff": "audio/x-aiff", + ".al": "application/x-perl", + ".alz": "application/x-alz", + ".amr": "audio/amr", + ".ani": "application/x-navi-animation", + ".anim[1-9j]": "video/x-anim", + ".anx": "application/annodex", + ".ape": "audio/x-ape", + ".arj": "application/x-arj", + ".arw": "image/x-sony-arw", + ".as": "application/x-applix-spreadsheet", + ".asc": "text/plain", + ".asf": "video/x-ms-asf", + ".asp": "application/x-asp", + ".ass": "text/x-ssa", + ".asx": "audio/x-ms-asx", + ".atom": "application/atom+xml", + ".au": "audio/basic", + ".avi": "video/x-msvideo", + ".aw": "application/x-applix-word", + ".awb": "audio/amr-wb", + ".awk": "application/x-awk", + ".axa": "audio/annodex", + ".axv": "video/annodex", + ".bak": "application/x-trash", + ".bcpio": "application/x-bcpio", + ".bdf": "application/x-font-bdf", + ".bib": "text/x-bibtex", + ".bin": "application/octet-stream", + ".blend": "application/x-blender", + ".blender": "application/x-blender", + ".bmp": "image/bmp", + ".bz": "application/x-bzip", + ".bz2": "application/x-bzip", + ".c": "text/x-csrc", + ".c++": "text/x-c++src", + ".cab": "application/vnd.ms-cab-compressed", + ".cb7": "application/x-cb7", + ".cbr": "application/x-cbr", + ".cbt": "application/x-cbt", + ".cbz": "application/x-cbz", + ".cc": "text/x-c++src", + ".cdf": "application/x-netcdf", + ".cdr": "application/vnd.corel-draw", + ".cer": "application/x-x509-ca-cert", + ".cert": "application/x-x509-ca-cert", + ".cgm": "image/cgm", + ".chm": "application/x-chm", + ".chrt": "application/x-kchart", + ".class": "application/x-java", + ".cls": "text/x-tex", + ".cmake": "text/x-cmake", + ".cpio": "application/x-cpio", + ".cpio.gz": "application/x-cpio-compressed", + ".cpp": "text/x-c++src", + ".cr2": "image/x-canon-cr2", + ".crt": "application/x-x509-ca-cert", + ".crw": "image/x-canon-crw", + ".cs": "text/x-csharp", + ".csh": "application/x-csh", + ".css": "text/css", + ".cssl": "text/css", + ".csv": "text/csv", + ".cue": "application/x-cue", + ".cur": "image/x-win-bitmap", + ".cxx": "text/x-c++src", + ".d": "text/x-dsrc", + ".dar": "application/x-dar", + ".dbf": "application/x-dbf", + ".dc": "application/x-dc-rom", + ".dcl": "text/x-dcl", + ".dcm": "application/dicom", + ".dcr": "image/x-kodak-dcr", + ".dds": "image/x-dds", + ".deb": "application/x-deb", + ".der": "application/x-x509-ca-cert", + ".desktop": "application/x-desktop", + ".dia": "application/x-dia-diagram", + ".diff": "text/x-patch", + ".divx": "video/x-msvideo", + ".djv": "image/vnd.djvu", + ".djvu": "image/vnd.djvu", + ".dng": "image/x-adobe-dng", + ".doc": "application/msword", + ".docbook": "application/docbook+xml", + ".docm": "application/vnd.openxmlformats-officedocument.wordprocessingml.document", + ".docx": "application/vnd.openxmlformats-officedocument.wordprocessingml.document", + ".dot": "text/vnd.graphviz", + ".dsl": "text/x-dsl", + ".dtd": "application/xml-dtd", + ".dtx": "text/x-tex", + ".dv": "video/dv", + ".dvi": "application/x-dvi", + ".dvi.bz2": "application/x-bzdvi", + ".dvi.gz": "application/x-gzdvi", + ".dwg": "image/vnd.dwg", + ".dxf": "image/vnd.dxf", + ".e": "text/x-eiffel", + ".egon": "application/x-egon", + ".eif": "text/x-eiffel", + ".el": "text/x-emacs-lisp", + ".emf": "image/x-emf", + ".emp": "application/vnd.emusic-emusic_package", + ".ent": "application/xml-external-parsed-entity", + ".eps": "image/x-eps", + ".eps.bz2": "image/x-bzeps", + ".eps.gz": "image/x-gzeps", + ".epsf": "image/x-eps", + ".epsf.bz2": "image/x-bzeps", + ".epsf.gz": "image/x-gzeps", + ".epsi": "image/x-eps", + ".epsi.bz2": "image/x-bzeps", + ".epsi.gz": "image/x-gzeps", + ".epub": "application/epub+zip", + ".erl": "text/x-erlang", + ".es": "application/ecmascript", + ".etheme": "application/x-e-theme", + ".etx": "text/x-setext", + ".exe": "application/x-ms-dos-executable", + ".exr": "image/x-exr", + ".ez": "application/andrew-inset", + ".f": "text/x-fortran", + ".f90": "text/x-fortran", + ".f95": "text/x-fortran", + ".fb2": "application/x-fictionbook+xml", + ".fig": "image/x-xfig", + ".fits": "image/fits", + ".fl": "application/x-fluid", + ".flac": "audio/x-flac", + ".flc": "video/x-flic", + ".fli": "video/x-flic", + ".flv": "video/x-flv", + ".flw": "application/x-kivio", + ".fo": "text/x-xslfo", + ".for": "text/x-fortran", + ".g3": "image/fax-g3", + ".gb": "application/x-gameboy-rom", + ".gba": "application/x-gba-rom", + ".gcrd": "text/directory", + ".ged": "application/x-gedcom", + ".gedcom": "application/x-gedcom", + ".gen": "application/x-genesis-rom", + ".gf": "application/x-tex-gf", + ".gg": "application/x-sms-rom", + ".gif": "image/gif", + ".glade": "application/x-glade", + ".gmo": "application/x-gettext-translation", + ".gnc": "application/x-gnucash", + ".gnd": "application/gnunet-directory", + ".gnucash": "application/x-gnucash", + ".gnumeric": "application/x-gnumeric", + ".gnuplot": "application/x-gnuplot", + ".gp": "application/x-gnuplot", + ".gpg": "application/pgp-encrypted", + ".gplt": "application/x-gnuplot", + ".gra": "application/x-graphite", + ".gsf": "application/x-font-type1", + ".gsm": "audio/x-gsm", + ".gtar": "application/x-tar", + ".gv": "text/vnd.graphviz", + ".gvp": "text/x-google-video-pointer", + ".gz": "application/x-gzip", + ".h": "text/x-chdr", + ".h++": "text/x-c++hdr", + ".hdf": "application/x-hdf", + ".hh": "text/x-c++hdr", + ".hp": "text/x-c++hdr", + ".hpgl": "application/vnd.hp-hpgl", + ".hpp": "text/x-c++hdr", + ".hs": "text/x-haskell", + ".htm": "text/html", + ".html": "text/html", + ".hwp": "application/x-hwp", + ".hwt": "application/x-hwt", + ".hxx": "text/x-c++hdr", + ".ica": "application/x-ica", + ".icb": "image/x-tga", + ".icns": "image/x-icns", + ".ico": "image/vnd.microsoft.icon", + ".ics": "text/calendar", + ".idl": "text/x-idl", + ".ief": "image/ief", + ".iff": "image/x-iff", + ".ilbm": "image/x-ilbm", + ".ime": "text/x-imelody", + ".imy": "text/x-imelody", + ".ins": "text/x-tex", + ".iptables": "text/x-iptables", + ".iso": "application/x-cd-image", + ".iso9660": "application/x-cd-image", + ".it": "audio/x-it", + ".j2k": "image/jp2", + ".jad": "text/vnd.sun.j2me.app-descriptor", + ".jar": "application/x-java-archive", + ".java": "text/x-java", + ".jng": "image/x-jng", + ".jnlp": "application/x-java-jnlp-file", + ".jp2": "image/jp2", + ".jpc": "image/jp2", + ".jpe": "image/jpeg", + ".jpeg": "image/jpeg", + ".jpf": "image/jp2", + ".jpg": "image/jpeg", + ".jpr": "application/x-jbuilder-project", + ".jpx": "image/jp2", + ".js": "application/javascript", + ".json": "application/json", + ".jsonp": "application/jsonp", + ".k25": "image/x-kodak-k25", + ".kar": "audio/midi", + ".karbon": "application/x-karbon", + ".kdc": "image/x-kodak-kdc", + ".kdelnk": "application/x-desktop", + ".kexi": "application/x-kexiproject-sqlite3", + ".kexic": "application/x-kexi-connectiondata", + ".kexis": "application/x-kexiproject-shortcut", + ".kfo": "application/x-kformula", + ".kil": "application/x-killustrator", + ".kino": "application/smil", + ".kml": "application/vnd.google-earth.kml+xml", + ".kmz": "application/vnd.google-earth.kmz", + ".kon": "application/x-kontour", + ".kpm": "application/x-kpovmodeler", + ".kpr": "application/x-kpresenter", + ".kpt": "application/x-kpresenter", + ".kra": "application/x-krita", + ".ksp": "application/x-kspread", + ".kud": "application/x-kugar", + ".kwd": "application/x-kword", + ".kwt": "application/x-kword", + ".la": "application/x-shared-library-la", + ".latex": "text/x-tex", + ".ldif": "text/x-ldif", + ".lha": "application/x-lha", + ".lhs": "text/x-literate-haskell", + ".lhz": "application/x-lhz", + ".log": "text/x-log", + ".ltx": "text/x-tex", + ".lua": "text/x-lua", + ".lwo": "image/x-lwo", + ".lwob": "image/x-lwo", + ".lws": "image/x-lws", + ".ly": "text/x-lilypond", + ".lyx": "application/x-lyx", + ".lz": "application/x-lzip", + ".lzh": "application/x-lha", + ".lzma": "application/x-lzma", + ".lzo": "application/x-lzop", + ".m": "text/x-matlab", + ".m15": "audio/x-mod", + ".m2t": "video/mpeg", + ".m3u": "audio/x-mpegurl", + ".m3u8": "audio/x-mpegurl", + ".m4": "application/x-m4", + ".m4a": "audio/mp4", + ".m4b": "audio/x-m4b", + ".m4v": "video/mp4", + ".mab": "application/x-markaby", + ".man": "application/x-troff-man", + ".mbox": "application/mbox", + ".md": "application/x-genesis-rom", + ".mdb": "application/vnd.ms-access", + ".mdi": "image/vnd.ms-modi", + ".me": "text/x-troff-me", + ".med": "audio/x-mod", + ".metalink": "application/metalink+xml", + ".mgp": "application/x-magicpoint", + ".mid": "audio/midi", + ".midi": "audio/midi", + ".mif": "application/x-mif", + ".minipsf": "audio/x-minipsf", + ".mka": "audio/x-matroska", + ".mkv": "video/x-matroska", + ".ml": "text/x-ocaml", + ".mli": "text/x-ocaml", + ".mm": "text/x-troff-mm", + ".mmf": "application/x-smaf", + ".mml": "text/mathml", + ".mng": "video/x-mng", + ".mo": "application/x-gettext-translation", + ".mo3": "audio/x-mo3", + ".moc": "text/x-moc", + ".mod": "audio/x-mod", + ".mof": "text/x-mof", + ".moov": "video/quicktime", + ".mov": "video/quicktime", + ".movie": "video/x-sgi-movie", + ".mp+": "audio/x-musepack", + ".mp2": "video/mpeg", + ".mp3": "audio/mpeg", + ".mp4": "video/mp4", + ".mpc": "audio/x-musepack", + ".mpe": "video/mpeg", + ".mpeg": "video/mpeg", + ".mpg": "video/mpeg", + ".mpga": "audio/mpeg", + ".mpp": "audio/x-musepack", + ".mrl": "text/x-mrml", + ".mrml": "text/x-mrml", + ".mrw": "image/x-minolta-mrw", + ".ms": "text/x-troff-ms", + ".msi": "application/x-msi", + ".msod": "image/x-msod", + ".msx": "application/x-msx-rom", + ".mtm": "audio/x-mod", + ".mup": "text/x-mup", + ".mxf": "application/mxf", + ".n64": "application/x-n64-rom", + ".nb": "application/mathematica", + ".nc": "application/x-netcdf", + ".nds": "application/x-nintendo-ds-rom", + ".nef": "image/x-nikon-nef", + ".nes": "application/x-nes-rom", + ".nfo": "text/x-nfo", + ".not": "text/x-mup", + ".nsc": "application/x-netshow-channel", + ".nsv": "video/x-nsv", + ".o": "application/x-object", + ".obj": "application/x-tgif", + ".ocl": "text/x-ocl", + ".oda": "application/oda", + ".odb": "application/vnd.oasis.opendocument.database", + ".odc": "application/vnd.oasis.opendocument.chart", + ".odf": "application/vnd.oasis.opendocument.formula", + ".odg": "application/vnd.oasis.opendocument.graphics", + ".odi": "application/vnd.oasis.opendocument.image", + ".odm": "application/vnd.oasis.opendocument.text-master", + ".odp": "application/vnd.oasis.opendocument.presentation", + ".ods": "application/vnd.oasis.opendocument.spreadsheet", + ".odt": "application/vnd.oasis.opendocument.text", + ".oga": "audio/ogg", + ".ogg": "video/x-theora+ogg", + ".ogm": "video/x-ogm+ogg", + ".ogv": "video/ogg", + ".ogx": "application/ogg", + ".old": "application/x-trash", + ".oleo": "application/x-oleo", + ".opml": "text/x-opml+xml", + ".ora": "image/openraster", + ".orf": "image/x-olympus-orf", + ".otc": "application/vnd.oasis.opendocument.chart-template", + ".otf": "application/x-font-otf", + ".otg": "application/vnd.oasis.opendocument.graphics-template", + ".oth": "application/vnd.oasis.opendocument.text-web", + ".otp": "application/vnd.oasis.opendocument.presentation-template", + ".ots": "application/vnd.oasis.opendocument.spreadsheet-template", + ".ott": "application/vnd.oasis.opendocument.text-template", + ".owl": "application/rdf+xml", + ".oxt": "application/vnd.openofficeorg.extension", + ".p": "text/x-pascal", + ".p10": "application/pkcs10", + ".p12": "application/x-pkcs12", + ".p7b": "application/x-pkcs7-certificates", + ".p7s": "application/pkcs7-signature", + ".pack": "application/x-java-pack200", + ".pak": "application/x-pak", + ".par2": "application/x-par2", + ".pas": "text/x-pascal", + ".patch": "text/x-patch", + ".pbm": "image/x-portable-bitmap", + ".pcd": "image/x-photo-cd", + ".pcf": "application/x-cisco-vpn-settings", + ".pcf.gz": "application/x-font-pcf", + ".pcf.z": "application/x-font-pcf", + ".pcl": "application/vnd.hp-pcl", + ".pcx": "image/x-pcx", + ".pdb": "chemical/x-pdb", + ".pdc": "application/x-aportisdoc", + ".pdf": "application/pdf", + ".pdf.bz2": "application/x-bzpdf", + ".pdf.gz": "application/x-gzpdf", + ".pef": "image/x-pentax-pef", + ".pem": "application/x-x509-ca-cert", + ".perl": "application/x-perl", + ".pfa": "application/x-font-type1", + ".pfb": "application/x-font-type1", + ".pfx": "application/x-pkcs12", + ".pgm": "image/x-portable-graymap", + ".pgn": "application/x-chess-pgn", + ".pgp": "application/pgp-encrypted", + ".php": "application/x-php", + ".php3": "application/x-php", + ".php4": "application/x-php", + ".pict": "image/x-pict", + ".pict1": "image/x-pict", + ".pict2": "image/x-pict", + ".pickle": "application/python-pickle", + ".pk": "application/x-tex-pk", + ".pkipath": "application/pkix-pkipath", + ".pkr": "application/pgp-keys", + ".pl": "application/x-perl", + ".pla": "audio/x-iriver-pla", + ".pln": "application/x-planperfect", + ".pls": "audio/x-scpls", + ".pm": "application/x-perl", + ".png": "image/png", + ".pnm": "image/x-portable-anymap", + ".pntg": "image/x-macpaint", + ".po": "text/x-gettext-translation", + ".por": "application/x-spss-por", + ".pot": "text/x-gettext-translation-template", + ".ppm": "image/x-portable-pixmap", + ".pps": "application/vnd.ms-powerpoint", + ".ppt": "application/vnd.ms-powerpoint", + ".pptm": "application/vnd.openxmlformats-officedocument.presentationml.presentation", + ".pptx": "application/vnd.openxmlformats-officedocument.presentationml.presentation", + ".ppz": "application/vnd.ms-powerpoint", + ".prc": "application/x-palm-database", + ".ps": "application/postscript", + ".ps.bz2": "application/x-bzpostscript", + ".ps.gz": "application/x-gzpostscript", + ".psd": "image/vnd.adobe.photoshop", + ".psf": "audio/x-psf", + ".psf.gz": "application/x-gz-font-linux-psf", + ".psflib": "audio/x-psflib", + ".psid": "audio/prs.sid", + ".psw": "application/x-pocket-word", + ".pw": "application/x-pw", + ".py": "text/x-python", + ".pyc": "application/x-python-bytecode", + ".pyo": "application/x-python-bytecode", + ".qif": "image/x-quicktime", + ".qt": "video/quicktime", + ".qtif": "image/x-quicktime", + ".qtl": "application/x-quicktime-media-link", + ".qtvr": "video/quicktime", + ".ra": "audio/vnd.rn-realaudio", + ".raf": "image/x-fuji-raf", + ".ram": "application/ram", + ".rar": "application/x-rar", + ".ras": "image/x-cmu-raster", + ".raw": "image/x-panasonic-raw", + ".rax": "audio/vnd.rn-realaudio", + ".rb": "application/x-ruby", + ".rdf": "application/rdf+xml", + ".rdfs": "application/rdf+xml", + ".reg": "text/x-ms-regedit", + ".rej": "application/x-reject", + ".rgb": "image/x-rgb", + ".rle": "image/rle", + ".rm": "application/vnd.rn-realmedia", + ".rmj": "application/vnd.rn-realmedia", + ".rmm": "application/vnd.rn-realmedia", + ".rms": "application/vnd.rn-realmedia", + ".rmvb": "application/vnd.rn-realmedia", + ".rmx": "application/vnd.rn-realmedia", + ".roff": "text/troff", + ".rp": "image/vnd.rn-realpix", + ".rpm": "application/x-rpm", + ".rss": "application/rss+xml", + ".rt": "text/vnd.rn-realtext", + ".rtf": "application/rtf", + ".rtx": "text/richtext", + ".rv": "video/vnd.rn-realvideo", + ".rvx": "video/vnd.rn-realvideo", + ".s3m": "audio/x-s3m", + ".sam": "application/x-amipro", + ".sami": "application/x-sami", + ".sav": "application/x-spss-sav", + ".scm": "text/x-scheme", + ".sda": "application/vnd.stardivision.draw", + ".sdc": "application/vnd.stardivision.calc", + ".sdd": "application/vnd.stardivision.impress", + ".sdp": "application/sdp", + ".sds": "application/vnd.stardivision.chart", + ".sdw": "application/vnd.stardivision.writer", + ".sgf": "application/x-go-sgf", + ".sgi": "image/x-sgi", + ".sgl": "application/vnd.stardivision.writer", + ".sgm": "text/sgml", + ".sgml": "text/sgml", + ".sh": "application/x-shellscript", + ".shar": "application/x-shar", + ".shn": "application/x-shorten", + ".siag": "application/x-siag", + ".sid": "audio/prs.sid", + ".sik": "application/x-trash", + ".sis": "application/vnd.symbian.install", + ".sisx": "x-epoc/x-sisx-app", + ".sit": "application/x-stuffit", + ".siv": "application/sieve", + ".sk": "image/x-skencil", + ".sk1": "image/x-skencil", + ".skr": "application/pgp-keys", + ".slk": "text/spreadsheet", + ".smaf": "application/x-smaf", + ".smc": "application/x-snes-rom", + ".smd": "application/vnd.stardivision.mail", + ".smf": "application/vnd.stardivision.math", + ".smi": "application/x-sami", + ".smil": "application/smil", + ".sml": "application/smil", + ".sms": "application/x-sms-rom", + ".snd": "audio/basic", + ".so": "application/x-sharedlib", + ".spc": "application/x-pkcs7-certificates", + ".spd": "application/x-font-speedo", + ".spec": "text/x-rpm-spec", + ".spl": "application/x-shockwave-flash", + ".spx": "audio/x-speex", + ".sql": "text/x-sql", + ".sr2": "image/x-sony-sr2", + ".src": "application/x-wais-source", + ".srf": "image/x-sony-srf", + ".srt": "application/x-subrip", + ".ssa": "text/x-ssa", + ".stc": "application/vnd.sun.xml.calc.template", + ".std": "application/vnd.sun.xml.draw.template", + ".sti": "application/vnd.sun.xml.impress.template", + ".stm": "audio/x-stm", + ".stw": "application/vnd.sun.xml.writer.template", + ".sty": "text/x-tex", + ".sub": "text/x-subviewer", + ".sun": "image/x-sun-raster", + ".sv4cpio": "application/x-sv4cpio", + ".sv4crc": "application/x-sv4crc", + ".svg": "image/svg+xml", + ".svgz": "image/svg+xml-compressed", + ".swf": "application/x-shockwave-flash", + ".sxc": "application/vnd.sun.xml.calc", + ".sxd": "application/vnd.sun.xml.draw", + ".sxg": "application/vnd.sun.xml.writer.global", + ".sxi": "application/vnd.sun.xml.impress", + ".sxm": "application/vnd.sun.xml.math", + ".sxw": "application/vnd.sun.xml.writer", + ".sylk": "text/spreadsheet", + ".t": "text/troff", + ".t2t": "text/x-txt2tags", + ".tar": "application/x-tar", + ".tar.bz": "application/x-bzip-compressed-tar", + ".tar.bz2": "application/x-bzip-compressed-tar", + ".tar.gz": "application/x-compressed-tar", + ".tar.lzma": "application/x-lzma-compressed-tar", + ".tar.lzo": "application/x-tzo", + ".tar.xz": "application/x-xz-compressed-tar", + ".tar.z": "application/x-tarz", + ".tbz": "application/x-bzip-compressed-tar", + ".tbz2": "application/x-bzip-compressed-tar", + ".tcl": "text/x-tcl", + ".tex": "text/x-tex", + ".texi": "text/x-texinfo", + ".texinfo": "text/x-texinfo", + ".tga": "image/x-tga", + ".tgz": "application/x-compressed-tar", + ".theme": "application/x-theme", + ".themepack": "application/x-windows-themepack", + ".tif": "image/tiff", + ".tiff": "image/tiff", + ".tk": "text/x-tcl", + ".tlz": "application/x-lzma-compressed-tar", + ".tnef": "application/vnd.ms-tnef", + ".tnf": "application/vnd.ms-tnef", + ".toc": "application/x-cdrdao-toc", + ".torrent": "application/x-bittorrent", + ".tpic": "image/x-tga", + ".tr": "text/troff", + ".ts": "application/x-linguist", + ".tsv": "text/tab-separated-values", + ".tta": "audio/x-tta", + ".ttc": "application/x-font-ttf", + ".ttf": "application/x-font-ttf", + ".ttx": "application/x-font-ttx", + ".txt": "text/plain", + ".txz": "application/x-xz-compressed-tar", + ".tzo": "application/x-tzo", + ".ufraw": "application/x-ufraw", + ".ui": "application/x-designer", + ".uil": "text/x-uil", + ".ult": "audio/x-mod", + ".uni": "audio/x-mod", + ".uri": "text/x-uri", + ".url": "text/x-uri", + ".ustar": "application/x-ustar", + ".vala": "text/x-vala", + ".vapi": "text/x-vala", + ".vcf": "text/directory", + ".vcs": "text/calendar", + ".vct": "text/directory", + ".vda": "image/x-tga", + ".vhd": "text/x-vhdl", + ".vhdl": "text/x-vhdl", + ".viv": "video/vivo", + ".vivo": "video/vivo", + ".vlc": "audio/x-mpegurl", + ".vob": "video/mpeg", + ".voc": "audio/x-voc", + ".vor": "application/vnd.stardivision.writer", + ".vst": "image/x-tga", + ".wav": "audio/x-wav", + ".wax": "audio/x-ms-asx", + ".wb1": "application/x-quattropro", + ".wb2": "application/x-quattropro", + ".wb3": "application/x-quattropro", + ".wbmp": "image/vnd.wap.wbmp", + ".wcm": "application/vnd.ms-works", + ".wdb": "application/vnd.ms-works", + ".webm": "video/webm", + ".wk1": "application/vnd.lotus-1-2-3", + ".wk3": "application/vnd.lotus-1-2-3", + ".wk4": "application/vnd.lotus-1-2-3", + ".wks": "application/vnd.ms-works", + ".wma": "audio/x-ms-wma", + ".wmf": "image/x-wmf", + ".wml": "text/vnd.wap.wml", + ".wmls": "text/vnd.wap.wmlscript", + ".wmv": "video/x-ms-wmv", + ".wmx": "audio/x-ms-asx", + ".wp": "application/vnd.wordperfect", + ".wp4": "application/vnd.wordperfect", + ".wp5": "application/vnd.wordperfect", + ".wp6": "application/vnd.wordperfect", + ".wpd": "application/vnd.wordperfect", + ".wpg": "application/x-wpg", + ".wpl": "application/vnd.ms-wpl", + ".wpp": "application/vnd.wordperfect", + ".wps": "application/vnd.ms-works", + ".wri": "application/x-mswrite", + ".wrl": "model/vrml", + ".wv": "audio/x-wavpack", + ".wvc": "audio/x-wavpack-correction", + ".wvp": "audio/x-wavpack", + ".wvx": "audio/x-ms-asx", + ".x3f": "image/x-sigma-x3f", + ".xac": "application/x-gnucash", + ".xbel": "application/x-xbel", + ".xbl": "application/xml", + ".xbm": "image/x-xbitmap", + ".xcf": "image/x-xcf", + ".xcf.bz2": "image/x-compressed-xcf", + ".xcf.gz": "image/x-compressed-xcf", + ".xhtml": "application/xhtml+xml", + ".xi": "audio/x-xi", + ".xla": "application/vnd.ms-excel", + ".xlc": "application/vnd.ms-excel", + ".xld": "application/vnd.ms-excel", + ".xlf": "application/x-xliff", + ".xliff": "application/x-xliff", + ".xll": "application/vnd.ms-excel", + ".xlm": "application/vnd.ms-excel", + ".xls": "application/vnd.ms-excel", + ".xlsm": "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", + ".xlsx": "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", + ".xlt": "application/vnd.ms-excel", + ".xlw": "application/vnd.ms-excel", + ".xm": "audio/x-xm", + ".xmf": "audio/x-xmf", + ".xmi": "text/x-xmi", + ".xml": "application/xml", + ".xpm": "image/x-xpixmap", + ".xps": "application/vnd.ms-xpsdocument", + ".xsl": "application/xml", + ".xslfo": "text/x-xslfo", + ".xslt": "application/xml", + ".xspf": "application/xspf+xml", + ".xul": "application/vnd.mozilla.xul+xml", + ".xwd": "image/x-xwindowdump", + ".xyz": "chemical/x-pdb", + ".xz": "application/x-xz", + ".w2p": "application/w2p", + ".z": "application/x-compress", + ".zabw": "application/x-abiword", + ".zip": "application/zip", + ".zoo": "application/x-zoo", } -def contenttype(filename, default='text/plain'): +def contenttype(filename, default="text/plain"): """ Returns the Content-Type string matching extension of the given filename. """ - i = filename.rfind('.') + i = filename.rfind(".") if i >= 0: default = CONTENT_TYPE.get(filename[i:].lower(), default) - j = filename.rfind('.', 0, i) + j = filename.rfind(".", 0, i) if j >= 0: default = CONTENT_TYPE.get(filename[j:].lower(), default) - if default.startswith('text/'): - default += '; charset=utf-8' + if default.startswith("text/"): + default += "; charset=utf-8" return default diff --git a/emmett/libs/portalocker.py b/emmett/libs/portalocker.py index 720963e3..480452bf 100644 --- a/emmett/libs/portalocker.py +++ b/emmett/libs/portalocker.py @@ -43,22 +43,23 @@ os_locking = None try: - import google.appengine - os_locking = 'gae' -except: + os_locking = "gae" +except Exception: try: import fcntl - os_locking = 'posix' - except: + + os_locking = "posix" + except Exception: try: + import pywintypes import win32con import win32file - import pywintypes - os_locking = 'windows' - except: + + os_locking = "windows" + except Exception: pass -if os_locking == 'windows': +if os_locking == "windows": LOCK_EX = win32con.LOCKFILE_EXCLUSIVE_LOCK LOCK_SH = 0 # the default LOCK_NB = win32con.LOCKFILE_FAIL_IMMEDIATELY @@ -69,14 +70,14 @@ def lock(file, flags): hfile = win32file._get_osfhandle(file.fileno()) - win32file.LockFileEx(hfile, flags, 0, 0x7fff0000, __overlapped) + win32file.LockFileEx(hfile, flags, 0, 0x7FFF0000, __overlapped) def unlock(file): hfile = win32file._get_osfhandle(file.fileno()) - win32file.UnlockFileEx(hfile, 0, 0x7fff0000, __overlapped) + win32file.UnlockFileEx(hfile, 0, 0x7FFF0000, __overlapped) -elif os_locking == 'posix': +elif os_locking == "posix": LOCK_EX = fcntl.LOCK_EX LOCK_SH = fcntl.LOCK_SH LOCK_NB = fcntl.LOCK_NB @@ -89,9 +90,9 @@ def unlock(file): else: - #if platform.system() == 'Windows': + # if platform.system() == 'Windows': # logger.error('no file locking, you must install the win32 extensions from: http://sourceforge.net/projects/pywin32/files/') - #elif os_locking != 'gae': + # elif os_locking != 'gae': # logger.debug('no file locking, this will cause problems') LOCK_EX = None @@ -106,18 +107,18 @@ def unlock(file): class LockedFile(object): - def __init__(self, filename, mode='rb'): + def __init__(self, filename, mode="rb"): self.filename = filename self.mode = mode self.file = None - if 'r' in mode: - kwargs = {'encoding': 'utf8'} if 'b' not in mode else {} + if "r" in mode: + kwargs = {"encoding": "utf8"} if "b" not in mode else {} self.file = open(filename, mode, **kwargs) lock(self.file, LOCK_SH) - elif 'w' in mode or 'a' in mode: - self.file = open(filename, mode.replace('w', 'a')) + elif "w" in mode or "a" in mode: + self.file = open(filename, mode.replace("w", "a")) lock(self.file, LOCK_EX) - if 'a' not in mode: + if "a" not in mode: self.file.seek(0) self.file.truncate() else: @@ -148,13 +149,13 @@ def __del__(self): def read_locked(filename): - fp = LockedFile(filename, 'r') + fp = LockedFile(filename, "r") data = fp.read() fp.close() return data def write_locked(filename, data): - fp = LockedFile(filename, 'wb') + fp = LockedFile(filename, "wb") data = fp.write(data) fp.close() diff --git a/emmett/locals.py b/emmett/locals.py index 1230ddf8..b99b44e5 100644 --- a/emmett/locals.py +++ b/emmett/locals.py @@ -1,19 +1,19 @@ # -*- coding: utf-8 -*- """ - emmett.locals - ------------- +emmett.locals +------------- - Provides shortcuts to `current` object. +Provides shortcuts to `current` object. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ from typing import Optional, cast +from emmett_core._internal import ContextVarProxy as _VProxy, ObjectProxy as _OProxy from pendulum import DateTime -from ._internal import ContextVarProxy as _VProxy, ObjectProxy as _OProxy from .ctx import _ctxv, current from .datastructures import sdict from .language.translator import Translator @@ -21,11 +21,12 @@ from .wrappers.response import Response from .wrappers.websocket import Websocket -request = cast(Request, _VProxy[Request](_ctxv, 'request')) -response = cast(Response, _VProxy[Response](_ctxv, 'response')) -session = cast(Optional[sdict], _VProxy[Optional[sdict]](_ctxv, 'session')) -websocket = cast(Websocket, _VProxy[Websocket](_ctxv, 'websocket')) -T = cast(Translator, _OProxy[Translator](current, 'T')) + +request = cast(Request, _VProxy[Request](_ctxv, "request")) +response = cast(Response, _VProxy[Response](_ctxv, "response")) +session = cast(Optional[sdict], _VProxy[Optional[sdict]](_ctxv, "session")) +websocket = cast(Websocket, _VProxy[Websocket](_ctxv, "websocket")) +T = cast(Translator, _OProxy[Translator](current, "T")) def now() -> DateTime: diff --git a/emmett/logger.py b/emmett/logger.py deleted file mode 100644 index 95b2d568..00000000 --- a/emmett/logger.py +++ /dev/null @@ -1,101 +0,0 @@ -# -*- coding: utf-8 -*- -""" - emmett.logger - ------------- - - Provides logging utitilites for Emmett applications. - - :copyright: 2014 Giovanni Barillari - - Based on the code of Flask (http://flask.pocoo.org) - :copyright: (c) 2014 by Armin Ronacher. - - :license: BSD-3-Clause -""" - -import logging -import os - -from logging import StreamHandler, Formatter -from logging.handlers import RotatingFileHandler -from threading import Lock - -from .datastructures import sdict - -_logger_lock = Lock() - -LOG_LEVELS = { - 'debug': logging.DEBUG, - 'info': logging.INFO, - 'warning': logging.WARNING, - 'error': logging.ERROR, - 'critical': logging.CRITICAL -} - -_def_log_config = sdict( - production=sdict( - max_size=5 * 1024 * 1024, - file_no=4, - level='warning', - format='[%(asctime)s] %(levelname)s in %(module)s: %(message)s', - on_app_debug=False)) - -_debug_log_format = ( - '> %(levelname)s in %(module)s [%(pathname)s:%(lineno)d]:\n' + - '%(message)s' -) - - -def create_logger(app): - Logger = logging.getLoggerClass() - - class DebugLogger(Logger): - def getEffectiveLevel(x): - if x.level == 0 and app.debug: - return logging.DEBUG - return Logger.getEffectiveLevel(x) - - class DebugHandler(StreamHandler): - def emit(x, record): - StreamHandler.emit(x, record) if app.debug else None - - class DebugRFHandler(RotatingFileHandler): - def emit(x, record): - RotatingFileHandler.emit(x, record) if app.debug else None - - class ProdRFHandler(RotatingFileHandler): - def emit(x, record): - RotatingFileHandler.emit(x, record) if not app.debug else None - - # init the console debug handler - debug_handler = DebugHandler() - debug_handler.setLevel(logging.DEBUG) - debug_handler.setFormatter(Formatter(_debug_log_format)) - logger = logging.getLogger(app.logger_name) - # just in case that was not a new logger, get rid of all the handlers - # already attached to it. - del logger.handlers[:] - logger.__class__ = DebugLogger - logger.addHandler(debug_handler) - # load application logging config - app_logs = app.config.logging - if not app_logs: - app_logs = _def_log_config - for lname, lconf in app_logs.items(): - lfile = os.path.join(app.root_path, 'logs', lname + '.log') - max_size = lconf.max_size or _def_log_config.production.max_size - file_no = lconf.file_no or _def_log_config.production.file_no - level = LOG_LEVELS.get( - lconf.level or 'warning', LOG_LEVELS.get('warning')) - lformat = lconf.format or _def_log_config.production.format - on_app_debug = lconf.on_app_debug - if on_app_debug: - handler = DebugRFHandler( - lfile, maxBytes=max_size, backupCount=file_no) - else: - handler = ProdRFHandler( - lfile, maxBytes=max_size, backupCount=file_no) - handler.setLevel(level) - handler.setFormatter(Formatter(lformat)) - logger.addHandler(handler) - return logger diff --git a/emmett/orm/__init__.py b/emmett/orm/__init__.py index 8a8ffa65..cefe76a6 100644 --- a/emmett/orm/__init__.py +++ b/emmett/orm/__init__.py @@ -1,16 +1,27 @@ from . import _patches from .adapters import adapters as adapters_registry -from .base import Database -from .objects import Field, TransactionOps -from .models import Model from .apis import ( - belongs_to, refers_to, has_one, has_many, - compute, rowattr, rowmethod, - before_insert, before_update, before_delete, - before_save, before_destroy, - before_commit, - after_insert, after_update, after_delete, - after_save, after_destroy, after_commit, - scope + after_delete, + after_destroy, + after_insert, + after_save, + after_update, + before_commit, + before_delete, + before_destroy, + before_insert, + before_save, + before_update, + belongs_to, + compute, + has_many, + has_one, + refers_to, + rowattr, + rowmethod, + scope, ) +from .base import Database +from .models import Model +from .objects import Field, TransactionOps diff --git a/emmett/orm/_patches.py b/emmett/orm/_patches.py index 91b1ea6f..8db05042 100644 --- a/emmett/orm/_patches.py +++ b/emmett/orm/_patches.py @@ -1,84 +1,68 @@ # -*- coding: utf-8 -*- """ - emmett.orm._patches - ------------------- +emmett.orm._patches +------------------- - Provides pyDAL patches. +Provides pyDAL patches. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ -from ..serializers import Serializers -from ..utils import cachedprop - +from emmett_core.utils import cachedprop from pydal.adapters.base import BaseAdapter from pydal.connection import ConnectionPool from pydal.helpers.classes import ConnectionConfigurationMixin from pydal.helpers.serializers import Serializers as _Serializers +from ..serializers import Serializers from .adapters import ( - _initialize, _begin, _in_transaction, - _push_transaction, + _initialize, _pop_transaction, + _push_transaction, + _top_transaction, _transaction_depth, - _top_transaction ) from .connection import ( - ConnectionManager, PooledConnectionManager, - _connection_init, - _connect_sync, - _connect_loop, - _close_sync, _close_loop, + _close_sync, + _connect_and_configure, + _connect_loop, + _connect_sync, _connection_getter, + _connection_init, _connection_setter, _cursors_getter, - _connect_and_configure ) -from .engines.sqlite import SQLite def _patch_adapter_cls(): - setattr(BaseAdapter, '_initialize_', _initialize) - setattr(BaseAdapter, 'in_transaction', _in_transaction) - setattr(BaseAdapter, 'push_transaction', _push_transaction) - setattr(BaseAdapter, 'pop_transaction', _pop_transaction) - setattr(BaseAdapter, 'transaction_depth', _transaction_depth) - setattr(BaseAdapter, 'top_transaction', _top_transaction) - setattr(BaseAdapter, '_connection_manager_cls', PooledConnectionManager) - setattr(BaseAdapter, 'begin', _begin) - setattr(SQLite, '_connection_manager_cls', ConnectionManager) + BaseAdapter._initialize_ = _initialize + BaseAdapter.in_transaction = _in_transaction + BaseAdapter.push_transaction = _push_transaction + BaseAdapter.pop_transaction = _pop_transaction + BaseAdapter.transaction_depth = _transaction_depth + BaseAdapter.top_transaction = _top_transaction + BaseAdapter._connection_manager_cls = PooledConnectionManager + BaseAdapter.begin = _begin def _patch_adapter_connection(): - setattr(ConnectionPool, '__init__', _connection_init) - setattr(ConnectionPool, 'reconnect', _connect_sync) - setattr(ConnectionPool, 'reconnect_loop', _connect_loop) - setattr(ConnectionPool, 'close', _close_sync) - setattr(ConnectionPool, 'close_loop', _close_loop) - setattr( - ConnectionPool, - 'connection', - property(_connection_getter, _connection_setter) - ) - setattr(ConnectionPool, 'cursors', property(_cursors_getter)) - setattr( - ConnectionConfigurationMixin, - '_reconnect_and_configure', - _connect_and_configure - ) + ConnectionPool.__init__ = _connection_init + ConnectionPool.reconnect = _connect_sync + ConnectionPool.reconnect_loop = _connect_loop + ConnectionPool.close = _close_sync + ConnectionPool.close_loop = _close_loop + ConnectionPool.connection = property(_connection_getter, _connection_setter) + ConnectionPool.cursors = property(_cursors_getter) + ConnectionConfigurationMixin._reconnect_and_configure = _connect_and_configure def _patch_serializers(): - setattr( - _Serializers, - 'json', - cachedprop(lambda _: Serializers.get_for('json'), name='json') - ) + _Serializers.json = cachedprop(lambda _: Serializers.get_for("json"), name="json") _patch_adapter_cls() diff --git a/emmett/orm/adapters.py b/emmett/orm/adapters.py index 089a49a5..9ac806b9 100644 --- a/emmett/orm/adapters.py +++ b/emmett/orm/adapters.py @@ -1,35 +1,20 @@ # -*- coding: utf-8 -*- """ - emmett.orm.adapters - ------------------- +emmett.orm.adapters +------------------- - Provides ORM adapters facilities. +Provides ORM adapters facilities. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ import sys - from functools import wraps from pydal.adapters.base import SQLAdapter -from pydal.adapters.mssql import ( - MSSQL1, - MSSQL3, - MSSQL4, - MSSQL1N, - MSSQL3N, - MSSQL4N -) -from pydal.adapters.postgres import ( - Postgre, - PostgrePsyco, - PostgrePG8000, - PostgreNew, - PostgrePsycoNew, - PostgrePG8000New -) +from pydal.adapters.mssql import MSSQL1, MSSQL1N, MSSQL3, MSSQL3N, MSSQL4, MSSQL4N +from pydal.adapters.postgres import Postgre, PostgreNew, PostgrePG8000, PostgrePG8000New, PostgrePsyco, PostgrePsycoNew from pydal.helpers.classes import SQLALL from pydal.helpers.regex import REGEX_TABLE_DOT_FIELD from pydal.parsers import ParserMethodWrapper, for_type as _parser_for_type @@ -37,37 +22,38 @@ from .engines import adapters from .helpers import GeoFieldWrapper, PasswordFieldWrapper, typed_row_reference -from .objects import Expression, Field, Row, IterRows - - -adapters._registry_.update({ - 'mssql': MSSQL4, - 'mssql2': MSSQL1, - 'mssql3': MSSQL3, - 'mssqln': MSSQL4N, - 'mssqln2': MSSQL1N, - 'mssqln3': MSSQL3N, - 'postgres2': PostgreNew, - 'postgres2:psycopg2': PostgrePsycoNew, - 'postgres2:pg8000': PostgrePG8000New, - 'postgres3': Postgre, - 'postgres3:psycopg2': PostgrePsyco, - 'postgres3:pg8000': PostgrePG8000 -}) +from .objects import Expression, Field, IterRows, Row + + +adapters._registry_.update( + { + "mssql": MSSQL4, + "mssql2": MSSQL1, + "mssql3": MSSQL3, + "mssqln": MSSQL4N, + "mssqln2": MSSQL1N, + "mssqln3": MSSQL3N, + "postgres2": PostgreNew, + "postgres2:psycopg2": PostgrePsycoNew, + "postgres2:pg8000": PostgrePG8000New, + "postgres3": Postgre, + "postgres3:psycopg2": PostgrePsyco, + "postgres3:pg8000": PostgrePG8000, + } +) def _wrap_on_obj(f, adapter): @wraps(f) def wrapped(*args, **kwargs): return f(adapter, *args, **kwargs) + return wrapped def patch_adapter(adapter): #: BaseAdapter interfaces - adapter._expand_all_with_concrete_tables = _wrap_on_obj( - _expand_all_with_concrete_tables, adapter - ) + adapter._expand_all_with_concrete_tables = _wrap_on_obj(_expand_all_with_concrete_tables, adapter) adapter._parse = _wrap_on_obj(_parse, adapter) adapter._parse_expand_colnames = _wrap_on_obj(_parse_expand_colnames, adapter) adapter.iterparse = _wrap_on_obj(iterparse, adapter) @@ -87,44 +73,28 @@ def patch_adapter(adapter): def patch_dialect(dialect): - _create_table_map = { - 'mysql': _create_table_mysql, - 'firebird': _create_table_firebird - } - dialect.create_table = _wrap_on_obj( - _create_table_map.get(dialect.adapter.dbengine, _create_table), dialect - ) + _create_table_map = {"mysql": _create_table_mysql, "firebird": _create_table_firebird} + dialect.create_table = _wrap_on_obj(_create_table_map.get(dialect.adapter.dbengine, _create_table), dialect) dialect.add_foreign_key_constraint = _wrap_on_obj(_add_fk_constraint, dialect) dialect.drop_constraint = _wrap_on_obj(_drop_constraint, dialect) def patch_parser(dialect, parser): - parser.registered['password'] = ParserMethodWrapper( - parser, - _parser_for_type('password')(_parser_password).f + parser.registered["password"] = ParserMethodWrapper(parser, _parser_for_type("password")(_parser_password).f) + parser.registered["reference"] = ParserMethodWrapper( + parser, _parser_for_type("reference")(_parser_reference).f, parser._before_registry_["reference"] ) - parser.registered['reference'] = ParserMethodWrapper( - parser, - _parser_for_type('reference')(_parser_reference).f, - parser._before_registry_['reference'] - ) - if 'geography' in dialect.types: - parser.registered['geography'] = ParserMethodWrapper( - parser, - _parser_for_type('geography')(_parser_geo).f - ) - if 'geometry' in dialect.types: - parser.registered['geometry'] = ParserMethodWrapper( - parser, - _parser_for_type('geometry')(_parser_geo).f - ) + if "geography" in dialect.types: + parser.registered["geography"] = ParserMethodWrapper(parser, _parser_for_type("geography")(_parser_geo).f) + if "geometry" in dialect.types: + parser.registered["geometry"] = ParserMethodWrapper(parser, _parser_for_type("geometry")(_parser_geo).f) def patch_representer(representer): - representer.registered_t['reference'] = TReprMethodWrapper( + representer.registered_t["reference"] = TReprMethodWrapper( representer, - _representer_for_type('reference')(_representer_reference), - representer._tbefore_registry_['reference'] + _representer_for_type("reference")(_representer_reference), + representer._tbefore_registry_["reference"], ) @@ -132,17 +102,14 @@ def insert(adapter, table, fields): query = adapter._insert(table, fields) try: adapter.execute(query) - except: + except Exception: e = sys.exc_info()[1] - if hasattr(table, '_on_insert_error'): + if hasattr(table, "_on_insert_error"): return table._on_insert_error(table, fields, e) raise e if not table._id: - id = { - field.name: val for field, val in fields - if field.name in table._primarykey - } or None - elif table._id.type == 'id': + id = {field.name: val for field, val in fields if field.name in table._primarykey} or None + elif table._id.type == "id": id = adapter.lastrowid(table) else: id = {field.name: val for field, val in fields}.get(table._id.name) @@ -193,7 +160,7 @@ def _select_wcols( orderby_on_limitby=True, for_update=False, outer_scoped=[], - **kwargs + **kwargs, ): return adapter._select_wcols_inner( query, @@ -207,7 +174,7 @@ def _select_wcols( limitby=limitby, orderby_on_limitby=orderby_on_limitby, for_update=for_update, - outer_scoped=outer_scoped + outer_scoped=outer_scoped, ) @@ -215,42 +182,25 @@ def _select_aux(adapter, sql, fields, attributes, colnames): rows = adapter._select_aux_execute(sql) if isinstance(rows, tuple): rows = list(rows) - limitby = attributes.get('limitby', None) or (0,) + limitby = attributes.get("limitby", None) or (0,) rows = adapter.rowslice(rows, limitby[0], None) - return adapter.parse( - rows, - fields, - colnames, - concrete_tables=attributes.get('_concrete_tables', []) - ) + return adapter.parse(rows, fields, colnames, concrete_tables=attributes.get("_concrete_tables", [])) def parse(adapter, rows, fields, colnames, **options): fdata, tables = _parse_expand_colnames(adapter, fields) new_rows = [ _parse( - adapter, - row, - fdata, - tables, - options['concrete_tables'], - fields, - colnames, - options.get('blob_decode', True) - ) for row in rows + adapter, row, fdata, tables, options["concrete_tables"], fields, colnames, options.get("blob_decode", True) + ) + for row in rows ] rowsobj = adapter.db.Rows(adapter.db, new_rows, colnames, rawrows=rows) return rowsobj def iterparse(adapter, sql, fields, colnames, **options): - return IterRows( - adapter.db, - sql, - fields, - options.get('_concrete_tables', []), - colnames - ) + return IterRows(adapter.db, sql, fields, options.get("_concrete_tables", []), colnames) def _parse_expand_colnames(adapter, fieldlist): @@ -269,12 +219,10 @@ def _parse_expand_colnames(adapter, fieldlist): def _parse(adapter, row, fdata, tables, concrete_tables, fields, colnames, blob_decode): - new_row, rows_cls, rows_accum = _build_newrow_wtables( - adapter, tables, concrete_tables - ) + new_row, rows_cls, rows_accum = _build_newrow_wtables(adapter, tables, concrete_tables) extras = adapter.db.Row() #: let's loop over columns - for (idx, colname) in enumerate(colnames): + for idx, colname in enumerate(colnames): value = row[idx] fd = fdata[idx] tablename = None @@ -289,9 +237,7 @@ def _parse(adapter, row, fdata, tables, concrete_tables, fields, colnames, blob_ colset[fieldname] = value #: otherwise we set the value in extras else: - value = adapter.parse_value( - value, fields[idx]._itype, fields[idx].type, blob_decode - ) + value = adapter.parse_value(value, fields[idx]._itype, fields[idx].type, blob_decode) extras[colname] = value new_column_name = adapter._regex_select_as_parser(colname) if new_column_name is not None: @@ -301,13 +247,13 @@ def _parse(adapter, row, fdata, tables, concrete_tables, fields, colnames, blob_ new_row[key] = val._from_engine(rows_accum[key]) #: add extras if needed (eg. operations results) if extras: - new_row['_extra'] = extras + new_row["_extra"] = extras return new_row def _build_newrow_wtables(adapter, tables, concrete_tables): row, cls_map, accum = adapter.db.Row(), {}, {} - for name, table in tables.items(): + for name, _ in tables.items(): cls_map[name] = adapter.db.Row accum[name] = {} for table in concrete_tables: @@ -317,14 +263,14 @@ def _build_newrow_wtables(adapter, tables, concrete_tables): def _create_table(dialect, tablename, fields): - return [ - "CREATE TABLE %s(\n %s\n);" % (dialect.quote(tablename), fields)] + return ["CREATE TABLE %s(\n %s\n);" % (dialect.quote(tablename), fields)] def _create_table_mysql(dialect, tablename, fields): - return ["CREATE TABLE %s(\n %s\n) ENGINE=%s CHARACTER SET utf8;" % ( - dialect.quote(tablename), fields, - dialect.adapter.adapter_args.get('engine', 'InnoDB'))] + return [ + "CREATE TABLE %s(\n %s\n) ENGINE=%s CHARACTER SET utf8;" + % (dialect.quote(tablename), fields, dialect.adapter.adapter_args.get("engine", "InnoDB")) + ] def _create_table_firebird(dialect, tablename, fields): @@ -332,30 +278,25 @@ def _create_table_firebird(dialect, tablename, fields): sequence_name = dialect.sequence_name(tablename) trigger_name = dialect.trigger_name(tablename) trigger_sql = ( - 'create trigger %s for %s active before insert position 0 as\n' - 'begin\n' + "create trigger %s for %s active before insert position 0 as\n" + "begin\n" 'if(new."id" is null) then\n' - 'begin\n' + "begin\n" 'new."id" = gen_id(%s, 1);\n' - 'end\n' - 'end;') - rv.extend([ - 'create generator %s;' % sequence_name, - 'set generator %s to 0;' % sequence_name, - trigger_sql % (trigger_name, dialect.quote(tablename), sequence_name) - ]) + "end\n" + "end;" + ) + rv.extend( + [ + "create generator %s;" % sequence_name, + "set generator %s to 0;" % sequence_name, + trigger_sql % (trigger_name, dialect.quote(tablename), sequence_name), + ] + ) return rv -def _add_fk_constraint( - dialect, - name, - table_local, - table_foreign, - columns_local, - columns_foreign, - on_delete -): +def _add_fk_constraint(dialect, name, table_local, table_foreign, columns_local, columns_foreign, on_delete): return ( f"ALTER TABLE {dialect.quote(table_local)} " f"ADD CONSTRAINT {dialect.quote(name)} " @@ -371,7 +312,7 @@ def _drop_constraint(dialect, name, table): def _parser_reference(parser, value, referee): - if '.' not in referee: + if "." not in referee: value = typed_row_reference(value, parser.adapter.db[referee]) return value @@ -385,7 +326,7 @@ def _parser_password(parser, value): def _representer_reference(representer, value, referenced): - rtname, _, rfname = referenced.partition('.') + rtname, _, rfname = referenced.partition(".") rtable = representer.adapter.db[rtname] if not rfname and rtable._id: rfname = rtable._id.name @@ -394,9 +335,9 @@ def _representer_reference(representer, value, referenced): rtype = rtable[rfname].type if isinstance(value, Row) and getattr(value, "_concrete", False): value = value[(value._model.primary_keys or ["id"])[0]] - if rtype in ('id', 'integer'): + if rtype in ("id", "integer"): return str(int(value)) - if rtype == 'string': + if rtype == "string": return str(value) return representer.adapter.represent(value, rtype) @@ -406,7 +347,8 @@ def _initialize(adapter, *args, **kwargs): adapter._connection_manager.configure( max_connections=adapter.db._pool_size, connect_timeout=adapter.db._connect_timeout, - stale_timeout=adapter.db._keep_alive_timeout) + stale_timeout=adapter.db._keep_alive_timeout, + ) def _begin(adapter): diff --git a/emmett/orm/apis.py b/emmett/orm/apis.py index cf5f6a9f..523e839a 100644 --- a/emmett/orm/apis.py +++ b/emmett/orm/apis.py @@ -1,20 +1,19 @@ # -*- coding: utf-8 -*- """ - emmett.orm.apis - --------------- +emmett.orm.apis +--------------- - Provides ORM apis. +Provides ORM apis. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ from collections import OrderedDict -from enum import Enum from typing import List from .errors import MissingFieldsForCompute -from .helpers import Reference, Callback +from .helpers import Callback, Reference class belongs_to(Reference): @@ -94,61 +93,62 @@ class rowmethod(rowattr): def before_insert(f): - return Callback(f, '_before_insert') + return Callback(f, "_before_insert") def after_insert(f): - return Callback(f, '_after_insert') + return Callback(f, "_after_insert") def before_update(f): - return Callback(f, '_before_update') + return Callback(f, "_before_update") def after_update(f): - return Callback(f, '_after_update') + return Callback(f, "_after_update") def before_delete(f): - return Callback(f, '_before_delete') + return Callback(f, "_before_delete") def after_delete(f): - return Callback(f, '_after_delete') + return Callback(f, "_after_delete") def before_save(f): - return Callback(f, '_before_save') + return Callback(f, "_before_save") def after_save(f): - return Callback(f, '_after_save') + return Callback(f, "_after_save") def before_destroy(f): - return Callback(f, '_before_destroy') + return Callback(f, "_before_destroy") def after_destroy(f): - return Callback(f, '_after_destroy') + return Callback(f, "_after_destroy") def before_commit(f): - return Callback(f, '_before_commit') + return Callback(f, "_before_commit") def after_commit(f): - return Callback(f, '_after_commit') + return Callback(f, "_after_commit") def _commit_callback_op(kind, op): def _deco(f): - return Callback(f, f'_{kind}_commit_{op}') + return Callback(f, f"_{kind}_commit_{op}") + return _deco -before_commit.operation = lambda op: _commit_callback_op('before', op) -after_commit.operation = lambda op: _commit_callback_op('after', op) +before_commit.operation = lambda op: _commit_callback_op("before", op) +after_commit.operation = lambda op: _commit_callback_op("after", op) class scope(object): diff --git a/emmett/orm/base.py b/emmett/orm/base.py index 7c235acd..fa8729e9 100644 --- a/emmett/orm/base.py +++ b/emmett/orm/base.py @@ -1,12 +1,12 @@ # -*- coding: utf-8 -*- """ - emmett.orm.base - --------------- +emmett.orm.base +--------------- - Provides base pyDAL implementation for Emmett. +Provides base pyDAL implementation for Emmett. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ from __future__ import annotations @@ -14,21 +14,22 @@ import copyreg import os import threading - from functools import wraps + +from emmett_core.serializers import _json_default from pydal import DAL as _pyDAL from pydal._globals import THREAD_LOCAL +from .._shortcuts import uuid as _uuid from ..datastructures import sdict from ..extensions import Signals from ..pipeline import Pipe -from ..security import uuid as _uuid -from ..serializers import _json_default, xml +from ..serializers import xml from .adapters import patch_adapter -from .objects import Table, Field, Set, Row, Rows from .helpers import ConnectionContext, TimingHandler from .models import MetaModel, Model -from .transactions import _atomic, _transaction, _savepoint +from .objects import Field, Row, Rows, Set, Table +from .transactions import _atomic, _savepoint, _transaction class DatabasePipe(Pipe): @@ -49,7 +50,7 @@ async def close(self): class Database(_pyDAL): - serializers = {'json': _json_default, 'xml': xml} + serializers = {"json": _json_default, "xml": xml} logger = None uuid = lambda x: _uuid() @@ -78,19 +79,12 @@ def uri_from_config(config=None): return uri def __new__(cls, app, *args, **kwargs): - config = kwargs.get('config', sdict()) or app.config.db + config = kwargs.get("config", sdict()) or app.config.db uri = config.uri or Database.uri_from_config(config) return super(Database, cls).__new__(cls, uri, *args, **kwargs) def __init__( - self, - app, - config=None, - pool_size=None, - keep_alive_timeout=3600, - connect_timeout=60, - folder=None, - **kwargs + self, app, config=None, pool_size=None, keep_alive_timeout=3600, connect_timeout=60, folder=None, **kwargs ): app.send_signal(Signals.before_database) self.logger = app.log @@ -98,30 +92,24 @@ def __init__( if not config.uri: config.uri = self.uri_from_config(config) if not config.migrations_folder: - config.migrations_folder = 'migrations' + config.migrations_folder = "migrations" self.config = config - self._auto_migrate = self.config.get( - 'auto_migrate', kwargs.pop('auto_migrate', False)) - self._auto_connect = self.config.get( - 'auto_connect', kwargs.pop('auto_connect', None)) - self._use_bigint_on_id_fields = self.config.get( - 'big_id_fields', kwargs.pop('big_id_fields', False)) + self._auto_migrate = self.config.get("auto_migrate", kwargs.pop("auto_migrate", False)) + self._auto_connect = self.config.get("auto_connect", kwargs.pop("auto_connect", None)) + self._use_bigint_on_id_fields = self.config.get("big_id_fields", kwargs.pop("big_id_fields", False)) #: load config data - kwargs['check_reserved'] = self.config.check_reserved or \ - kwargs.get('check_reserved', None) - kwargs['migrate'] = self._auto_migrate - kwargs['driver_args'] = self.config.driver_args or \ - kwargs.get('driver_args', None) - kwargs['adapter_args'] = self.config.adapter_args or \ - kwargs.get('adapter_args', None) + kwargs["check_reserved"] = self.config.check_reserved or kwargs.get("check_reserved", None) + kwargs["migrate"] = self._auto_migrate + kwargs["driver_args"] = self.config.driver_args or kwargs.get("driver_args", None) + kwargs["adapter_args"] = self.config.adapter_args or kwargs.get("adapter_args", None) if self._auto_connect is not None: - kwargs['do_connect'] = self._auto_connect + kwargs["do_connect"] = self._auto_connect else: - kwargs['do_connect'] = os.environ.get('EMMETT_CLI_ENV') == 'true' + kwargs["do_connect"] = os.environ.get("EMMETT_CLI_ENV") == "true" if self._use_bigint_on_id_fields: - kwargs['bigint_id'] = True + kwargs["bigint_id"] = True #: set directory - folder = folder or 'databases' + folder = folder or "databases" folder = os.path.join(app.root_path, folder) if self._auto_migrate: with self._cls_global_lock_: @@ -130,17 +118,14 @@ def __init__( #: set pool_size pool_size = self.config.pool_size or pool_size or 5 self._keep_alive_timeout = ( - keep_alive_timeout if self.config.keep_alive_timeout is None - else self.config.keep_alive_timeout) - self._connect_timeout = ( - connect_timeout if self.config.connect_timeout is None - else self.config.connect_timeout) + keep_alive_timeout if self.config.keep_alive_timeout is None else self.config.keep_alive_timeout + ) + self._connect_timeout = connect_timeout if self.config.connect_timeout is None else self.config.connect_timeout #: add timings storage if requested if config.store_execution_timings: self.execution_handlers.append(TimingHandler) #: finally setup pyDAL instance - super(Database, self).__init__( - self.config.uri, pool_size, folder, **kwargs) + super(Database, self).__init__(self.config.uri, pool_size, folder, **kwargs) patch_adapter(self._adapter) Model._init_inheritable_dicts_() app.send_signal(Signals.after_database, database=self) @@ -151,29 +136,19 @@ def pipe(self): @property def execution_timings(self): - return getattr(THREAD_LOCAL, '_emtdal_timings_', []) + return getattr(THREAD_LOCAL, "_emtdal_timings_", []) def connection_open(self, with_transaction=True, reuse_if_open=True): - return self._adapter.reconnect( - with_transaction=with_transaction, reuse_if_open=reuse_if_open) + return self._adapter.reconnect(with_transaction=with_transaction, reuse_if_open=reuse_if_open) def connection_close(self): self._adapter.close() - def connection( - self, - with_transaction: bool = True, - reuse_if_open: bool = True - ) -> ConnectionContext: - return ConnectionContext( - self, - with_transaction=with_transaction, - reuse_if_open=reuse_if_open - ) + def connection(self, with_transaction: bool = True, reuse_if_open: bool = True) -> ConnectionContext: + return ConnectionContext(self, with_transaction=with_transaction, reuse_if_open=reuse_if_open) def connection_open_loop(self, with_transaction=True, reuse_if_open=True): - return self._adapter.reconnect_loop( - with_transaction=with_transaction, reuse_if_open=reuse_if_open) + return self._adapter.reconnect_loop(with_transaction=with_transaction, reuse_if_open=reuse_if_open) def connection_close_loop(self): return self._adapter.close_loop() @@ -193,15 +168,13 @@ def define_models(self, *models): obj._define_relations_() obj._define_virtuals_() # define table and store in model - args = dict( - migrate=obj.migrate, - format=obj.format, - table_class=Table, - primarykey=obj.primary_keys or ['id'] - ) - model.table = self.define_table( - obj.tablename, *obj.fields, **args - ) + args = { + "migrate": obj.migrate, + "format": obj.format, + "table_class": Table, + "primarykey": obj.primary_keys or ["id"], + } + model.table = self.define_table(obj.tablename, *obj.fields, **args) model.table._model_ = obj # set reference in db for model name self.__setattr__(model.__name__, obj.table) @@ -217,7 +190,7 @@ def where(self, query=None, ignore_common_filters=None, model=None): if isinstance(query, Table): q = self._adapter.id_query(query) elif isinstance(query, Field): - q = (query != None) + q = query != None # noqa: E711 elif isinstance(query, dict): icf = query.get("ignore_common_filters") if icf: @@ -227,14 +200,14 @@ def where(self, query=None, ignore_common_filters=None, model=None): q = self._adapter.id_query(query.table) else: q = query - return Set( - self, q, ignore_common_filters=ignore_common_filters, model=model) + return Set(self, q, ignore_common_filters=ignore_common_filters, model=model) def with_connection(self, f): @wraps(f) def wrapped(*args, **kwargs): with self.connection(): f(*args, **kwargs) + return wrapped def atomic(self): @@ -259,7 +232,7 @@ def rollback(self): def _Database_unpickler(db_uid): fake_app_obj = sdict(config=sdict(db=sdict())) - fake_app_obj.config.db.adapter = '' + fake_app_obj.config.db.adapter = "" return Database(fake_app_obj, db_uid=db_uid) diff --git a/emmett/orm/connection.py b/emmett/orm/connection.py index 2c253109..a3b8a483 100644 --- a/emmett/orm/connection.py +++ b/emmett/orm/connection.py @@ -1,16 +1,16 @@ # -*- coding: utf-8 -*- """ - emmett.orm.connection - --------------------- +emmett.orm.connection +--------------------- - Provides pyDAL connection implementation for Emmett. +Provides pyDAL connection implementation for Emmett. - :copyright: 2014 Giovanni Barillari +:copyright: 2014 Giovanni Barillari - Parts of this code are inspired to peewee - :copyright: (c) 2010 by Charles Leifer +Parts of this code are inspired to peewee +:copyright: (c) 2010 by Charles Leifer - :license: BSD-3-Clause +:license: BSD-3-Clause """ import asyncio @@ -18,24 +18,24 @@ import heapq import threading import time - from collections import OrderedDict from functools import partial +from emmett_core.utils import cachedprop + from ..ctx import current -from ..utils import cachedprop from .errors import MaxConnectionsExceeded from .transactions import _transaction class ConnectionStateCtxVars: - __slots__ = ('_connection', '_transactions', '_cursors', '_closed') + __slots__ = ("_connection", "_transactions", "_cursors", "_closed") def __init__(self): - self._connection = contextvars.ContextVar('_emt_orm_cs_connection') - self._transactions = contextvars.ContextVar('_emt_orm_cs_transactions') - self._cursors = contextvars.ContextVar('_emt_orm_cs_cursors') - self._closed = contextvars.ContextVar('_emt_orm_cs_closed') + self._connection = contextvars.ContextVar("_emt_orm_cs_connection") + self._transactions = contextvars.ContextVar("_emt_orm_cs_transactions") + self._cursors = contextvars.ContextVar("_emt_orm_cs_cursors") + self._closed = contextvars.ContextVar("_emt_orm_cs_closed") self.reset() @property @@ -68,7 +68,7 @@ def reset(self): class ConnectionState: - __slots__ = ('_connection', '_transactions', '_cursors', '_closed') + __slots__ = ("_connection", "_transactions", "_cursors", "_closed") def __init__(self, connection=None): self.connection = connection @@ -86,14 +86,14 @@ def connection(self, value): class ConnectionStateCtl: - __slots__ = ['_state_obj_var', '_state_load_var'] + __slots__ = ["_state_obj_var", "_state_load_var"] state_cls = ConnectionState def __init__(self): inst_id = id(self) - self._state_obj_var = f'__emt_orm_state_{inst_id}__' - self._state_load_var = f'__emt_orm_state_loaded_{inst_id}__' + self._state_obj_var = f"__emt_orm_state_{inst_id}__" + self._state_load_var = f"__emt_orm_state_loaded_{inst_id}__" @property def _has_ctx(self): @@ -132,7 +132,7 @@ def reset(self): class ConnectionManager: - __slots__ = ['adapter', 'state', '__dict__'] + __slots__ = ["adapter", "state", "__dict__"] state_cls = ConnectionStateCtl def __init__(self, adapter, **kwargs): @@ -156,12 +156,7 @@ def _connection_open_sync(self): return self._connector_sync(), True async def _connection_open_loop(self): - return ( - await self._loop.run_in_executor( - None, self._connector_loop - ), - True - ) + return (await self._loop.run_in_executor(None, self._connector_loop), True) def _connection_close_sync(self, connection, *args, **kwargs): try: @@ -170,9 +165,7 @@ def _connection_close_sync(self, connection, *args, **kwargs): pass async def _connection_close_loop(self, connection, *args, **kwargs): - return await self._loop.run_in_executor( - None, partial(self._connection_close_sync, connection) - ) + return await self._loop.run_in_executor(None, partial(self._connection_close_sync, connection)) connect_sync = _connection_open_sync connect_loop = _connection_open_loop @@ -187,18 +180,16 @@ def __del__(self): class PooledConnectionManager(ConnectionManager): __slots__ = [ - 'max_connections', 'connect_timeout', 'stale_timeout', - 'connections_map', 'connections_sync', - 'in_use', '_lock_sync' + "max_connections", + "connect_timeout", + "stale_timeout", + "connections_map", + "connections_sync", + "in_use", + "_lock_sync", ] - def __init__( - self, - adapter, - max_connections=5, - connect_timeout=0, - stale_timeout=0 - ): + def __init__(self, adapter, max_connections=5, connect_timeout=0, stale_timeout=0): super().__init__(adapter) self.max_connections = max(max_connections, 1) self.connect_timeout = connect_timeout @@ -233,9 +224,7 @@ def connect_sync(self): raise MaxConnectionsExceeded() async def connect_loop(self): - return await asyncio.wait_for( - self._acquire_loop(), self.connect_timeout or None - ) + return await asyncio.wait_for(self._acquire_loop(), self.connect_timeout or None) def _acquire_sync(self): _opened = False @@ -273,9 +262,7 @@ async def _acquire_loop(self): break ts, key = await self.connections_loop.get() if self.stale_timeout and self.is_stale(ts): - await self._connection_close_loop( - self.connections_map.pop(key) - ) + await self._connection_close_loop(self.connections_map.pop(key)) else: conn = self.connections_map[key] break @@ -325,7 +312,7 @@ def _connect_sync(self, with_transaction=True, reuse_if_open=False): if not self._connection_manager.state.closed: if reuse_if_open: return False - raise RuntimeError('Connection already opened.') + raise RuntimeError("Connection already opened.") self.connection, _opened = self._connection_manager.connect_sync() if _opened: self.after_connection_hook() @@ -339,7 +326,7 @@ async def _connect_loop(self, with_transaction=True, reuse_if_open=False): if not self._connection_manager.state.closed: if reuse_if_open: return False - raise RuntimeError('Connection already opened.') + raise RuntimeError("Connection already opened.") self.connection, _opened = await self._connection_manager.connect_loop() if _opened: self.after_connection_hook() @@ -349,7 +336,7 @@ async def _connect_loop(self, with_transaction=True, reuse_if_open=False): return True -def _close_sync(self, action='commit', really=True): +def _close_sync(self, action="commit", really=True): is_open = not self._connection_manager.state.closed if not is_open: return is_open @@ -368,7 +355,7 @@ def _close_sync(self, action='commit', really=True): return is_open -async def _close_loop(self, action='commit', really=True): +async def _close_loop(self, action="commit", really=True): is_open = not self._connection_manager.state.closed if not is_open: return is_open diff --git a/emmett/orm/engines/__init__.py b/emmett/orm/engines/__init__.py index 1a1702a6..a612ca93 100644 --- a/emmett/orm/engines/__init__.py +++ b/emmett/orm/engines/__init__.py @@ -1,6 +1,3 @@ from pydal.adapters import adapters -from . import ( - postgres, - sqlite -) +from . import postgres, sqlite diff --git a/emmett/orm/engines/postgres.py b/emmett/orm/engines/postgres.py index dcc47f88..b6bbc174 100644 --- a/emmett/orm/engines/postgres.py +++ b/emmett/orm/engines/postgres.py @@ -1,19 +1,15 @@ # -*- coding: utf-8 -*- """ - emmett.orm.engines.postgres - --------------------------- +emmett.orm.engines.postgres +--------------------------- - Provides ORM PostgreSQL engine specific features. +Provides ORM PostgreSQL engine specific features. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ -from pydal.adapters.postgres import ( - PostgreBoolean, - PostgrePsycoBoolean, - PostgrePG8000Boolean -) +from pydal.adapters.postgres import PostgreBoolean, PostgrePG8000Boolean, PostgrePsycoBoolean from pydal.dialects import register_expression, sqltype_for from pydal.dialects.postgre import PostgreDialectBooleanJSON from pydal.helpers.serializers import serializers @@ -27,98 +23,94 @@ class JSONBPostgreDialect(PostgreDialectBooleanJSON): - @sqltype_for('jsonb') + @sqltype_for("jsonb") def type_jsonb(self): - return 'jsonb' + return "jsonb" def _jcontains(self, field, data, query_env={}): - return '(%s @> %s)' % ( + return "(%s @> %s)" % ( self.expand(field, query_env=query_env), - self.expand(data, field.type, query_env=query_env) + self.expand(data, field.type, query_env=query_env), ) - @register_expression('jcontains') + @register_expression("jcontains") def _jcontains_expr(self, expr, data): return Query(expr.db, self._jcontains, expr, data) def _jin(self, field, data, query_env={}): - return '(%s <@ %s)' % ( + return "(%s <@ %s)" % ( self.expand(field, query_env=query_env), - self.expand(data, field.type, query_env=query_env) + self.expand(data, field.type, query_env=query_env), ) - @register_expression('jin') + @register_expression("jin") def _jin_expr(self, expr, data): return Query(expr.db, self._jin, expr, data) def _jget_common(self, op, field, data, query_env): if not isinstance(data, int): - _dtype = field.type if isinstance(data, (dict, list)) else 'string' + _dtype = field.type if isinstance(data, (dict, list)) else "string" data = self.expand(data, field_type=_dtype, query_env=query_env) - return '%s %s %s' % ( - self.expand(field, query_env=query_env), - op, - str(data) - ) + return "%s %s %s" % (self.expand(field, query_env=query_env), op, str(data)) def _jget(self, field, data, query_env={}): - return self._jget_common('->', field, data, query_env=query_env) + return self._jget_common("->", field, data, query_env=query_env) - @register_expression('jget') + @register_expression("jget") def _jget_expr(self, expr, data): return Expression(expr.db, self._jget, expr, data, expr.type) def _jgetv(self, field, data, query_env={}): - return self._jget_common('->>', field, data, query_env=query_env) + return self._jget_common("->>", field, data, query_env=query_env) - @register_expression('jgetv') + @register_expression("jgetv") def _jgetv_expr(self, expr, data): - return Expression(expr.db, self._jgetv, expr, data, 'string') + return Expression(expr.db, self._jgetv, expr, data, "string") def _jpath(self, field, data, query_env={}): - return '%s #> %s' % ( + return "%s #> %s" % ( self.expand(field, query_env=query_env), - self.expand(data, field_type='string', query_env=query_env) + self.expand(data, field_type="string", query_env=query_env), ) - @register_expression('jpath') + @register_expression("jpath") def _jpath_expr(self, expr, data): return Expression(expr.db, self._jpath, expr, data, expr.type) def _jpathv(self, field, data, query_env={}): - return '%s #>> %s' % ( + return "%s #>> %s" % ( self.expand(field, query_env=query_env), - self.expand(data, field_type='string', query_env=query_env) + self.expand(data, field_type="string", query_env=query_env), ) - @register_expression('jpathv') + @register_expression("jpathv") def _jpathv_expr(self, expr, data): - return Expression(expr.db, self._jpathv, expr, data, 'string') + return Expression(expr.db, self._jpathv, expr, data, "string") def _jhas(self, field, data, all=False, query_env={}): - _op, _ftype = '?', 'string' + _op, _ftype = "?", "string" if isinstance(data, list): - _op = '?&' if all else '?|' - _ftype = 'list:string' - return '%s %s %s' % ( + _op = "?&" if all else "?|" + _ftype = "list:string" + return "%s %s %s" % ( self.expand(field, query_env=query_env), _op, - self.expand(data, field_type=_ftype, query_env=query_env) + self.expand(data, field_type=_ftype, query_env=query_env), ) - @register_expression('jhas') + @register_expression("jhas") def _jhas_expr(self, expr, data, all=False): return Query(expr.db, self._jhas, expr, data, all=all) class JSONBPostgreParser(PostgreBooleanAutoJSONParser): - @parse_type('jsonb') + @parse_type("jsonb") def _jsonb(self, value): return value class JSONBPostgreRepresenter(PostgreArraysRepresenter): - @repr_type('jsonb') + @repr_type("jsonb") def _jsonb(self, value): return serializers.json(value) @@ -145,9 +137,9 @@ def _insert(self, table, fields): retval = table._id._rname return self.dialect.insert( table._rname, - ','.join(el[0]._rname for el in fields), - ','.join(self.expand(v, f.type) for f, v in fields), - retval + ",".join(el[0]._rname for el in fields), + ",".join(self.expand(v, f.type) for f, v in fields), + retval, ) return self.dialect.insert_empty(table._rname) @@ -159,16 +151,16 @@ def lastrowid(self, table): return self.cursor.fetchone()[0] -@adapters.register_for('postgres') +@adapters.register_for("postgres") class PostgresAdapter(PostgresAdapterMixin, PostgreBoolean): pass -@adapters.register_for('postgres:psycopg2') +@adapters.register_for("postgres:psycopg2") class PostgresPsycoPG2Adapter(PostgresAdapterMixin, PostgrePsycoBoolean): pass -@adapters.register_for('postgres:pg8000') +@adapters.register_for("postgres:pg8000") class PostgresPG8000Adapter(PostgresAdapterMixin, PostgrePG8000Boolean): pass diff --git a/emmett/orm/engines/sqlite.py b/emmett/orm/engines/sqlite.py index a979ce76..2b6b5d81 100644 --- a/emmett/orm/engines/sqlite.py +++ b/emmett/orm/engines/sqlite.py @@ -1,12 +1,12 @@ # -*- coding: utf-8 -*- """ - emmett.orm.engines.sqlite - ------------------------- +emmett.orm.engines.sqlite +------------------------- - Provides ORM SQLite engine specific features. +Provides ORM SQLite engine specific features. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ from pydal.adapters.sqlite import SQLite as _SQLite @@ -14,27 +14,21 @@ from . import adapters -@adapters.register_for('sqlite', 'sqlite:memory') +@adapters.register_for("sqlite", "sqlite:memory") class SQLite(_SQLite): def _initialize_(self, do_connect): super()._initialize_(do_connect) - self.driver_args['isolation_level'] = None + self.driver_args["isolation_level"] = None def begin(self, lock_type=None): - statement = 'BEGIN %s;' % lock_type if lock_type else 'BEGIN;' + statement = "BEGIN %s;" % lock_type if lock_type else "BEGIN;" self.execute(statement) def delete(self, table, query): - deleted = ( - [x[table._id.name] for x in self.db(query).select(table._id)] - if table._id else [] - ) + deleted = [x[table._id.name] for x in self.db(query).select(table._id)] if table._id else [] counter = super(_SQLite, self).delete(table, query) if table._id and counter: for field in table._referenced_by: - if ( - field.type == 'reference ' + table._dalname and - field.ondelete == 'CASCADE' - ): + if field.type == "reference " + table._dalname and field.ondelete == "CASCADE": self.db(field.belongs(deleted)).delete() return counter diff --git a/emmett/orm/errors.py b/emmett/orm/errors.py index 691bf876..293e9e35 100644 --- a/emmett/orm/errors.py +++ b/emmett/orm/errors.py @@ -1,39 +1,33 @@ # -*- coding: utf-8 -*- """ - emmett.orm.errors - ----------------- +emmett.orm.errors +----------------- - Provides some error wrappers. +Provides some error wrappers. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ class MaxConnectionsExceeded(RuntimeError): def __init__(self): - super().__init__('Exceeded maximum connections') + super().__init__("Exceeded maximum connections") -class MissingFieldsForCompute(RuntimeError): - ... +class MissingFieldsForCompute(RuntimeError): ... -class SaveException(RuntimeError): - ... +class SaveException(RuntimeError): ... -class InsertFailureOnSave(SaveException): - ... +class InsertFailureOnSave(SaveException): ... -class UpdateFailureOnSave(SaveException): - ... +class UpdateFailureOnSave(SaveException): ... -class DestroyException(RuntimeError): - ... +class DestroyException(RuntimeError): ... -class ValidationError(RuntimeError): - ... +class ValidationError(RuntimeError): ... diff --git a/emmett/orm/geo.py b/emmett/orm/geo.py index 8ecb3e6a..df752609 100644 --- a/emmett/orm/geo.py +++ b/emmett/orm/geo.py @@ -1,12 +1,12 @@ # -*- coding: utf-8 -*- """ - emmett.orm.geo - -------------- +emmett.orm.geo +-------------- - Provides geographic facilities. +Provides geographic facilities. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ from .helpers import GeoFieldWrapper @@ -17,9 +17,7 @@ def Point(x, y): def Line(*coordinates): - return GeoFieldWrapper( - "LINESTRING(%s)" % ','.join("%f %f" % point for point in coordinates) - ) + return GeoFieldWrapper("LINESTRING(%s)" % ",".join("%f %f" % point for point in coordinates)) def Polygon(*coordinates_groups): @@ -29,46 +27,30 @@ def Polygon(*coordinates_groups): except Exception: pass return GeoFieldWrapper( - "POLYGON(%s)" % ( - ",".join([ - "(%s)" % ",".join("%f %f" % point for point in group) - for group in coordinates_groups - ]) - ) + "POLYGON(%s)" + % (",".join(["(%s)" % ",".join("%f %f" % point for point in group) for group in coordinates_groups])) ) def MultiPoint(*points): - return GeoFieldWrapper( - "MULTIPOINT(%s)" % ( - ",".join([ - "(%f %f)" % point for point in points - ]) - ) - ) + return GeoFieldWrapper("MULTIPOINT(%s)" % (",".join(["(%f %f)" % point for point in points]))) def MultiLine(*lines): return GeoFieldWrapper( - "MULTILINESTRING(%s)" % ( - ",".join([ - "(%s)" % ",".join("%f %f" % point for point in line) - for line in lines - ]) - ) + "MULTILINESTRING(%s)" % (",".join(["(%s)" % ",".join("%f %f" % point for point in line) for line in lines])) ) def MultiPolygon(*polygons): return GeoFieldWrapper( - "MULTIPOLYGON(%s)" % ( - ",".join([ - "(%s)" % ( - ",".join([ - "(%s)" % ",".join("%f %f" % point for point in group) - for group in polygon - ]) - ) for polygon in polygons - ]) + "MULTIPOLYGON(%s)" + % ( + ",".join( + [ + "(%s)" % (",".join(["(%s)" % ",".join("%f %f" % point for point in group) for group in polygon])) + for polygon in polygons + ] + ) ) ) diff --git a/emmett/orm/helpers.py b/emmett/orm/helpers.py index 161183da..58713725 100644 --- a/emmett/orm/helpers.py +++ b/emmett/orm/helpers.py @@ -1,12 +1,12 @@ # -*- coding: utf-8 -*- """ - emmett.orm.helpers - ------------------ +emmett.orm.helpers +------------------ - Provides ORM helpers. +Provides ORM helpers. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ from __future__ import annotations @@ -15,23 +15,23 @@ import operator import re import time - from functools import reduce, wraps from typing import TYPE_CHECKING, Any, Callable +from emmett_core.utils import cachedprop from pydal._globals import THREAD_LOCAL from pydal.helpers.classes import ExecutionHandler from pydal.objects import Field as _Field from ..datastructures import sdict -from ..utils import cachedprop + if TYPE_CHECKING: from .objects import Table class RowReferenceMeta: - __slots__ = ['table', 'pk', 'caster'] + __slots__ = ["table", "pk", "caster"] def __init__(self, table: Table, caster: Callable[[Any], Any]): self.table = table @@ -39,15 +39,14 @@ def __init__(self, table: Table, caster: Callable[[Any], Any]): self.caster = caster def fetch(self, val): - return self.table._db(self.table._id == self.caster(val)).select( - limitby=(0, 1), - orderby_on_limitby=False - ).first() + return ( + self.table._db(self.table._id == self.caster(val)).select(limitby=(0, 1), orderby_on_limitby=False).first() + ) class RowReferenceMultiMeta: - __slots__ = ['table', 'pks', 'pks_idx', 'caster', 'casters'] - _casters = {'integer': int, 'string': str} + __slots__ = ["table", "pks", "pks_idx", "caster", "casters"] + _casters = {"integer": int, "string": str} def __init__(self, table: Table) -> None: self.table = table @@ -58,15 +57,10 @@ def __init__(self, table: Table) -> None: def fetch(self, val): query = reduce( - operator.and_, [ - self.table[pk] == self.casters[pk](self.caster.__getitem__(val, idx)) - for pk, idx in self.pks_idx.items() - ] + operator.and_, + [self.table[pk] == self.casters[pk](self.caster.__getitem__(val, idx)) for pk, idx in self.pks_idx.items()], ) - return self.table._db(query).select( - limitby=(0, 1), - orderby_on_limitby=False - ).first() + return self.table._db(query).select(limitby=(0, 1), orderby_on_limitby=False).first() class RowReferenceMixin: @@ -75,8 +69,7 @@ def _allocate_(self): self._refrecord = self._refmeta.fetch(self) if not self._refrecord: raise RuntimeError( - "Using a recursive select but encountered a broken " + - "reference: %s %r" % (self._table, self) + "Using a recursive select but encountered a broken " + "reference: %s %r" % (self._table, self) ) def __getattr__(self, key: str) -> Any: @@ -92,7 +85,7 @@ def get(self, key: str, default: Any = None) -> Any: return self.__getattr__(key, default) def __setattr__(self, key: str, value: Any): - if key.startswith('_'): + if key.startswith("_"): self._refmeta.caster.__setattr__(self, key, value) return self._allocate_() @@ -118,16 +111,16 @@ def __repr__(self) -> str: class RowReferenceInt(RowReferenceMixin, int): def __new__(cls, id, table: Table, *args: Any, **kwargs: Any): rv = super().__new__(cls, id, *args, **kwargs) - int.__setattr__(rv, '_refmeta', RowReferenceMeta(table, int)) - int.__setattr__(rv, '_refrecord', None) + int.__setattr__(rv, "_refmeta", RowReferenceMeta(table, int)) + int.__setattr__(rv, "_refrecord", None) return rv class RowReferenceStr(RowReferenceMixin, str): def __new__(cls, id, table: Table, *args: Any, **kwargs: Any): rv = super().__new__(cls, id, *args, **kwargs) - str.__setattr__(rv, '_refmeta', RowReferenceMeta(table, str)) - str.__setattr__(rv, '_refrecord', None) + str.__setattr__(rv, "_refmeta", RowReferenceMeta(table, str)) + str.__setattr__(rv, "_refrecord", None) return rv @@ -135,15 +128,13 @@ class RowReferenceMulti(RowReferenceMixin, tuple): def __new__(cls, id, table: Table, *args: Any, **kwargs: Any): tupid = tuple(id[key] for key in table._primarykey) rv = super().__new__(cls, tupid, *args, **kwargs) - tuple.__setattr__(rv, '_refmeta', RowReferenceMultiMeta(table)) - tuple.__setattr__(rv, '_refrecord', None) + tuple.__setattr__(rv, "_refmeta", RowReferenceMultiMeta(table)) + tuple.__setattr__(rv, "_refrecord", None) return rv def __getattr__(self, key: str) -> Any: if key in self._refmeta.pks: - return self._refmeta.casters[key]( - tuple.__getitem__(self, self._refmeta.pks_idx[key]) - ) + return self._refmeta.casters[key](tuple.__getitem__(self, self._refmeta.pks_idx[key])) if key in self._refmeta.table: self._allocate_() if self._refrecord: @@ -152,9 +143,7 @@ def __getattr__(self, key: str) -> Any: def __getitem__(self, key): if key in self._refmeta.pks: - return self._refmeta.casters[key]( - tuple.__getitem__(self, self._refmeta.pks_idx[key]) - ) + return self._refmeta.casters[key](tuple.__getitem__(self, self._refmeta.pks_idx[key])) self._allocate_() return self._refrecord.get(key, None) @@ -167,22 +156,22 @@ class GeoFieldWrapper(str): "POLYGON": "Polygon", "MULTIPOINT": "MultiPoint", "MULTILINESTRING": "MultiLineString", - "MULTIPOLYGON": "MultiPolygon" + "MULTIPOLYGON": "MultiPolygon", } def __new__(cls, value, *args: Any, **kwargs: Any): geometry, raw_coords = value.strip()[:-1].split("(", 1) rv = super().__new__(cls, value, *args, **kwargs) coords = cls._parse_coords_block(raw_coords) - str.__setattr__(rv, '_geometry', geometry.strip()) - str.__setattr__(rv, '_coordinates', coords) + str.__setattr__(rv, "_geometry", geometry.strip()) + str.__setattr__(rv, "_coordinates", coords) return rv @classmethod def _parse_coords_block(cls, v): groups = [] parens_match = cls._rule_parens.match(v) - parens = parens_match.group(1) if parens_match else '' + parens = parens_match.group(1) if parens_match else "" if parens: for element in v.split(parens): if not element: @@ -192,9 +181,7 @@ def _parse_coords_block(cls, v): groups.append(f"{parens}{element}"[1:shift]) if not groups: return cls._parse_coords_group(v) - return tuple( - cls._parse_coords_block(group) for group in groups - ) + return tuple(cls._parse_coords_block(group) for group in groups) @staticmethod def _parse_coords_group(v): @@ -225,17 +212,13 @@ def coordinates(self): @property def groups(self): if not self._geometry.startswith("MULTI"): - return tuple() + return () return tuple( - self.__class__(f"{self._geometry[5:]}({self._repr_coords(coords)[0]})") - for coords in self._coordinates + self.__class__(f"{self._geometry[5:]}({self._repr_coords(coords)[0]})") for coords in self._coordinates ) def __json__(self): - return { - "type": self._json_geom_map[self._geometry], - "coordinates": self._coordinates - } + return {"type": self._json_geom_map[self._geometry], "coordinates": self._coordinates} class PasswordFieldWrapper(str): @@ -244,27 +227,24 @@ class PasswordFieldWrapper(str): class Reference(object): def __init__(self, *args, **params): - self.reference = [arg for arg in args] + self.reference = list(args) self.params = params self.refobj[id(self)] = self def __call__(self, func): - if self.__class__.__name__ not in ['has_one', 'has_many']: - raise SyntaxError( - '%s cannot be used as a decorator' % self.__class__.__name__) + if self.__class__.__name__ not in ["has_one", "has_many"]: + raise SyntaxError("%s cannot be used as a decorator" % self.__class__.__name__) if not callable(func): - raise SyntaxError('Argument must be callable') + raise SyntaxError("Argument must be callable") if self.reference: - raise SyntaxError( - "When using %s as decorator, you must use the 'field' option" % - self.__class__.__name__) - new_reference = {func.__name__: {'method': func}} - field = self.params.get('field') + raise SyntaxError("When using %s as decorator, you must use the 'field' option" % self.__class__.__name__) + new_reference = {func.__name__: {"method": func}} + field = self.params.get("field") if field: - new_reference[func.__name__]['field'] = field - cast = self.params.get('cast') + new_reference[func.__name__]["field"] = field + cast = self.params.get("cast") if cast: - new_reference[func.__name__]['cast'] = cast + new_reference[func.__name__]["cast"] = cast self.reference = [new_reference] return self @@ -345,10 +325,11 @@ def _get_belongs(self, modelname, value): def belongs_query(self): return reduce( - operator.and_, [ + operator.and_, + [ self.model.table[local] == self.model.db[self.ref.model][foreign] for local, foreign in self.ref.coupled_fields - ] + ], ) @staticmethod @@ -361,17 +342,10 @@ def many_query(ref, rid): components.append(element.cast(ref.cast)) else: components.append(element) - return reduce( - operator.and_, [ - field == components[idx] - for idx, field in enumerate(ref.fields_instances) - ] - ) + return reduce(operator.and_, [field == components[idx] for idx, field in enumerate(ref.fields_instances)]) def _many(self, ref, rid): - return ref.dbset.where( - self._patch_query_with_scopes(ref, self.many_query(ref, rid)) - ).query + return ref.dbset.where(self._patch_query_with_scopes(ref, self.many_query(ref, rid))).query def many(self, row=None): return self._many(self.ref, self._make_refid(row)) @@ -381,17 +355,11 @@ def via(self, row=None): rid = self._make_refid(row) sname = self.model.__class__.__name__ stack = [] - midrel = self.model._hasmany_ref_.get( - self.ref.via, - self.model._hasone_ref_.get(self.ref.via) - ) + midrel = self.model._hasmany_ref_.get(self.ref.via, self.model._hasone_ref_.get(self.ref.via)) stack.append(self.ref) while midrel.via is not None: stack.insert(0, midrel) - midrel = self.model._hasmany_ref_.get( - midrel.via, - self.model._hasone_ref_.get(midrel.via) - ) + midrel = self.model._hasmany_ref_.get(midrel.via, self.model._hasone_ref_.get(midrel.via)) query = self._many(midrel, rid) step_model = midrel.table_name sel_field = db[step_model].ALL @@ -405,12 +373,11 @@ def via(self, row=None): last_belongs = step_model last_via = via _query = reduce( - operator.and_, [ - ( - db[belongs_model.model][foreign] == - db[step_model][local] - ) for local, foreign in belongs_model.coupled_fields - ] + operator.and_, + [ + (db[belongs_model.model][foreign] == db[step_model][local]) + for local, foreign in belongs_model.coupled_fields + ], ) sel_field = db[belongs_model.model].ALL step_model = belongs_model.model @@ -418,10 +385,7 @@ def via(self, row=None): #: shortcut way last_belongs = None rname = via.field or via.name - midrel = db[step_model]._model_._hasmany_ref_.get( - rname, - db[step_model]._model_._hasone_ref_.get(rname) - ) + midrel = db[step_model]._model_._hasmany_ref_.get(rname, db[step_model]._model_._hasone_ref_.get(rname)) if midrel.via: nested = RelationBuilder(midrel, midrel.model_class) nested_data = nested.via() @@ -429,19 +393,13 @@ def via(self, row=None): step_model = midrel.model_class.tablename else: _query = self._many( - midrel, [ - db[step_model][step_field] - for step_field in ( - db[step_model]._model_.primary_keys or ["id"] - ) - ] + midrel, + [db[step_model][step_field] for step_field in (db[step_model]._model_.primary_keys or ["id"])], ) step_model = midrel.table_name sel_field = db[step_model].ALL query = query & _query - query = via.dbset.where( - self._patch_query_with_scopes_on(via, query, step_model) - ).query + query = via.dbset.where(self._patch_query_with_scopes_on(via, query, step_model)).query return query, sel_field, sname, rid, last_belongs, last_via @@ -464,8 +422,7 @@ def __call__(self): class TimingHandler(ExecutionHandler): def _timings(self): - THREAD_LOCAL._emtdal_timings_ = getattr( - THREAD_LOCAL, '_emtdal_timings_', []) + THREAD_LOCAL._emtdal_timings_ = getattr(THREAD_LOCAL, "_emtdal_timings_", []) return THREAD_LOCAL._emtdal_timings_ @cachedprop @@ -481,7 +438,7 @@ def after_execute(self, command): class ConnectionContext: - __slots__ = ['db', 'conn', 'with_transaction', 'reuse_if_open'] + __slots__ = ["db", "conn", "with_transaction", "reuse_if_open"] def __init__(self, db, with_transaction=True, reuse_if_open=True): self.db = db @@ -490,10 +447,7 @@ def __init__(self, db, with_transaction=True, reuse_if_open=True): self.reuse_if_open = reuse_if_open def __enter__(self): - self.conn = self.db.connection_open( - with_transaction=self.with_transaction, - reuse_if_open=self.reuse_if_open - ) + self.conn = self.db.connection_open(with_transaction=self.with_transaction, reuse_if_open=self.reuse_if_open) return self def __exit__(self, exc_type, exc_val, exc_tb): @@ -503,8 +457,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): async def __aenter__(self): self.conn = await self.db.connection_open_loop( - with_transaction=self.with_transaction, - reuse_if_open=self.reuse_if_open + with_transaction=self.with_transaction, reuse_if_open=self.reuse_if_open ) return self @@ -515,7 +468,7 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): def decamelize(name): - return "_".join(re.findall('[A-Z][^A-Z]*', name)).lower() + return "_".join(re.findall("[A-Z][^A-Z]*", name)).lower() def camelize(name): @@ -529,17 +482,16 @@ def make_tablename(classname): def wrap_scope_on_set(dbset, model_instance, scope): @wraps(scope) def wrapped(*args, **kwargs): - return dbset.where( - scope(model_instance, *args, **kwargs), - model=model_instance.__class__) + return dbset.where(scope(model_instance, *args, **kwargs), model=model_instance.__class__) + return wrapped def wrap_scope_on_model(scope): @wraps(scope) def wrapped(cls, *args, **kwargs): - return cls.db.where( - scope(cls._instance_(), *args, **kwargs), model=cls) + return cls.db.where(scope(cls._instance_(), *args, **kwargs), model=cls) + return wrapped @@ -547,27 +499,22 @@ def wrap_virtual_on_model(model, virtual): @wraps(virtual) def wrapped(row, *args, **kwargs): return virtual(model, row, *args, **kwargs) + return wrapped def typed_row_reference(id: Any, table: Table): field_type = table._id.type if table._id else None - return { - 'id': RowReferenceInt, - 'integer': RowReferenceInt, - 'string': RowReferenceStr, - None: RowReferenceMulti - }[field_type](id, table) + return {"id": RowReferenceInt, "integer": RowReferenceInt, "string": RowReferenceStr, None: RowReferenceMulti}[ + field_type + ](id, table) def typed_row_reference_from_record(record: Any, model: Any): field_type = model.table._id.type if model.table._id else None - refcls = { - 'id': RowReferenceInt, - 'integer': RowReferenceInt, - 'string': RowReferenceStr, - None: RowReferenceMulti - }[field_type] + refcls = {"id": RowReferenceInt, "integer": RowReferenceInt, "string": RowReferenceStr, None: RowReferenceMulti}[ + field_type + ] if len(model._fieldset_pk) > 1: id = {pk: record[pk] for pk in model._fieldset_pk} else: @@ -578,7 +525,7 @@ def typed_row_reference_from_record(record: Any, model: Any): def _rowref_pickler(obj): - return obj._refmeta.caster, (obj.__pure__(), ) + return obj._refmeta.caster, (obj.__pure__(),) copyreg.pickle(RowReferenceInt, _rowref_pickler) diff --git a/emmett/orm/migrations/__init__.py b/emmett/orm/migrations/__init__.py index 5b9db84a..87321f19 100644 --- a/emmett/orm/migrations/__init__.py +++ b/emmett/orm/migrations/__init__.py @@ -1,2 +1,2 @@ -from .base import Migration, Column +from .base import Column, Migration from .operations import * diff --git a/emmett/orm/migrations/base.py b/emmett/orm/migrations/base.py index 7c0d6e15..7178d4a3 100644 --- a/emmett/orm/migrations/base.py +++ b/emmett/orm/migrations/base.py @@ -1,12 +1,12 @@ # -*- coding: utf-8 -*- """ - emmett.orm.migrations.base - -------------------------- +emmett.orm.migrations.base +-------------------------- - Provides base migrations objects. +Provides base migrations objects. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ from __future__ import annotations @@ -14,10 +14,11 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, Type from ...datastructures import sdict -from .. import Database, Model, Field -from .engine import MetaEngine, Engine +from .. import Database, Field, Model +from .engine import Engine, MetaEngine from .helpers import WrappedOperation, _feasible_as_dbms_default + if TYPE_CHECKING: from .operations import Operation @@ -32,13 +33,11 @@ class Migration: skip_on_compare: bool = False @classmethod - def register_operation( - cls, - name: str - ) -> Callable[[Type[Operation]], Type[Operation]]: + def register_operation(cls, name: str) -> Callable[[Type[Operation]], Type[Operation]]: def wrap(op_cls: Type[Operation]) -> Type[Operation]: cls._registered_ops_[name] = op_cls return op_cls + return wrap def __init__(self, app: Any, db: Database, is_meta: bool = False): @@ -56,14 +55,7 @@ def __getattr__(self, name: str) -> WrappedOperation: class Column(sdict): - def __init__( - self, - name: str, - type: str = 'string', - unique: bool = False, - notnull: bool = False, - **kwargs: Any - ): + def __init__(self, name: str, type: str = "string", unique: bool = False, notnull: bool = False, **kwargs: Any): self.name = name self.type = type self.unique = unique @@ -76,7 +68,7 @@ def _fk_type(self, db: Database, tablename: str): if self.name not in db[tablename]._model_._belongs_ref_: return ref = db[tablename]._model_._belongs_ref_[self.name] - if ref.ftype != 'id': + if ref.ftype != "id": self.type = ref.ftype self.length = db[ref.model][ref.fk].length self.on_delete = None @@ -90,7 +82,7 @@ def from_field(cls, field: Field) -> Column: field.notnull, length=field.length, ondelete=field.ondelete, - **field._ormkw + **field._ormkw, ) if _feasible_as_dbms_default(field.default): rv.default = field.default @@ -98,7 +90,4 @@ def from_field(cls, field: Field) -> Column: return rv def __repr__(self) -> str: - return "%s(%s)" % ( - self.__class__.__name__, - ", ".join(["%s=%r" % (k, v) for k, v in self.items()]) - ) + return "%s(%s)" % (self.__class__.__name__, ", ".join(["%s=%r" % (k, v) for k, v in self.items()])) diff --git a/emmett/orm/migrations/commands.py b/emmett/orm/migrations/commands.py index 6d055ecb..64019389 100644 --- a/emmett/orm/migrations/commands.py +++ b/emmett/orm/migrations/commands.py @@ -1,12 +1,12 @@ # -*- coding: utf-8 -*- """ - emmett.orm.migrations.commands - ------------------------------ +emmett.orm.migrations.commands +------------------------------ - Provides command interfaces for migrations. +Provides command interfaces for migrations. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ from __future__ import annotations @@ -16,9 +16,9 @@ import click from ...datastructures import sdict -from .base import Database, Schema, Column +from .base import Column, Database, Schema from .helpers import DryRunDatabase, make_migration_id, to_tuple -from .operations import MigrationOp, UpgradeOps, DowngradeOps +from .operations import DowngradeOps, MigrationOp, UpgradeOps from .scripts import ScriptDir @@ -30,14 +30,7 @@ def __init__(self, app: Any, dals: List[Database]): def _load_envs(self, dals): for dal in dals: - self.envs.append( - sdict( - db=dal, - scriptdir=ScriptDir( - self.app, dal.config.migrations_folder - ) - ) - ) + self.envs.append(sdict(db=dal, scriptdir=ScriptDir(self.app, dal.config.migrations_folder))) def load_schema(self, ctx): ctx.db.define_models(Schema) @@ -53,6 +46,7 @@ def _ensure_schema_table_(self, ctx): ctx.db.rollback() from .engine import Engine from .operations import CreateTableOp + op = CreateTableOp.from_table(self._build_schema_metatable_(ctx)) op.engine = Engine(ctx.db) op.run() @@ -61,13 +55,11 @@ def _ensure_schema_table_(self, ctx): @staticmethod def _build_schema_metatable_(ctx): from .generation import MetaTable + columns = [] for field in list(ctx.db.Schema): columns.append(Column.from_field(field)) - return MetaTable( - ctx.db.Schema._tablename, - columns - ) + return MetaTable(ctx.db.Schema._tablename, columns) @staticmethod def _load_current_revision_(ctx): @@ -81,226 +73,152 @@ def _load_current_revision_(ctx): @staticmethod def _log_store_new(revid): - click.echo( - " ".join([ - "> Adding revision", - click.style(revid, fg="cyan", bold=True), - "to schema" - ]) - ) + click.echo(" ".join(["> Adding revision", click.style(revid, fg="cyan", bold=True), "to schema"])) @staticmethod def _log_store_del(revid): - click.echo( - " ".join([ - "> Removing revision", - click.style(revid, fg="cyan", bold=True), - "from schema" - ]) - ) + click.echo(" ".join(["> Removing revision", click.style(revid, fg="cyan", bold=True), "from schema"])) @staticmethod def _log_store_upd(revid_src, revid_dst): click.echo( - " ".join([ - "> Updating schema revision from", - click.style(revid_src, fg="cyan", bold=True), - "to", - click.style(revid_dst, fg="cyan", bold=True), - ]) + " ".join( + [ + "> Updating schema revision from", + click.style(revid_src, fg="cyan", bold=True), + "to", + click.style(revid_dst, fg="cyan", bold=True), + ] + ) ) @staticmethod def _log_dry_run(msg): - click.secho(msg, fg='yellow') + click.secho(msg, fg="yellow") def _store_current_revision_(self, ctx, source, dest): - _store_logs = { - 'new': self._log_store_new, - 'del': self._log_store_del, - 'upd': self._log_store_upd - } + _store_logs = {"new": self._log_store_new, "del": self._log_store_del, "upd": self._log_store_upd} source = to_tuple(source) dest = to_tuple(dest) if not source and dest: - _store_logs['new'](dest[0]) + _store_logs["new"](dest[0]) ctx.db.Schema.insert(version=dest[0]) ctx.db.commit() ctx._current_revision_ = [dest[0]] return if not dest and source: - _store_logs['del'](source[0]) + _store_logs["del"](source[0]) ctx.db(ctx.db.Schema.version == source[0]).delete() ctx.db.commit() ctx._current_revision_ = [] return if len(source) > 1: if len(source) > 2: - ctx.db( - ctx.db.Schema.version.belongs( - source[1:])).delete() - _store_logs['del'](source[1:]) + ctx.db(ctx.db.Schema.version.belongs(source[1:])).delete() + _store_logs["del"](source[1:]) else: - ctx.db( - ctx.db.Schema.version == source[1]).delete() - _store_logs['del'](source[1]) - ctx.db(ctx.db.Schema.version == source[0]).update( - version=dest[0] - ) - _store_logs['upd'](source[0], dest[0]) + ctx.db(ctx.db.Schema.version == source[1]).delete() + _store_logs["del"](source[1]) + ctx.db(ctx.db.Schema.version == source[0]).update(version=dest[0]) + _store_logs["upd"](source[0], dest[0]) ctx._current_revision_ = [dest[0]] else: if list(source) != ctx._current_revision_: ctx.db.Schema.insert(version=dest[0]) - _store_logs['new'](dest[0]) + _store_logs["new"](dest[0]) ctx._current_revision_.append(dest[0]) else: - ctx.db( - ctx.db.Schema.version == source[0] - ).update( - version=dest[0] - ) - _store_logs['upd'](source[0], dest[0]) + ctx.db(ctx.db.Schema.version == source[0]).update(version=dest[0]) + _store_logs["upd"](source[0], dest[0]) ctx._current_revision_ = [dest[0]] ctx.db.commit() @staticmethod def _generate_migration_script(ctx, migration, head): from .generation import Renderer + upgrades, downgrades = Renderer.render_migration(migration) ctx.scriptdir.generate_revision( - migration.rev_id, migration.message, head, upgrades=upgrades, - downgrades=downgrades + migration.rev_id, migration.message, head, upgrades=upgrades, downgrades=downgrades ) def generate(self, message, head): from .generation import Generator + for ctx in self.envs: upgrade_ops = Generator.generate_from(ctx.db, ctx.scriptdir, head) revid = make_migration_id() - migration = MigrationOp( - revid, upgrade_ops, upgrade_ops.reverse(), message - ) + migration = MigrationOp(revid, upgrade_ops, upgrade_ops.reverse(), message) self._generate_migration_script(ctx, migration, head) - click.echo( - " ".join([ - "> Generated migration for revision", - click.style(revid, fg="cyan", bold=True) - ]) - ) + click.echo(" ".join(["> Generated migration for revision", click.style(revid, fg="cyan", bold=True)])) def new(self, message, head): for ctx in self.envs: source_rev = ctx.scriptdir.get_revision(head) revid = make_migration_id() - migration = MigrationOp( - revid, UpgradeOps(), DowngradeOps(), message - ) - self._generate_migration_script( - ctx, migration, source_rev.revision - ) - click.echo( - " ".join([ - "> Created new migration for revision", - click.style(revid, fg="cyan", bold=True) - ]) - ) + migration = MigrationOp(revid, UpgradeOps(), DowngradeOps(), message) + self._generate_migration_script(ctx, migration, source_rev.revision) + click.echo(" ".join(["> Created new migration for revision", click.style(revid, fg="cyan", bold=True)])) def history(self, base, head, verbose): for ctx in self.envs: click.echo("> Migrations history:") lines = [] - for sc in ctx.scriptdir.walk_revisions( - base=base or "base", - head=head or "heads" - ): - lines.append( - sc.cmd_format( - verbose=verbose, include_doc=True, include_parents=True - ) - ) + for sc in ctx.scriptdir.walk_revisions(base=base or "base", head=head or "heads"): + lines.append(sc.cmd_format(verbose=verbose, include_doc=True, include_parents=True)) for line in lines: click.echo(line) if not lines: - click.secho( - "No migrations for the selected application.", fg="yellow" - ) + click.secho("No migrations for the selected application.", fg="yellow") def status(self, verbose): for ctx in self.envs: self.load_schema(ctx) - click.echo( - " ".join([ - "> Current revision(s) for", - click.style(ctx.db._uri, bold=True) - ]) - ) + click.echo(" ".join(["> Current revision(s) for", click.style(ctx.db._uri, bold=True)])) lines = [] for rev in ctx.scriptdir.get_revisions(ctx._current_revision_): lines.append(rev.cmd_format(verbose)) for line in lines: click.echo(line) if not lines: - click.secho( - "No revision state found on the schema.", fg="yellow" - ) + click.secho("No revision state found on the schema.", fg="yellow") def up(self, rev_id, dry_run=False): log_verb = "Previewing" if dry_run else "Performing" for ctx in self.envs: self.load_schema(ctx) start_point = ctx._current_revision_ - revisions = ctx.scriptdir.get_upgrade_revs( - rev_id, start_point - ) - click.echo( - " ".join([ - f"> {log_verb} upgrades against", - click.style(ctx.db._uri, bold=True) - ]) - ) - db = ( - DryRunDatabase(ctx.db, self._log_dry_run) if dry_run else - ctx.db - ) + revisions = ctx.scriptdir.get_upgrade_revs(rev_id, start_point) + click.echo(" ".join([f"> {log_verb} upgrades against", click.style(ctx.db._uri, bold=True)])) + db = DryRunDatabase(ctx.db, self._log_dry_run) if dry_run else ctx.db with db.connection(): for revision in revisions: - click.echo( - " ".join([ - f"> {log_verb} upgrade:", - click.style(str(revision), fg="cyan", bold=True) - ]) - ) + click.echo(" ".join([f"> {log_verb} upgrade:", click.style(str(revision), fg="cyan", bold=True)])) migration = revision.migration_class(self.app, db) try: migration.up() db.commit() if dry_run: continue - self._store_current_revision_( - ctx, migration.revises, migration.revision - ) + self._store_current_revision_(ctx, migration.revises, migration.revision) click.echo( - "".join([ - click.style( - "> Succesfully upgraded to revision ", - fg="green" - ), - click.style( - revision.revision, fg="cyan", bold=True - ), - click.style(f": {revision.doc}", fg="green") - ]) + "".join( + [ + click.style("> Succesfully upgraded to revision ", fg="green"), + click.style(revision.revision, fg="cyan", bold=True), + click.style(f": {revision.doc}", fg="green"), + ] + ) ) except Exception: db.rollback() click.echo( - " ".join([ - click.style("> Failed upgrading to", fg="red"), - click.style( - revision.revision, fg="red", bold=True - ), - ]) + " ".join( + [ + click.style("> Failed upgrading to", fg="red"), + click.style(revision.revision, fg="red", bold=True), + ] + ) ) raise @@ -309,58 +227,37 @@ def down(self, rev_id, dry_run=False): for ctx in self.envs: self.load_schema(ctx) start_point = ctx._current_revision_ - revisions = ctx.scriptdir.get_downgrade_revs( - rev_id, start_point) - click.echo( - " ".join([ - f"> {log_verb} downgrades against", - click.style(ctx.db._uri, bold=True) - ]) - ) - db = ( - DryRunDatabase(ctx.db, self._log_dry_run) if dry_run else - ctx.db - ) + revisions = ctx.scriptdir.get_downgrade_revs(rev_id, start_point) + click.echo(" ".join([f"> {log_verb} downgrades against", click.style(ctx.db._uri, bold=True)])) + db = DryRunDatabase(ctx.db, self._log_dry_run) if dry_run else ctx.db with db.connection(): for revision in revisions: - click.echo( - " ".join([ - f"> {log_verb} downgrade:", - click.style(str(revision), fg="cyan", bold=True) - ]) - ) + click.echo(" ".join([f"> {log_verb} downgrade:", click.style(str(revision), fg="cyan", bold=True)])) migration = revision.migration_class(self.app, db) try: migration.down() db.commit() if dry_run: continue - self._store_current_revision_( - ctx, migration.revision, migration.revises - ) + self._store_current_revision_(ctx, migration.revision, migration.revises) click.echo( - "".join([ - click.style( - "> Succesfully downgraded from revision ", - fg="green" - ), - click.style( - revision.revision, fg="cyan", bold=True - ), - click.style(f": {revision.doc}", fg="green") - ]) + "".join( + [ + click.style("> Succesfully downgraded from revision ", fg="green"), + click.style(revision.revision, fg="cyan", bold=True), + click.style(f": {revision.doc}", fg="green"), + ] + ) ) except Exception: db.rollback() click.echo( - " ".join([ - click.style( - "> Failed downgrading from", fg="red" - ), - click.style( - revision.revision, fg="red", bold=True - ), - ]) + " ".join( + [ + click.style("> Failed downgrading from", fg="red"), + click.style(revision.revision, fg="red", bold=True), + ] + ) ) raise @@ -373,32 +270,29 @@ def set(self, rev_id, auto_confirm=False): click.secho("> No matching revision found", fg="red") return click.echo( - " ".join([ - click.style("> Setting revision to", fg="yellow"), - click.style(target_revision.revision, bold=True, fg="yellow"), - click.style("against", fg="yellow"), - click.style(ctx.db._uri, bold=True, fg="yellow") - ]) + " ".join( + [ + click.style("> Setting revision to", fg="yellow"), + click.style(target_revision.revision, bold=True, fg="yellow"), + click.style("against", fg="yellow"), + click.style(ctx.db._uri, bold=True, fg="yellow"), + ] + ) ) if not auto_confirm: if not click.confirm("Do you want to continue?"): click.echo("Aborting") return with ctx.db.connection(): - self._store_current_revision_( - ctx, current_revision, target_revision.revision - ) + self._store_current_revision_(ctx, current_revision, target_revision.revision) click.echo( - "".join([ - click.style( - "> Succesfully set revision to ", - fg="green" - ), - click.style( - target_revision.revision, fg="cyan", bold=True - ), - click.style(f": {target_revision.doc}", fg="green") - ]) + "".join( + [ + click.style("> Succesfully set revision to ", fg="green"), + click.style(target_revision.revision, fg="cyan", bold=True), + click.style(f": {target_revision.doc}", fg="green"), + ] + ) ) @@ -413,9 +307,7 @@ def new(app, dals, message, head): def history(app, dals, rev_range, verbose): if rev_range is not None: if ":" not in rev_range: - raise Exception( - "History range requires [start]:[end], " - "[start]:, or :[end]") + raise Exception("History range requires [start]:[end], " "[start]:, or :[end]") base, head = rev_range.strip().split(":") else: base = head = None diff --git a/emmett/orm/migrations/engine.py b/emmett/orm/migrations/engine.py index c52f2df4..565e0704 100644 --- a/emmett/orm/migrations/engine.py +++ b/emmett/orm/migrations/engine.py @@ -1,12 +1,12 @@ # -*- coding: utf-8 -*- """ - emmett.orm.migrations.engine - ---------------------------- +emmett.orm.migrations.engine +---------------------------- - Provides migration engine for pyDAL. +Provides migration engine for pyDAL. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ from __future__ import annotations @@ -15,6 +15,7 @@ from ...datastructures import sdict + if TYPE_CHECKING: from .base import Column, Database from .generation import MetaData @@ -48,40 +49,27 @@ def drop_index(self, name, table_name): self.db.drop_index(table_name, name) def create_foreign_key_constraint( - self, - name, - table_name, - column_names, - foreign_table_name, - foreign_keys, - on_delete + self, name, table_name, column_names, foreign_table_name, foreign_keys, on_delete ): self.db.create_foreign_key_constraint( - table_name, - name, - column_names, - foreign_table_name, - foreign_keys, - on_delete + table_name, name, column_names, foreign_table_name, foreign_keys, on_delete ) def drop_foreign_key_constraint(self, name, table_name): self.db.drop_foreign_key_constraint(table_name, name) @staticmethod - def _parse_column_changes( - changes: List[Tuple[str, str, str, Dict[str, Any], Any, Any]] - ) -> Dict[str, List[Any]]: + def _parse_column_changes(changes: List[Tuple[str, str, str, Dict[str, Any], Any, Any]]) -> Dict[str, List[Any]]: rv = {} for change in changes: if change[0] == "modify_type": - rv['type'] = [change[4], change[5], change[3]['existing_length']] + rv["type"] = [change[4], change[5], change[3]["existing_length"]] elif change[0] == "modify_length": - rv['length'] = [change[4], change[5], change[3]['existing_type']] + rv["length"] = [change[4], change[5], change[3]["existing_type"]] elif change[0] == "modify_notnull": - rv['notnull'] = [change[4], change[5]] + rv["notnull"] = [change[4], change[5]] elif change[0] == "modify_default": - rv['default'] = [change[4], change[5], change[3]['existing_type']] + rv["default"] = [change[4], change[5], change[3]["existing_type"]] else: rv[change[0].split("modify_")[-1]] = [change[4], change[5], change[3]] return rv @@ -110,7 +98,7 @@ def create_table(self, name, columns, primary_keys, **kwargs): def drop_table(self, name): adapt_v = sdict(_rname=self.dialect.quote(name)) - sql_list = self.dialect.drop_table(adapt_v, 'cascade') + sql_list = self.dialect.drop_table(adapt_v, "cascade") for sql in sql_list: self._log_and_exec(sql) @@ -123,11 +111,7 @@ def drop_column(self, tablename, colname): self._log_and_exec(sql) def alter_column(self, table_name, column_name, changes): - sql = self._alter_column_sql( - table_name, - column_name, - self._parse_column_changes(changes) - ) + sql = self._alter_column_sql(table_name, column_name, self._parse_column_changes(changes)) if sql is not None: self._log_and_exec(sql) @@ -150,7 +134,7 @@ def create_foreign_key_constraint( column_names: List[str], foreign_table_name: str, foreign_keys: List[str], - on_delete: str + on_delete: str, ): sql = self.dialect.add_foreign_key_constraint( name, table_name, foreign_table_name, column_names, foreign_keys, on_delete @@ -165,104 +149,74 @@ def _gen_reference(self, tablename, column): referenced = column.type[10:].strip() constraint_name = self.dialect.constraint_name(tablename, column.name) try: - rtablename, rfieldname = referenced.split('.') + rtablename, rfieldname = referenced.split(".") except Exception: rtablename = referenced - rfieldname = 'id' + rfieldname = "id" if not rtablename: rtablename = tablename - csql_info = dict( - index_name=self.dialect.quote(column.name + '__idx'), - field_name=self.dialect.quote(column.name), - constraint_name=self.dialect.quote(constraint_name), - foreign_key='%s (%s)' % ( - self.dialect.quote(rtablename), - self.dialect.quote(rfieldname) - ), - on_delete_action=column.ondelete) - csql_info['null'] = ( - ' NOT NULL' if column.notnull else - self.dialect.allow_null - ) - csql_info['unique'] = ' UNIQUE' if column.unique else '' - csql = self.adapter.types['reference'] % csql_info + csql_info = { + "index_name": self.dialect.quote(column.name + "__idx"), + "field_name": self.dialect.quote(column.name), + "constraint_name": self.dialect.quote(constraint_name), + "foreign_key": "%s (%s)" % (self.dialect.quote(rtablename), self.dialect.quote(rfieldname)), + "on_delete_action": column.ondelete, + } + csql_info["null"] = " NOT NULL" if column.notnull else self.dialect.allow_null + csql_info["unique"] = " UNIQUE" if column.unique else "" + csql = self.adapter.types["reference"] % csql_info return csql def _gen_primary_key(self, fields, primary_keys=[]): if primary_keys: - fields.append( - self.dialect.primary_key( - ', '.join([self.dialect.quote(pk) for pk in primary_keys]) - ) - ) + fields.append(self.dialect.primary_key(", ".join([self.dialect.quote(pk) for pk in primary_keys]))) def _gen_geo(self, column_type, geometry_type, srid, dimension): - if not hasattr(self.adapter, 'srid'): - raise RuntimeError('Adapter does not support geometry') + if not hasattr(self.adapter, "srid"): + raise RuntimeError("Adapter does not support geometry") if column_type not in self.adapter.types: - raise SyntaxError( - f'Field: unknown field type: {column_type}' - ) + raise SyntaxError(f"Field: unknown field type: {column_type}") return "{ctype}({gtype},{srid},{dimension})".format( ctype=self.adapter.types[column_type], gtype=geometry_type, srid=srid or self.adapter.srid, - dimension=dimension or 2 + dimension=dimension or 2, ) - def _new_column_sql( - self, - tablename: str, - column: Column, - primary_key: bool = False - ) -> str: - if column.type.startswith('reference'): + def _new_column_sql(self, tablename: str, column: Column, primary_key: bool = False) -> str: + if column.type.startswith("reference"): csql = self._gen_reference(tablename, column) - elif column.type.startswith('list:reference'): + elif column.type.startswith("list:reference"): csql = self.adapter.types[column.type[:14]] - elif column.type.startswith('decimal'): - precision, scale = map(int, column.type[8:-1].split(',')) - csql = self.adapter.types[column.type[:7]] % dict( - precision=precision, scale=scale - ) - elif column.type.startswith('geo'): - csql = self._gen_geo( - column.type, - column.geometry_type, - column.srid, - column.dimension - ) + elif column.type.startswith("decimal"): + precision, scale = map(int, column.type[8:-1].split(",")) + csql = self.adapter.types[column.type[:7]] % {"precision": precision, "scale": scale} + elif column.type.startswith("geo"): + csql = self._gen_geo(column.type, column.geometry_type, column.srid, column.dimension) elif column.type not in self.adapter.types: - raise SyntaxError( - f'Field: unknown field type: {column.type} for {column.nmae}' - ) + raise SyntaxError(f"Field: unknown field type: {column.type} for {column.nmae}") else: - csql = self.adapter.types[column.type] % {'length': column.length} - if self.adapter.dbengine not in ('firebird', 'informix', 'oracle'): + csql = self.adapter.types[column.type] % {"length": column.length} + if self.adapter.dbengine not in ("firebird", "informix", "oracle"): cprops = "%(notnull)s%(default)s%(unique)s%(pk)s%(qualifier)s" else: cprops = "%(default)s%(notnull)s%(unique)s%(pk)s%(qualifier)s" - if not column.type.startswith(('id', 'reference')): + if not column.type.startswith(("id", "reference")): csql += cprops % { - 'notnull': ' NOT NULL' if column.notnull else self.dialect.allow_null, - 'default': ( - ' DEFAULT %s' % self.adapter.represent(column.default, column.type) - if column.default is not None else '' + "notnull": " NOT NULL" if column.notnull else self.dialect.allow_null, + "default": ( + " DEFAULT %s" % self.adapter.represent(column.default, column.type) + if column.default is not None + else "" ), - 'unique': ' UNIQUE' if column.unique else '', - 'pk': ' PRIMARY KEY' if primary_key else '', - 'qualifier': ( - ' %s' % column.custom_qualifier if column.custom_qualifier else '' - ) + "unique": " UNIQUE" if column.unique else "", + "pk": " PRIMARY KEY" if primary_key else "", + "qualifier": (" %s" % column.custom_qualifier if column.custom_qualifier else ""), } return csql def _new_table_sql( - self, - tablename: str, - columns: List[Column], - primary_keys: List[str] = [], - id_col: str ='id' + self, tablename: str, columns: List[Column], primary_keys: List[str] = [], id_col: str = "id" ) -> str: # TODO: # - SQLCustomType @@ -270,39 +224,34 @@ def _new_table_sql( fields = [] for column in columns: csql = self._new_column_sql( - tablename, - column, - primary_key=( - column.name in primary_keys if not composed_primary_key else False - ) + tablename, column, primary_key=(column.name in primary_keys if not composed_primary_key else False) ) - fields.append('%s %s' % (self.dialect.quote(column.name), csql)) + fields.append("%s %s" % (self.dialect.quote(column.name), csql)) # backend-specific extensions to fields - if self.adapter.dbengine == 'mysql': + if self.adapter.dbengine == "mysql": if not primary_keys: primary_keys.append(id_col) elif not composed_primary_key: primary_keys.clear() self._gen_primary_key(fields, primary_keys) - fields = ',\n '.join(fields) + fields = ",\n ".join(fields) return self.dialect.create_table(tablename, fields) def _add_column_sql(self, tablename, column): csql = self._new_column_sql(tablename, column) - return 'ALTER TABLE %(tname)s ADD %(cname)s %(sql)s;' % { - 'tname': self.dialect.quote(tablename), - 'cname': self.dialect.quote(column.name), - 'sql': csql + return "ALTER TABLE %(tname)s ADD %(cname)s %(sql)s;" % { + "tname": self.dialect.quote(tablename), + "cname": self.dialect.quote(column.name), + "sql": csql, } def _drop_column_sql(self, table_name, column_name): if self.adapter.dbengine == "firebird": - sql = 'ALTER TABLE %s DROP %s;' + sql = "ALTER TABLE %s DROP %s;" else: - sql = 'ALTER TABLE %s DROP COLUMN %s;' - return sql % ( - self.dialect.quote(table_name), self.dialect.quote(column_name)) + sql = "ALTER TABLE %s DROP COLUMN %s;" + return sql % (self.dialect.quote(table_name), self.dialect.quote(column_name)) def _represent_changes(self, changes, field): geo_attrs = ("geometry_type", "srid", "dimension") @@ -311,62 +260,46 @@ def _represent_changes(self, changes, field): geo_changes[key] = changes.pop(key) geo_data.update(geo_changes[key][3]) - if 'default' in changes and changes['default'][1] is not None: - ftype = changes['default'][2] or field.type - if 'type' in changes: - ftype = changes['type'][1] - changes['default'][1] = self.adapter.represent( - changes['default'][1], ftype) - if 'type' in changes: - changes.pop('length', None) - coltype = changes['type'][1] - if coltype.startswith('reference'): - raise NotImplementedError( - 'Type change on reference fields is not supported.' - ) - elif coltype.startswith('decimal'): - precision, scale = map(int, coltype[8:-1].split(',')) - csql = self.adapter.types[coltype[:7]] % \ - dict(precision=precision, scale=scale) - elif coltype.startswith('geo'): + if "default" in changes and changes["default"][1] is not None: + ftype = changes["default"][2] or field.type + if "type" in changes: + ftype = changes["type"][1] + changes["default"][1] = self.adapter.represent(changes["default"][1], ftype) + if "type" in changes: + changes.pop("length", None) + coltype = changes["type"][1] + if coltype.startswith("reference"): + raise NotImplementedError("Type change on reference fields is not supported.") + elif coltype.startswith("decimal"): + precision, scale = map(int, coltype[8:-1].split(",")) + csql = self.adapter.types[coltype[:7]] % {"precision": precision, "scale": scale} + elif coltype.startswith("geo"): gen_attrs = [] for key in geo_attrs: - val = ( - geo_changes.get(f"{key}", (None, None))[1] or - geo_data[f"existing_{key}"] - ) + val = geo_changes.get(f"{key}", (None, None))[1] or geo_data[f"existing_{key}"] gen_attrs.append(val) csql = self._gen_geo(coltype, gen_attrs) else: - csql = self.adapter.types[coltype] % { - 'length': changes['type'][2] or field.length - } - changes['type'][1] = csql - elif 'length' in changes: - change = changes.pop('length') + csql = self.adapter.types[coltype] % {"length": changes["type"][2] or field.length} + changes["type"][1] = csql + elif "length" in changes: + change = changes.pop("length") ftype = change[2] or field.type - changes['type'] = [None, self.adapter.types[ftype] % {'length': change[1]}] + changes["type"] = [None, self.adapter.types[ftype] % {"length": change[1]}] elif geo_changes: coltype = geo_data["existing_type"] or field.type gen_attrs = [] for key in geo_attrs: - val = ( - geo_changes.get(f"{key}", (None, None))[1] or - geo_data[f"existing_{key}"] - ) + val = geo_changes.get(f"{key}", (None, None))[1] or geo_data[f"existing_{key}"] gen_attrs.append(val) - changes['type'] = [None, self._gen_geo(coltype, *gen_attrs)] - + changes["type"] = [None, self._gen_geo(coltype, *gen_attrs)] def _alter_column_sql(self, table_name, column_name, changes): - sql = 'ALTER TABLE %(tname)s ALTER COLUMN %(cname)s %(changes)s;' + sql = "ALTER TABLE %(tname)s ALTER COLUMN %(cname)s %(changes)s;" sql_changes_map = { - 'type': "%s" if self.adapter.dbengine in ["mysql", "mssql"] else "TYPE %s", - 'notnull': { - True: "SET NOT NULL", - False: "DROP NOT NULL" - }, - 'default': ["SET DEFAULT %s", "DROP DEFAULT"] + "type": "%s" if self.adapter.dbengine in ["mysql", "mssql"] else "TYPE %s", + "notnull": {True: "SET NOT NULL", False: "DROP NOT NULL"}, + "default": ["SET DEFAULT %s", "DROP DEFAULT"], } field = self.db[table_name][column_name] self._represent_changes(changes, field) @@ -376,16 +309,13 @@ def _alter_column_sql(self, table_name, column_name, changes): if isinstance(change_sql, dict): sql_changes.append(change_sql[change_val[1]]) elif isinstance(change_sql, list): - sql_changes.append( - change_sql[0] % change_val[1] if change_val[1] is not None else - change_sql[1] - ) + sql_changes.append(change_sql[0] % change_val[1] if change_val[1] is not None else change_sql[1]) else: sql_changes.append(change_sql % change_val[1]) if not sql_changes: return None return sql % { - 'tname': self.dialect.quote(table_name), - 'cname': self.dialect.quote(column_name), - 'changes': " ".join(sql_changes) + "tname": self.dialect.quote(table_name), + "cname": self.dialect.quote(column_name), + "changes": " ".join(sql_changes), } diff --git a/emmett/orm/migrations/exceptions.py b/emmett/orm/migrations/exceptions.py index fc1d3198..eb86e00c 100644 --- a/emmett/orm/migrations/exceptions.py +++ b/emmett/orm/migrations/exceptions.py @@ -1,12 +1,12 @@ # -*- coding: utf-8 -*- """ - emmett.orm.migrations.exceptions - -------------------------------- +emmett.orm.migrations.exceptions +-------------------------------- - Provides exceptions for migration operations. +Provides exceptions for migration operations. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ @@ -19,8 +19,7 @@ def __init__(self, lower, upper): self.lower = lower self.upper = upper super(RangeNotAncestorError, self).__init__( - "Revision %s is not an ancestor of revision %s" % - (lower or "base", upper or "base") + "Revision %s is not an ancestor of revision %s" % (lower or "base", upper or "base") ) @@ -29,8 +28,7 @@ def __init__(self, heads, argument): self.heads = heads self.argument = argument super(MultipleHeads, self).__init__( - "Multiple heads are present for given argument '%s'; " - "%s" % (argument, ", ".join(heads)) + "Multiple heads are present for given argument '%s'; " "%s" % (argument, ", ".join(heads)) ) diff --git a/emmett/orm/migrations/generation.py b/emmett/orm/migrations/generation.py index 59ce6467..2ce14a9a 100644 --- a/emmett/orm/migrations/generation.py +++ b/emmett/orm/migrations/generation.py @@ -1,16 +1,16 @@ # -*- coding: utf-8 -*- """ - emmett.orm.migrations.generation - -------------------------------- +emmett.orm.migrations.generation +-------------------------------- - Provides generation utils for migrations. +Provides generation utils for migrations. - :copyright: 2014 Giovanni Barillari +:copyright: 2014 Giovanni Barillari - Based on the code of Alembic (https://bitbucket.org/zzzeek/alembic) - :copyright: (c) 2009-2015 by Michael Bayer +Based on the code of Alembic (https://bitbucket.org/zzzeek/alembic) +:copyright: (c) 2009-2015 by Michael Bayer - :license: BSD-3-Clause +:license: BSD-3-Clause """ from __future__ import annotations @@ -22,7 +22,7 @@ from ...datastructures import OrderedSet from ..objects import Rows, Table from .base import Column, Database -from .helpers import Dispatcher, DEFAULT_VALUE +from .helpers import DEFAULT_VALUE, Dispatcher from .operations import ( AddColumnOp, AlterColumnOp, @@ -36,19 +36,13 @@ MigrationOp, OpContainer, Operation, - UpgradeOps + UpgradeOps, ) from .scripts import ScriptDir class MetaTable: - def __init__( - self, - name: str, - columns: List[Column] = [], - primary_keys: List[str] = [], - **kw: Any - ): + def __init__(self, name: str, columns: List[Column] = [], primary_keys: List[str] = [], **kw: Any): self.name = name self.columns = OrderedDict() for column in columns: @@ -72,25 +66,14 @@ def __delitem__(self, name: str): del self.columns[name] def __repr__(self) -> str: - return "Table(%r, %s)" % ( - self.name, - ", ".join(["%s" % column for column in self.columns.values()]) - ) + return "Table(%r, %s)" % (self.name, ", ".join(["%s" % column for column in self.columns.values()])) def insert(self, *args, **kwargs) -> Any: return None class MetaIndex: - def __init__( - self, - table_name: str, - name: str, - fields: List[str], - expressions: List[str], - unique: bool, - **kw: Any - ): + def __init__(self, table_name: str, name: str, fields: List[str], expressions: List[str], unique: bool, **kw: Any): self.table_name = table_name self.name = name self.fields = fields @@ -100,16 +83,13 @@ def __init__( @property def where(self) -> Optional[str]: - return self.kw.get('where') + return self.kw.get("where") def __repr__(self) -> str: - opts = [('expressions', self.expressions), ('unique', self.unique)] + opts = [("expressions", self.expressions), ("unique", self.unique)] for key, val in self.kw.items(): opts.append((key, val)) - return "Index(%r, %r, %s)" % ( - self.name, self.fields, - ", ".join(["%s=%r" % (opt[0], opt[1]) for opt in opts]) - ) + return "Index(%r, %r, %s)" % (self.name, self.fields, ", ".join(["%s=%r" % (opt[0], opt[1]) for opt in opts])) class MetaForeignKey: @@ -121,7 +101,7 @@ def __init__( foreign_table_name: str, foreign_keys: List[str], on_delete: str, - **kw + **kw, ): self.table_name = table_name self.name = name @@ -150,7 +130,7 @@ def __repr__(self) -> str: self.foreign_table_name, self.column_names, self.foreign_keys, - self.on_delete + self.on_delete, ) @@ -189,13 +169,7 @@ def where(self, *args, **kwargs) -> MetaDataSet: def __call__(self, *args, **kwargs): return self.where(*args, **kwargs) - def create_table( - self, - name: str, - columns: List[Column], - primary_keys: List[str], - **kw: Any - ): + def create_table(self, name: str, columns: List[Column], primary_keys: List[str], **kw: Any): self.tables[name] = MetaTable(name, columns, primary_keys, **kw) def drop_table(self, name: str): @@ -211,13 +185,7 @@ def change_column(self, table_name: str, column_name: str, changes: Dict[str, An self.tables[table_name][column_name].update(**changes) def create_index( - self, - table_name: str, - index_name: str, - fields: List[str], - expressions: List[str], - unique: bool, - **kw: Any + self, table_name: str, index_name: str, fields: List[str], expressions: List[str], unique: bool, **kw: Any ): self.tables[table_name].indexes[index_name] = MetaIndex( table_name, index_name, fields, expressions, unique, **kw @@ -233,15 +201,10 @@ def create_foreign_key_constraint( column_names: List[str], foreign_table_name: str, foreign_keys: List[str], - on_delete: str + on_delete: str, ): self.tables[table_name].foreign_keys[constraint_name] = MetaForeignKey( - table_name, - constraint_name, - column_names, - foreign_table_name, - foreign_keys, - on_delete + table_name, constraint_name, column_names, foreign_table_name, foreign_keys, on_delete ) def drop_foreign_key_constraint(self, table_name: str, constraint_name: str): @@ -261,10 +224,8 @@ def make_ops(self) -> List[Operation]: def _build_metatable(self, dbtable: Table): return MetaTable( dbtable._tablename, - [ - Column.from_field(field) for field in list(dbtable) - ], - primary_keys=list(dbtable._primary_keys) + [Column.from_field(field) for field in list(dbtable)], + primary_keys=list(dbtable._primary_keys), ) def _build_metaindex(self, dbtable: Table, index_name: str) -> MetaIndex: @@ -272,15 +233,15 @@ def _build_metaindex(self, dbtable: Table, index_name: str) -> MetaIndex: dbindex = model._indexes_[index_name] kw = {} with self.db._adapter.index_expander(): - if 'where' in dbindex: - kw['where'] = str(dbindex['where']) + if "where" in dbindex: + kw["where"] = str(dbindex["where"]) rv = MetaIndex( model.tablename, index_name, - [field for field in dbindex['fields']], - [str(expr) for expr in dbindex['expressions']], - dbindex['unique'], - **kw + list(dbindex["fields"]), + [str(expr) for expr in dbindex["expressions"]], + dbindex["unique"], + **kw, ) return rv @@ -288,12 +249,7 @@ def _build_metafk(self, dbtable: Table, fk_name: str) -> MetaForeignKey: model = dbtable._model_ dbfk = model._foreign_keys_[fk_name] return MetaForeignKey( - model.tablename, - fk_name, - dbfk['fields_local'], - dbfk['table'], - dbfk['fields_foreign'], - dbfk['on_delete'] + model.tablename, fk_name, dbfk["fields_local"], dbfk["table"], dbfk["fields_foreign"], dbfk["on_delete"] ) def tables(self): @@ -322,64 +278,37 @@ def table(self, dbtable: Table, metatable: MetaTable): self.indexes_and_uniques(dbtable, metatable) self.foreign_keys(dbtable, metatable) - def indexes_and_uniques( - self, - dbtable: Table, - metatable: MetaTable, - ops_stack: Optional[List[Operation]] = None - ): + def indexes_and_uniques(self, dbtable: Table, metatable: MetaTable, ops_stack: Optional[List[Operation]] = None): ops = ops_stack if ops_stack is not None else self.ops - db_index_names = OrderedSet( - [idxname for idxname in dbtable._model_._indexes_.keys()] - ) + db_index_names = OrderedSet(list(dbtable._model_._indexes_.keys())) meta_index_names = OrderedSet(list(metatable.indexes)) #: removed indexes for index_name in meta_index_names.difference(db_index_names): ops.append(DropIndexOp.from_index(metatable.indexes[index_name])) #: new indexs for index_name in db_index_names.difference(meta_index_names): - ops.append( - CreateIndexOp.from_index( - self._build_metaindex(dbtable, index_name) - ) - ) + ops.append(CreateIndexOp.from_index(self._build_metaindex(dbtable, index_name))) #: existing indexes for index_name in meta_index_names.intersection(db_index_names): metaindex = metatable.indexes[index_name] dbindex = self._build_metaindex(dbtable, index_name) if any( - getattr(metaindex, key) != getattr(dbindex, key) - for key in ['fields', 'expressions', 'unique', 'kw'] + getattr(metaindex, key) != getattr(dbindex, key) for key in ["fields", "expressions", "unique", "kw"] ): ops.append(DropIndexOp.from_index(metaindex)) ops.append(CreateIndexOp.from_index(dbindex)) # TODO: uniques - def foreign_keys( - self, - dbtable: Table, - metatable: MetaTable, - ops_stack: Optional[List[Operation]] = None - ): + def foreign_keys(self, dbtable: Table, metatable: MetaTable, ops_stack: Optional[List[Operation]] = None): ops = ops_stack if ops_stack is not None else self.ops - db_fk_names = OrderedSet( - [fkname for fkname in dbtable._model_._foreign_keys_.keys()] - ) + db_fk_names = OrderedSet(list(dbtable._model_._foreign_keys_.keys())) meta_fk_names = OrderedSet(list(metatable.foreign_keys)) #: removed fks for fk_name in meta_fk_names.difference(db_fk_names): - ops.append( - DropForeignKeyConstraintOp.from_foreign_key( - metatable.foreign_keys[fk_name] - ) - ) + ops.append(DropForeignKeyConstraintOp.from_foreign_key(metatable.foreign_keys[fk_name])) #: new fks for fk_name in db_fk_names.difference(meta_fk_names): - ops.append( - CreateForeignKeyConstraintOp.from_foreign_key( - self._build_metafk(dbtable, fk_name) - ) - ) + ops.append(CreateForeignKeyConstraintOp.from_foreign_key(self._build_metafk(dbtable, fk_name))) #: existing fks for fk_name in meta_fk_names.intersection(db_fk_names): metafk = metatable.foreign_keys[fk_name] @@ -389,29 +318,22 @@ def foreign_keys( ops.append(CreateForeignKeyConstraintOp.from_foreign_key(dbfk)) def columns(self, dbtable: Table, metatable: MetaTable): - db_column_names = OrderedSet([fname for fname in dbtable.fields]) + db_column_names = OrderedSet(list(dbtable.fields)) meta_column_names = OrderedSet(metatable.fields) #: new columns for column_name in db_column_names.difference(meta_column_names): - self.ops.append(AddColumnOp.from_column_and_tablename( - dbtable._tablename, Column.from_field(dbtable[column_name]) - )) + self.ops.append( + AddColumnOp.from_column_and_tablename(dbtable._tablename, Column.from_field(dbtable[column_name])) + ) #: existing columns for column_name in meta_column_names.intersection(db_column_names): self.ops.append(AlterColumnOp(dbtable._tablename, column_name)) - self.column( - Column.from_field(dbtable[column_name]), - metatable.columns[column_name] - ) + self.column(Column.from_field(dbtable[column_name]), metatable.columns[column_name]) if not self.ops[-1].has_changes(): self.ops.pop() #: removed columns for column_name in meta_column_names.difference(db_column_names): - self.ops.append( - DropColumnOp.from_column_and_tablename( - dbtable._tablename, metatable.columns[column_name] - ) - ) + self.ops.append(DropColumnOp.from_column_and_tablename(dbtable._tablename, metatable.columns[column_name])) def column(self, dbcolumn: Column, metacolumn: Column): self.notnulls(dbcolumn, metacolumn) @@ -431,9 +353,7 @@ def types(self, dbcolumn: Column, metacolumn: Column): def lengths(self, dbcolumn: Column, metacolumn: Column): self.ops[-1].existing_length = metacolumn.length - if any( - field.type == "string" for field in [dbcolumn, metacolumn] - ) and dbcolumn.length != metacolumn.length: + if any(field.type == "string" for field in [dbcolumn, metacolumn]) and dbcolumn.length != metacolumn.length: self.ops[-1].modify_length = dbcolumn.length def notnulls(self, dbcolumn: Column, metacolumn: Column): @@ -461,12 +381,8 @@ def __init__(self, db: Database, scriptdir: ScriptDir, head: str): self._load_head_to_meta() def _load_head_to_meta(self): - for revision in reversed( - list(self.scriptdir.walk_revisions("base", self.head)) - ): - migration = revision.migration_class( - None, self.meta, is_meta=True - ) + for revision in reversed(list(self.scriptdir.walk_revisions("base", self.head))): + migration = revision.migration_class(None, self.meta, is_meta=True) if migration.skip_on_compare: continue migration.up() @@ -475,12 +391,7 @@ def generate(self) -> UpgradeOps: return Comparator.compare(self.db, self.meta) @classmethod - def generate_from( - cls, - dal: Database, - scriptdir: ScriptDir, - head: str - ) -> UpgradeOps: + def generate_from(cls, dal: Database, scriptdir: ScriptDir, head: str) -> UpgradeOps: return cls(dal, scriptdir, head).generate() @@ -501,10 +412,7 @@ def render_opcontainer(self, op_container: OpContainer) -> List[str]: @classmethod def render_migration(cls, migration_op: MigrationOp): r = cls() - return ( - r.render_opcontainer(migration_op.upgrade_ops), - r.render_opcontainer(migration_op.downgrade_ops) - ) + return (r.render_opcontainer(migration_op.upgrade_ops), r.render_opcontainer(migration_op.downgrade_ops)) renderers = Dispatcher() @@ -514,10 +422,7 @@ def render_migration(cls, migration_op: MigrationOp): def _add_table(op: CreateTableOp) -> str: table = op.to_table() - args = [ - col for col in [_render_column(col) for col in table.columns.values()] - if col - ] + args = [col for col in [_render_column(col) for col in table.columns.values()] if col] # + sorted([ # rcons for rcons in [ # _render_constraint(cons) for cons in table.constraints] @@ -527,18 +432,19 @@ def _add_table(op: CreateTableOp) -> str: indent = " " * 12 if len(args) > 255: - args = '*[' + (',\n' + indent).join(args) + ']' + args = "*[" + (",\n" + indent).join(args) + "]" else: - args = (',\n' + indent).join(args) + args = (",\n" + indent).join(args) text = ( - "self.create_table(\n" + indent + "%(tablename)r,\n" + indent + "%(args)s,\n" + - indent + "primary_keys=%(primary_keys)r" - ) % { - 'tablename': op.table_name, - 'args': args, - 'primary_keys': table.primary_keys - } + "self.create_table(\n" + + indent + + "%(tablename)r,\n" + + indent + + "%(args)s,\n" + + indent + + "primary_keys=%(primary_keys)r" + ) % {"tablename": op.table_name, "args": args, "primary_keys": table.primary_keys} for k in sorted(op.kw): text += ",\n" + indent + "%s=%r" % (k.replace(" ", "_"), op.kw[k]) text += ")" @@ -547,9 +453,7 @@ def _add_table(op: CreateTableOp) -> str: @renderers.dispatch_for(DropTableOp) def _drop_table(op: DropTableOp) -> str: - text = "self.drop_table(%(tname)r" % { - "tname": op.table_name - } + text = "self.drop_table(%(tname)r" % {"tname": op.table_name} text += ")" return text @@ -568,7 +472,7 @@ def _render_column(column: Column) -> str: if column.type in ("string", "password", "upload"): opts.append(("length", column.length)) - elif column.type.startswith('reference'): + elif column.type.startswith("reference"): opts.append(("ondelete", column.ondelete)) elif column.type.startswith("geo"): for key in ("geometry_type", "srid", "dimension"): @@ -577,37 +481,24 @@ def _render_column(column: Column) -> str: kw_str = "" if opts: - kw_str = ", %s" % \ - ", ".join(["%s=%r" % (key, val) for key, val in opts]) - return "migrations.Column(%(name)r, %(type)r%(kw)s)" % { - 'name': column.name, - 'type': column.type, - 'kw': kw_str - } + kw_str = ", %s" % ", ".join(["%s=%r" % (key, val) for key, val in opts]) + return "migrations.Column(%(name)r, %(type)r%(kw)s)" % {"name": column.name, "type": column.type, "kw": kw_str} @renderers.dispatch_for(AddColumnOp) def _add_column(op: AddColumnOp) -> str: - return "self.add_column(%(tname)r, %(column)s)" % { - "tname": op.table_name, - "column": _render_column(op.column) - } + return "self.add_column(%(tname)r, %(column)s)" % {"tname": op.table_name, "column": _render_column(op.column)} @renderers.dispatch_for(DropColumnOp) def _drop_column(op: DropTableOp) -> str: - return "self.drop_column(%(tname)r, %(cname)r)" % { - "tname": op.table_name, - "cname": op.column_name - } + return "self.drop_column(%(tname)r, %(cname)r)" % {"tname": op.table_name, "cname": op.column_name} @renderers.dispatch_for(AlterColumnOp) def _alter_column(op: AlterColumnOp) -> str: indent = " " * 12 - text = "self.alter_column(%(tname)r, %(cname)r" % { - 'tname': op.table_name, - 'cname': op.column_name} + text = "self.alter_column(%(tname)r, %(cname)r" % {"tname": op.table_name, "cname": op.column_name} if op.existing_type is not None: text += ",\n%sexisting_type=%r" % (indent, op.existing_type) @@ -637,47 +528,31 @@ def _alter_column(op: AlterColumnOp) -> str: def _add_index(op: CreateIndexOp) -> str: kw_str = "" if op.kw: - kw_str = ", %s" % ", ".join( - ["%s=%r" % (key, val) for key, val in op.kw.items()]) + kw_str = ", %s" % ", ".join(["%s=%r" % (key, val) for key, val in op.kw.items()]) return "self.create_index(%(iname)r, %(tname)r, %(idata)s)" % { "tname": op.table_name, "iname": op.index_name, - "idata": "%r, expressions=%r, unique=%s%s" % ( - op.fields, op.expressions, op.unique, kw_str) + "idata": "%r, expressions=%r, unique=%s%s" % (op.fields, op.expressions, op.unique, kw_str), } @renderers.dispatch_for(DropIndexOp) def _drop_index(op: DropIndexOp) -> str: - return "self.drop_index(%(iname)r, %(tname)r)" % { - "tname": op.table_name, - "iname": op.index_name - } + return "self.drop_index(%(iname)r, %(tname)r)" % {"tname": op.table_name, "iname": op.index_name} @renderers.dispatch_for(CreateForeignKeyConstraintOp) def _add_fk_constraint(op: CreateForeignKeyConstraintOp) -> str: kw_str = "" if op.kw: - kw_str = ", %s" % ", ".join( - ["%s=%r" % (key, val) for key, val in op.kw.items()] - ) + kw_str = ", %s" % ", ".join(["%s=%r" % (key, val) for key, val in op.kw.items()]) return "self.create_foreign_key(%s%s)" % ( - "%r, %r, %r, %r, %r, on_delete=%r" % ( - op.constraint_name, - op.table_name, - op.foreign_table_name, - op.column_names, - op.foreign_keys, - op.on_delete - ), - kw_str + "%r, %r, %r, %r, %r, on_delete=%r" + % (op.constraint_name, op.table_name, op.foreign_table_name, op.column_names, op.foreign_keys, op.on_delete), + kw_str, ) @renderers.dispatch_for(DropForeignKeyConstraintOp) def _drop_fk_constraint(op: DropForeignKeyConstraintOp) -> str: - return "self.drop_foreign_key(%(cname)r, %(tname)r)" % { - "tname": op.table_name, - "cname": op.constraint_name - } + return "self.drop_foreign_key(%(cname)r, %(tname)r)" % {"tname": op.table_name, "cname": op.constraint_name} diff --git a/emmett/orm/migrations/helpers.py b/emmett/orm/migrations/helpers.py index 0f804b7b..fd9ed6a6 100644 --- a/emmett/orm/migrations/helpers.py +++ b/emmett/orm/migrations/helpers.py @@ -1,12 +1,12 @@ # -*- coding: utf-8 -*- """ - emmett.orm.migrations.helpers - ----------------------------- +emmett.orm.migrations.helpers +----------------------------- - Provides helpers for migrations. +Provides helpers for migrations. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ from __future__ import annotations @@ -21,6 +21,7 @@ from ...datastructures import _unique_list from .base import Database + if TYPE_CHECKING: from .engine import MetaEngine from .operations import Operation @@ -50,12 +51,12 @@ def __init__(self): self._registry: Dict[Type[Operation], Callable[[Operation], str]] = {} def dispatch_for( - self, - target: Type[Operation] + self, target: Type[Operation] ) -> Callable[[Callable[[Operation], str]], Callable[[Operation], str]]: def wrap(fn: Callable[[Operation], str]) -> Callable[[Operation], str]: self._registry[target] = fn return fn + return wrap def dispatch(self, obj: Operation): @@ -98,11 +99,11 @@ def to_tuple(x, default=None): if x is None: return default elif isinstance(x, str): - return (x, ) + return (x,) elif isinstance(x, Iterable): return tuple(x) else: - return (x, ) + return (x,) def tuple_or_value(val): diff --git a/emmett/orm/migrations/operations.py b/emmett/orm/migrations/operations.py index b2558c6b..965e60b3 100644 --- a/emmett/orm/migrations/operations.py +++ b/emmett/orm/migrations/operations.py @@ -1,30 +1,30 @@ # -*- coding: utf-8 -*- """ - emmett.orm.migrations.operations - -------------------------------- +emmett.orm.migrations.operations +-------------------------------- - Provides operations handlers for migrations. +Provides operations handlers for migrations. - :copyright: 2014 Giovanni Barillari +:copyright: 2014 Giovanni Barillari - Based on the code of Alembic (https://bitbucket.org/zzzeek/alembic) - :copyright: (c) 2009-2015 by Michael Bayer +Based on the code of Alembic (https://bitbucket.org/zzzeek/alembic) +:copyright: (c) 2009-2015 by Michael Bayer - :license: BSD-3-Clause +:license: BSD-3-Clause """ from __future__ import annotations import re - from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple -from .base import Migration, Column +from .base import Column, Migration from .helpers import DEFAULT_VALUE + if TYPE_CHECKING: from .engine import MetaEngine - from .generation import MetaTable, MetaIndex, MetaForeignKey + from .generation import MetaForeignKey, MetaIndex, MetaTable class Operation: @@ -52,7 +52,7 @@ def as_diffs(self): @classmethod def _ops_as_diffs(cls, migrations): for op in migrations.ops: - if hasattr(op, 'ops'): + if hasattr(op, "ops"): for sub_op in cls._ops_as_diffs(op): yield sub_op else: @@ -66,10 +66,7 @@ def __init__(self, table_name: str, ops: List[Operation]): self.table_name = table_name def reverse(self) -> ModifyTableOps: - return ModifyTableOps( - self.table_name, - ops=list(reversed([op.reverse() for op in self.ops])) - ) + return ModifyTableOps(self.table_name, ops=list(reversed([op.reverse() for op in self.ops]))) class UpgradeOps(OpContainer): @@ -79,9 +76,7 @@ def __init__(self, ops: List[Operation] = [], upgrade_token: str = "upgrades"): self.upgrade_token = upgrade_token def reverse(self) -> DowngradeOps: - return DowngradeOps( - ops=list(reversed([op.reverse() for op in self.ops])) - ) + return DowngradeOps(ops=list(reversed([op.reverse() for op in self.ops]))) class DowngradeOps(OpContainer): @@ -91,9 +86,7 @@ def __init__(self, ops: List[Operation] = [], downgrade_token: str = "downgrades self.downgrade_token = downgrade_token def reverse(self): - return UpgradeOps( - ops=list(reversed([op.reverse() for op in self.ops])) - ) + return UpgradeOps(ops=list(reversed([op.reverse() for op in self.ops]))) class MigrationOp(Operation): @@ -104,7 +97,7 @@ def __init__( downgrade_ops: DowngradeOps, message: Optional[str] = None, head: Optional[str] = None, - splice: Any = None + splice: Any = None, ): self.rev_id = rev_id self.message = message @@ -122,7 +115,7 @@ def __init__( columns: List[Column], primary_keys: List[str] = [], _orig_table: Optional[MetaTable] = None, - **kw: Any + **kw: Any, ): self.table_name = table_name self.columns = columns @@ -139,45 +132,28 @@ def to_diff_tuple(self) -> Tuple[str, MetaTable]: @classmethod def from_table(cls, table: MetaTable) -> CreateTableOp: return cls( - table.name, - [table[colname] for colname in table.fields], - list(table.primary_keys), - _orig_table=table + table.name, [table[colname] for colname in table.fields], list(table.primary_keys), _orig_table=table ) def to_table(self, migration_context: Any = None) -> MetaTable: if self._orig_table is not None: return self._orig_table from .generation import MetaTable - return MetaTable( - self.table_name, - self.columns, - self.primary_keys, - **self.kw - ) + + return MetaTable(self.table_name, self.columns, self.primary_keys, **self.kw) @classmethod - def create_table( - cls, - table_name: str, - *columns: Column, - **kw: Any - ) -> CreateTableOp: + def create_table(cls, table_name: str, *columns: Column, **kw: Any) -> CreateTableOp: return cls(table_name, columns, **kw) def run(self): - self.engine.create_table( - self.table_name, self.columns, self.primary_keys, **self.kw - ) + self.engine.create_table(self.table_name, self.columns, self.primary_keys, **self.kw) @Migration.register_operation("drop_table") class DropTableOp(Operation): def __init__( - self, - table_name: str, - table_kw: Optional[Dict[str, Any]] = None, - _orig_table: Optional[MetaTable] = None + self, table_name: str, table_kw: Optional[Dict[str, Any]] = None, _orig_table: Optional[MetaTable] = None ): self.table_name = table_name self.table_kw = table_kw or {} @@ -188,9 +164,7 @@ def to_diff_tuple(self) -> Tuple[str, MetaTable]: def reverse(self) -> CreateTableOp: if self._orig_table is None: - raise ValueError( - "operation is not reversible; original table is not present" - ) + raise ValueError("operation is not reversible; original table is not present") return CreateTableOp.from_table(self._orig_table) @classmethod @@ -201,10 +175,8 @@ def to_table(self) -> MetaTable: if self._orig_table is not None: return self._orig_table from .generation import MetaTable - return MetaTable( - self.table_name, - **self.table_kw - ) + + return MetaTable(self.table_name, **self.table_kw) @classmethod def drop_table(cls, table_name: str, **kw: Any) -> DropTableOp: @@ -230,7 +202,7 @@ def rename_table(cls, old_table_name: str, new_table_name: str) -> RenameTableOp return cls(old_table_name, new_table_name) def run(self): - raise NotImplementedError('Table renaming is currently not supported.') + raise NotImplementedError("Table renaming is currently not supported.") @Migration.register_operation("add_column") @@ -262,13 +234,7 @@ def run(self): @Migration.register_operation("drop_column") class DropColumnOp(AlterTableOp): - def __init__( - self, - table_name: str, - column_name: str, - _orig_column: Optional[Column] = None, - **kw: Any - ): + def __init__(self, table_name: str, column_name: str, _orig_column: Optional[Column] = None, **kw: Any): super().__init__(table_name) self.column_name = column_name self.kw = kw @@ -279,13 +245,9 @@ def to_diff_tuple(self) -> Tuple[str, str, Column]: def reverse(self) -> AddColumnOp: if self._orig_column is None: - raise ValueError( - "operation is not reversible; original column is not present" - ) + raise ValueError("operation is not reversible; original column is not present") - return AddColumnOp.from_column_and_tablename( - self.table_name, self._orig_column - ) + return AddColumnOp.from_column_and_tablename(self.table_name, self._orig_column) @classmethod def from_column_and_tablename(cls, tname: str, col: Column) -> DropColumnOp: @@ -319,7 +281,7 @@ def __init__( modify_name: Optional[str] = None, modify_type: Optional[str] = None, modify_length: Optional[int] = None, - **kw: Any + **kw: Any, ): super().__init__(table_name) self.column_name = column_name @@ -348,13 +310,10 @@ def to_diff_tuple(self) -> List[Tuple[str, str, str, Dict[str, Any], Any, Any]]: "existing_length": self.existing_length, "existing_notnull": self.existing_notnull, "existing_default": self.existing_default, - **{ - nkey: nval for nkey, nval in self.kw.items() - if nkey.startswith('existing_') - } + **{nkey: nval for nkey, nval in self.kw.items() if nkey.startswith("existing_")}, }, self.existing_type, - self.modify_type + self.modify_type, ) ) @@ -367,10 +326,10 @@ def to_diff_tuple(self) -> List[Tuple[str, str, str, Dict[str, Any], Any, Any]]: { "existing_type": self.existing_type, "existing_notnull": self.existing_notnull, - "existing_default": self.existing_default + "existing_default": self.existing_default, }, self.existing_length, - self.modify_length + self.modify_length, ) ) @@ -380,12 +339,9 @@ def to_diff_tuple(self) -> List[Tuple[str, str, str, Dict[str, Any], Any, Any]]: "modify_notnull", tname, cname, - { - "existing_type": self.existing_type, - "existing_default": self.existing_default - }, + {"existing_type": self.existing_type, "existing_default": self.existing_default}, self.existing_notnull, - self.modify_notnull + self.modify_notnull, ) ) @@ -395,12 +351,9 @@ def to_diff_tuple(self) -> List[Tuple[str, str, str, Dict[str, Any], Any, Any]]: "modify_default", tname, cname, - { - "existing_notnull": self.existing_notnull, - "existing_type": self.existing_type - }, + {"existing_notnull": self.existing_notnull, "existing_type": self.existing_type}, self.existing_default, - self.modify_default + self.modify_default, ) ) @@ -414,13 +367,10 @@ def to_diff_tuple(self) -> List[Tuple[str, str, str, Dict[str, Any], Any, Any]]: cname, { "existing_type": self.existing_type, - **{ - nkey: nval for nkey, nval in self.kw.items() - if nkey.startswith('existing_') - } + **{nkey: nval for nkey, nval in self.kw.items() if nkey.startswith("existing_")}, }, self.kw.get(f"existing_{attr}"), - val + val, ) ) @@ -428,47 +378,42 @@ def to_diff_tuple(self) -> List[Tuple[str, str, str, Dict[str, Any], Any, Any]]: def has_changes(self) -> bool: hc = ( - self.modify_notnull is not None or - self.modify_default is not DEFAULT_VALUE or - self.modify_type is not None or - self.modify_length is not None + self.modify_notnull is not None + or self.modify_default is not DEFAULT_VALUE + or self.modify_type is not None + or self.modify_length is not None ) if hc: return True for kw in self.kw: - if kw.startswith('modify_'): + if kw.startswith("modify_"): return True return False def reverse(self) -> AlterColumnOp: kw = self.kw.copy() - kw['existing_type'] = self.existing_type - kw['existing_length'] = self.existing_length - kw['existing_notnull'] = self.existing_notnull - kw['existing_default'] = self.existing_default + kw["existing_type"] = self.existing_type + kw["existing_length"] = self.existing_length + kw["existing_notnull"] = self.existing_notnull + kw["existing_default"] = self.existing_default if self.modify_type is not None: - kw['modify_type'] = self.modify_type + kw["modify_type"] = self.modify_type if self.modify_length is not None: - kw['modify_length'] = self.modify_length + kw["modify_length"] = self.modify_length if self.modify_notnull is not None: - kw['modify_notnull'] = self.modify_notnull + kw["modify_notnull"] = self.modify_notnull if self.modify_default is not DEFAULT_VALUE: - kw['modify_default'] = self.modify_default + kw["modify_default"] = self.modify_default - all_keys = set(m.group(1) for m in [ - re.match(r'^(?:existing_|modify_)(.+)$', k) - for k in kw - ] if m) + all_keys = {m.group(1) for m in [re.match(r"^(?:existing_|modify_)(.+)$", k) for k in kw] if m} for k in all_keys: - if 'modify_%s' % k in kw: - swap = kw['existing_%s' % k] - kw['existing_%s' % k] = kw['modify_%s' % k] - kw['modify_%s' % k] = swap + if "modify_%s" % k in kw: + swap = kw["existing_%s" % k] + kw["existing_%s" % k] = kw["modify_%s" % k] + kw["modify_%s" % k] = swap - return self.__class__( - self.table_name, self.column_name, **kw - ) + return self.__class__(self.table_name, self.column_name, **kw) @classmethod def alter_column( @@ -484,7 +429,7 @@ def alter_column( existing_length: Optional[int] = None, existing_default: Any = None, existing_notnull: Optional[bool] = None, - **kw: Any + **kw: Any, ) -> AlterColumnOp: return cls( table_name, @@ -498,13 +443,11 @@ def alter_column( modify_length=length, modify_default=default, modify_notnull=notnull, - **kw + **kw, ) def run(self): - self.engine.alter_column( - self.table_name, self.column_name, self.to_diff_tuple() - ) + self.engine.alter_column(self.table_name, self.column_name, self.to_diff_tuple()) @Migration.register_operation("create_index") @@ -517,7 +460,7 @@ def __init__( expressions: List[str] = [], unique: bool = False, _orig_index: Optional[MetaIndex] = None, - **kw: Any + **kw: Any, ): self.index_name = index_name self.table_name = table_name @@ -536,22 +479,15 @@ def to_diff_tuple(self) -> Tuple[str, MetaIndex]: @classmethod def from_index(cls, index: MetaIndex) -> CreateIndexOp: return cls( - index.name, index.table_name, index.fields, index.expressions, - index.unique, _orig_index=index, **index.kw + index.name, index.table_name, index.fields, index.expressions, index.unique, _orig_index=index, **index.kw ) def to_index(self) -> MetaIndex: if self._orig_index is not None: return self._orig_index from .generation import MetaIndex - return MetaIndex( - self.table_name, - self.index_name, - self.fields, - self.expressions, - self.unique, - **self.kw - ) + + return MetaIndex(self.table_name, self.index_name, self.fields, self.expressions, self.unique, **self.kw) @classmethod def create_index( @@ -561,29 +497,19 @@ def create_index( fields: List[str] = [], expressions: List[str] = [], unique: bool = False, - **kw: Any + **kw: Any, ) -> CreateIndexOp: return cls(index_name, table_name, fields, expressions, unique, **kw) def run(self): self.engine.create_index( - self.index_name, - self.table_name, - self.fields, - self.expressions, - self.unique, - **self.kw + self.index_name, self.table_name, self.fields, self.expressions, self.unique, **self.kw ) @Migration.register_operation("drop_index") class DropIndexOp(Operation): - def __init__( - self, - index_name: str, - table_name: Optional[str] = None, - _orig_index: Optional[MetaIndex] = None - ): + def __init__(self, index_name: str, table_name: Optional[str] = None, _orig_index: Optional[MetaIndex] = None): self.index_name = index_name self.table_name = table_name self._orig_index = _orig_index @@ -593,9 +519,7 @@ def to_diff_tuple(self) -> Tuple[str, MetaIndex]: def reverse(self) -> CreateIndexOp: if self._orig_index is None: - raise ValueError( - "operation is not reversible; original index is not present" - ) + raise ValueError("operation is not reversible; original index is not present") return CreateIndexOp.from_index(self._orig_index) @classmethod @@ -606,6 +530,7 @@ def to_index(self) -> MetaIndex: if self._orig_index is not None: return self._orig_index from .generation import MetaIndex + return MetaIndex(self.table_name, self.index_name, [], [], False) @classmethod @@ -627,7 +552,7 @@ def __init__( foreign_keys: List[str], on_delete: str, _orig_fk: Optional[MetaForeignKey] = None, - **kw: Any + **kw: Any, ): super().__init__(table_name) self.constraint_name = name @@ -647,10 +572,7 @@ def to_diff_tuple(self) -> Tuple[str, MetaForeignKey]: return ("create_fk_constraint", self.to_foreign_key()) @classmethod - def from_foreign_key( - cls, - foreign_key: MetaForeignKey - ) -> CreateForeignKeyConstraintOp: + def from_foreign_key(cls, foreign_key: MetaForeignKey) -> CreateForeignKeyConstraintOp: return cls( foreign_key.name, foreign_key.table_name, @@ -658,7 +580,7 @@ def from_foreign_key( foreign_key.column_names, foreign_key.foreign_keys, foreign_key.on_delete, - _orig_fk=foreign_key + _orig_fk=foreign_key, ) def to_foreign_key(self) -> MetaForeignKey: @@ -666,13 +588,14 @@ def to_foreign_key(self) -> MetaForeignKey: return self._orig_fk from .generation import MetaForeignKey + return MetaForeignKey( self.table_name, self.constraint_name, self.column_names, self.foreign_table_name, self.foreign_keys, - self.on_delete + self.on_delete, ) @classmethod @@ -683,7 +606,7 @@ def create_foreign_key( foreign_table_name: str, column_names: List[str], foreign_keys: List[str], - on_delete: str + on_delete: str, ) -> CreateForeignKeyConstraintOp: return cls( name=name, @@ -691,7 +614,7 @@ def create_foreign_key( foreign_table_name=foreign_table_name, column_names=column_names, foreign_keys=foreign_keys, - on_delete=on_delete + on_delete=on_delete, ) def run(self): @@ -701,19 +624,13 @@ def run(self): self.column_names, self.foreign_table_name, self.foreign_keys, - self.on_delete + self.on_delete, ) @Migration.register_operation("drop_foreign_key") class DropForeignKeyConstraintOp(AlterTableOp): - def __init__( - self, - name: str, - table_name: str, - _orig_fk: Optional[MetaForeignKey] = None, - **kw: Any - ): + def __init__(self, name: str, table_name: str, _orig_fk: Optional[MetaForeignKey] = None, **kw: Any): super().__init__(table_name) self.constraint_name = name self.kw = kw @@ -721,38 +638,26 @@ def __init__( def reverse(self) -> CreateForeignKeyConstraintOp: if self._orig_fk is None: - raise ValueError( - "operation is not reversible; original constraint is not present" - ) + raise ValueError("operation is not reversible; original constraint is not present") return CreateForeignKeyConstraintOp.from_foreign_key(self._orig_fk) def to_diff_tuple(self) -> Tuple[str, MetaForeignKey]: return ("drop_fk_constraint", self.to_foreign_key()) @classmethod - def from_foreign_key( - cls, - foreign_key: MetaForeignKey - ) -> DropForeignKeyConstraintOp: - return cls( - foreign_key.name, - foreign_key.table_name, - _orig_fk=foreign_key - ) + def from_foreign_key(cls, foreign_key: MetaForeignKey) -> DropForeignKeyConstraintOp: + return cls(foreign_key.name, foreign_key.table_name, _orig_fk=foreign_key) def to_foreign_key(self): if self._orig_fk is not None: return self._orig_fk from .generation import MetaForeignKey - return MetaForeignKey(self.table_name, self.constraint_name, [], '', [], '') + + return MetaForeignKey(self.table_name, self.constraint_name, [], "", [], "") @classmethod - def drop_foreign_key( - cls, - name: str, - table_name: str - ) -> DropForeignKeyConstraintOp: + def drop_foreign_key(cls, name: str, table_name: str) -> DropForeignKeyConstraintOp: return DropForeignKeyConstraintOp(name, table_name) def run(self): diff --git a/emmett/orm/migrations/revisions.py b/emmett/orm/migrations/revisions.py index a88b9839..56fd4edd 100644 --- a/emmett/orm/migrations/revisions.py +++ b/emmett/orm/migrations/revisions.py @@ -1,25 +1,25 @@ # -*- coding: utf-8 -*- """ - emmett.orm.migrations.revisions - ------------------------------- +emmett.orm.migrations.revisions +------------------------------- - Provides revisions logic for migrations. +Provides revisions logic for migrations. - :copyright: 2014 Giovanni Barillari +:copyright: 2014 Giovanni Barillari - Based on the code of Alembic (https://bitbucket.org/zzzeek/alembic) - :copyright: (c) 2009-2015 by Michael Bayer +Based on the code of Alembic (https://bitbucket.org/zzzeek/alembic) +:copyright: (c) 2009-2015 by Michael Bayer - :license: BSD-3-Clause +:license: BSD-3-Clause """ from collections import deque + +from emmett_core.utils import cachedprop + from ...datastructures import OrderedSet -from ...utils import cachedprop -from .helpers import to_tuple, tuple_or_value, dedupe_tuple -from .exceptions import ( - RevisionError, RangeNotAncestorError, ResolutionError, MultipleHeads -) +from .exceptions import MultipleHeads, RangeNotAncestorError, ResolutionError, RevisionError +from .helpers import dedupe_tuple, to_tuple, tuple_or_value class Revision(object): @@ -27,31 +27,24 @@ class Revision(object): _all_nextrev = frozenset() revision = None down_revision = None - #dependencies = None - #branch_labels = None + # dependencies = None + # branch_labels = None - def __init__(self, revision, down_revision, - dependencies=None, branch_labels=None): + def __init__(self, revision, down_revision, dependencies=None, branch_labels=None): self.revision = revision self.down_revision = tuple_or_value(down_revision) - #self.dependencies = tuple_or_value(dependencies) - #self._resolved_dependencies = () - #self._orig_branch_labels = to_tuple(branch_labels, default=()) - #self.branch_labels = set(self._orig_branch_labels) + # self.dependencies = tuple_or_value(dependencies) + # self._resolved_dependencies = () + # self._orig_branch_labels = to_tuple(branch_labels, default=()) + # self.branch_labels = set(self._orig_branch_labels) def __repr__(self): - args = [ - repr(self.revision), - repr(self.down_revision) - ] + args = [repr(self.revision), repr(self.down_revision)] # if self.dependencies: # args.append("dependencies=%r" % self.dependencies) # if self.branch_labels: # args.append("branch_labels=%r" % self.branch_labels) - return "%s(%s)" % ( - self.__class__.__name__, - ", ".join(args) - ) + return "%s(%s)" % (self.__class__.__name__, ", ".join(args)) def add_nextrev(self, revision): self._all_nextrev = self._all_nextrev.union([revision.revision]) @@ -61,7 +54,7 @@ def add_nextrev(self, revision): @property def _all_down_revisions(self): return to_tuple(self.down_revision, default=()) - #+ self._resolved_dependencies + # + self._resolved_dependencies @property def _versioned_down_revisions(self): @@ -110,14 +103,11 @@ def _revision_map(self): self.bases = () self._real_bases = () - #has_branch_labels = set() - #has_depends_on = set() + # has_branch_labels = set() + # has_depends_on = set() for revision in self._generator(): - if revision.revision in rmap: - self.app.log.warn( - "Revision %s is present more than once" % revision.revision - ) + self.app.log.warn("Revision %s is present more than once" % revision.revision) rmap[revision.revision] = revision # if revision.branch_labels: # has_branch_labels.add(revision) @@ -126,9 +116,9 @@ def _revision_map(self): heads.add(revision.revision) _real_heads.add(revision.revision) if revision.is_base: - self.bases += (revision.revision, ) - self._real_bases += (revision.revision, ) - #if revision._is_real_base: + self.bases += (revision.revision,) + self._real_bases += (revision.revision,) + # if revision._is_real_base: # self._real_bases += (revision.revision, ) # for revision in has_branch_labels: @@ -140,9 +130,7 @@ def _revision_map(self): for rev in rmap.values(): for downrev in rev._all_down_revisions: if downrev not in rmap: - self.app.log.warn( - "Revision %s referenced from %s is not present" % - (downrev, rev)) + self.app.log.warn("Revision %s referenced from %s is not present" % (downrev, rev)) down_revision = rmap[downrev] down_revision.add_nextrev(rev) if downrev in rev._versioned_down_revisions: @@ -188,14 +176,14 @@ def get_current_head(self): def _resolve_revision_number(self, rid): self._revision_map - if rid == 'heads': + if rid == "heads": return self._real_heads - elif rid == 'head': + elif rid == "head": current_head = self.get_current_head() if current_head: - return (current_head, ) + return (current_head,) return () - elif rid == 'base' or rid is None: + elif rid == "base" or rid is None: return () else: return to_tuple(rid, default=None) @@ -205,19 +193,15 @@ def _revision_for_ident(self, resolved_id): revision = self._revision_map[resolved_id] except KeyError: # do a partial lookup - revs = [x for x in self._revision_map - if x and x.startswith(resolved_id)] + revs = [x for x in self._revision_map if x and x.startswith(resolved_id)] if not revs: - raise ResolutionError( - "No such revision or branch '%s'" % resolved_id, - resolved_id) + raise ResolutionError("No such revision or branch '%s'" % resolved_id, resolved_id) elif len(revs) > 1: raise ResolutionError( "Multiple revisions start " - "with '%s': %s..." % ( - resolved_id, - ", ".join("'%s'" % r for r in revs[0:3]) - ), resolved_id) + "with '%s': %s..." % (resolved_id, ", ".join("'%s'" % r for r in revs[0:3])), + resolved_id, + ) else: revision = self._revision_map[revs[0]] return revision @@ -236,65 +220,53 @@ def get_revisions(self, rid): return sum([self.get_revisions(id_elem) for id_elem in rid], ()) else: resolved_id = self._resolve_revision_number(rid) - return tuple( - self._revision_for_ident(rev_id) for rev_id in resolved_id) + return tuple(self._revision_for_ident(rev_id) for rev_id in resolved_id) def add_revision(self, revision, _replace=False): map_ = self._revision_map if not _replace and revision.revision in map_: - self.app.log.warn( - "Revision %s is present more than once" % revision.revision) + self.app.log.warn("Revision %s is present more than once" % revision.revision) elif _replace and revision.revision not in map_: raise Exception("revision %s not in map" % revision.revision) map_[revision.revision] = revision if revision.is_base: - self.bases += (revision.revision, ) - self._real_bases += (revision.revision, ) + self.bases += (revision.revision,) + self._real_bases += (revision.revision,) # if revision._is_real_base: # self._real_bases += (revision.revision, ) for downrev in revision._all_down_revisions: if downrev not in map_: - self.app.log.warn( - "Revision %s referenced from %s is not present" - % (downrev, revision) - ) + self.app.log.warn("Revision %s referenced from %s is not present" % (downrev, revision)) map_[downrev].add_nextrev(revision) if revision._is_real_head: self._real_heads = tuple( - head for head in self._real_heads - if head not in - set(revision._all_down_revisions).union([revision.revision]) + head + for head in self._real_heads + if head not in set(revision._all_down_revisions).union([revision.revision]) ) + (revision.revision,) if revision.is_head: self.heads = tuple( - head for head in self.heads - if head not in - set(revision._versioned_down_revisions).union( - [revision.revision]) + head + for head in self.heads + if head not in set(revision._versioned_down_revisions).union([revision.revision]) ) + (revision.revision,) - def iterate_revisions(self, upper, lower, implicit_base=False, - inclusive=False): + def iterate_revisions(self, upper, lower, implicit_base=False, inclusive=False): #: iterate through script revisions, starting at the given upper # revision identifier and ending at the lower. - return self._iterate_revisions( - upper, lower, inclusive=inclusive, implicit_base=implicit_base) + return self._iterate_revisions(upper, lower, inclusive=inclusive, implicit_base=implicit_base) def _get_ancestor_nodes(self, targets, map_=None, check=False): fn = lambda rev: rev._versioned_down_revisions - return self._iterate_related_revisions( - fn, targets, map_=map_, check=check - ) + return self._iterate_related_revisions(fn, targets, map_=map_, check=check) def _get_descendant_nodes(self, targets, map_=None, check=False): fn = lambda rev: rev.nextrev - return self._iterate_related_revisions( - fn, targets, map_=map_, check=check - ) + return self._iterate_related_revisions(fn, targets, map_=map_, check=check) def _iterate_related_revisions(self, fn, targets, map_, check=False): if map_ is None: @@ -303,7 +275,6 @@ def _iterate_related_revisions(self, fn, targets, map_, check=False): seen = set() todo = deque() for target in targets: - todo.append(target) if check: per_target = set() @@ -316,16 +287,14 @@ def _iterate_related_revisions(self, fn, targets, map_, check=False): if rev in seen: continue seen.add(rev) - todo.extend( - map_[rev_id] for rev_id in fn(rev)) + todo.extend(map_[rev_id] for rev_id in fn(rev)) yield rev if check and per_target.intersection(targets).difference([target]): raise RevisionError( - "Requested revision %s overlaps with " - "other requested revisions" % target.revision) + "Requested revision %s overlaps with " "other requested revisions" % target.revision + ) - def _iterate_revisions(self, upper, lower, inclusive=True, - implicit_base=False): + def _iterate_revisions(self, upper, lower, inclusive=True, implicit_base=False): #: iterate revisions from upper to lower. requested_lowers = self.get_revisions(lower) uppers = dedupe_tuple(self.get_revisions(upper)) @@ -336,16 +305,10 @@ def _iterate_revisions(self, upper, lower, inclusive=True, upper_ancestors = set(self._get_ancestor_nodes(uppers, check=True)) if implicit_base and requested_lowers: - lower_ancestors = set( - self._get_ancestor_nodes(requested_lowers) - ) - lower_descendants = set( - self._get_descendant_nodes(requested_lowers) - ) + lower_ancestors = set(self._get_ancestor_nodes(requested_lowers)) + lower_descendants = set(self._get_descendant_nodes(requested_lowers)) base_lowers = set() - candidate_lowers = upper_ancestors.\ - difference(lower_ancestors).\ - difference(lower_descendants) + candidate_lowers = upper_ancestors.difference(lower_ancestors).difference(lower_descendants) for rev in candidate_lowers: for downrev in rev._all_down_revisions: if self._revision_map[downrev] in candidate_lowers: @@ -362,10 +325,8 @@ def _iterate_revisions(self, upper, lower, inclusive=True, lowers = requested_lowers # represents all nodes we will produce - total_space = set( - rev.revision for rev in upper_ancestors).intersection( - rev.revision for rev - in self._get_descendant_nodes(lowers, check=True) + total_space = {rev.revision for rev in upper_ancestors}.intersection( + rev.revision for rev in self._get_descendant_nodes(lowers, check=True) ) if not total_space: @@ -373,27 +334,23 @@ def _iterate_revisions(self, upper, lower, inclusive=True, # organize branch points to be consumed separately from # member nodes - branch_todo = set( - rev for rev in - (self._revision_map[rev] for rev in total_space) - if rev._is_real_branch_point and - len(total_space.intersection(rev._all_nextrev)) > 1 - ) + branch_todo = { + rev + for rev in (self._revision_map[rev] for rev in total_space) + if rev._is_real_branch_point and len(total_space.intersection(rev._all_nextrev)) > 1 + } # it's not possible for any "uppers" to be in branch_todo, # because the ._all_nextrev of those nodes is not in total_space - #assert not branch_todo.intersection(uppers) + # assert not branch_todo.intersection(uppers) - todo = deque( - r for r in uppers if r.revision in total_space) + todo = deque(r for r in uppers if r.revision in total_space) # iterate for total_space being emptied out total_space_modified = True while total_space: - if not total_space_modified: - raise RevisionError( - "Dependency resolution failed; iteration can't proceed") + raise RevisionError("Dependency resolution failed; iteration can't proceed") total_space_modified = False # when everything non-branch pending is consumed, # add to the todo any branch nodes that have no @@ -401,13 +358,10 @@ def _iterate_revisions(self, upper, lower, inclusive=True, if not todo: todo.extendleft( sorted( - ( - rev for rev in branch_todo - if not rev._all_nextrev.intersection(total_space) - ), + (rev for rev in branch_todo if not rev._all_nextrev.intersection(total_space)), # favor "revisioned" branch points before # dependent ones - key=lambda rev: 0 if rev.is_branch_point else 1 + key=lambda rev: 0 if rev.is_branch_point else 1, ) ) branch_todo.difference_update(todo) @@ -419,11 +373,13 @@ def _iterate_revisions(self, upper, lower, inclusive=True, # do depth first for elements within branches, # don't consume any actual branch nodes - todo.extendleft([ - self._revision_map[downrev] - for downrev in reversed(rev._all_down_revisions) - if self._revision_map[downrev] not in branch_todo and - downrev in total_space]) + todo.extendleft( + [ + self._revision_map[downrev] + for downrev in reversed(rev._all_down_revisions) + if self._revision_map[downrev] not in branch_todo and downrev in total_space + ] + ) if not inclusive and rev in requested_lowers: continue diff --git a/emmett/orm/migrations/scripts.py b/emmett/orm/migrations/scripts.py index eee7b6fb..452cf90c 100644 --- a/emmett/orm/migrations/scripts.py +++ b/emmett/orm/migrations/scripts.py @@ -1,22 +1,21 @@ # -*- coding: utf-8 -*- """ - emmett.orm.migrations.scripts - ----------------------------- +emmett.orm.migrations.scripts +----------------------------- - Provides scripts interface for migrations. +Provides scripts interface for migrations. - :copyright: 2014 Giovanni Barillari +:copyright: 2014 Giovanni Barillari - Based on the code of Alembic (https://bitbucket.org/zzzeek/alembic) - :copyright: (c) 2009-2015 by Michael Bayer +Based on the code of Alembic (https://bitbucket.org/zzzeek/alembic) +:copyright: (c) 2009-2015 by Michael Bayer - :license: BSD-3-Clause +:license: BSD-3-Clause """ import os import re import sys - from contextlib import contextmanager from datetime import datetime from importlib import resources @@ -26,30 +25,25 @@ from ...html import asis from . import __name__ as __pkg__ from .base import Migration -from .exceptions import ( - RangeNotAncestorError, MultipleHeads, ResolutionError, RevisionError -) -from .helpers import tuple_rev_as_scalar, format_with_comma +from .exceptions import MultipleHeads, RangeNotAncestorError, ResolutionError, RevisionError +from .helpers import format_with_comma, tuple_rev_as_scalar from .revisions import Revision, RevisionsMap class ScriptDir(object): - _slug_re = re.compile(r'\w+') + _slug_re = re.compile(r"\w+") _default_file_template = "%(rev)s_%(slug)s" def __init__(self, app, migrations_folder=None): self.app = app - self.path = os.path.join( - app.root_path, migrations_folder or 'migrations') + self.path = os.path.join(app.root_path, migrations_folder or "migrations") if not os.path.exists(self.path): os.mkdir(self.path) self.cwd = os.path.dirname(__file__) - self.file_template = self.app.config.migrations.file_template or \ - self._default_file_template - self.truncate_slug_length = \ - self.app.config.migrations.filename_len or 40 + self.file_template = self.app.config.migrations.file_template or self._default_file_template + self.truncate_slug_length = self.app.config.migrations.filename_len or 40 self.revision_map = RevisionsMap(self.app, self._load_revisions) - self.templater = Renoir(path=self.cwd, mode='plain') + self.templater = Renoir(path=self.cwd, mode="plain") def _load_revisions(self): sys.path.insert(0, self.path) @@ -60,10 +54,7 @@ def _load_revisions(self): yield script @contextmanager - def _catch_revision_errors( - self, - ancestor=None, multiple_heads=None, start=None, end=None, - resolution=None): + def _catch_revision_errors(self, ancestor=None, multiple_heads=None, start=None, end=None, resolution=None): try: yield except RangeNotAncestorError as rna: @@ -85,25 +76,20 @@ def _catch_revision_errors( "argument '%(head_arg)s'; please " "specify a specific target revision, " "'@%(head_arg)s' to " - "narrow to a specific head, or 'heads' for all heads") - multiple_heads = multiple_heads % { - "head_arg": end or mh.argument, - "heads": str(mh.heads) - } + "narrow to a specific head, or 'heads' for all heads" + ) + multiple_heads = multiple_heads % {"head_arg": end or mh.argument, "heads": str(mh.heads)} raise Exception(multiple_heads) except ResolutionError as re: if resolution is None: - resolution = "Can't locate revision identified by '%s'" % ( - re.argument - ) + resolution = "Can't locate revision identified by '%s'" % (re.argument) raise Exception(resolution) except RevisionError as err: raise Exception(err.args[0]) def walk_revisions(self, base="base", head="heads"): with self._catch_revision_errors(start=base, end=head): - for rev in self.revision_map.iterate_revisions( - head, base, inclusive=True): + for rev in self.revision_map.iterate_revisions(head, base, inclusive=True): yield rev def get_revision(self, revid): @@ -116,42 +102,41 @@ def get_revisions(self, revid): def get_upgrade_revs(self, destination, current_rev): with self._catch_revision_errors( - ancestor="Destination %(end)s is not a valid upgrade " - "target from current head(s)", end=destination): - revs = self.revision_map.iterate_revisions( - destination, current_rev, implicit_base=True) + ancestor="Destination %(end)s is not a valid upgrade " "target from current head(s)", end=destination + ): + revs = self.revision_map.iterate_revisions(destination, current_rev, implicit_base=True) return reversed(list(revs)) def get_downgrade_revs(self, destination, current_rev): with self._catch_revision_errors( - ancestor="Destination %(end)s is not a valid downgrade " - "target from current head(s)", end=destination): - revs = self.revision_map.iterate_revisions( - current_rev, destination) + ancestor="Destination %(end)s is not a valid downgrade " "target from current head(s)", end=destination + ): + revs = self.revision_map.iterate_revisions(current_rev, destination) return list(revs) def _rev_filename(self, revid, message, creation_date): slug = "_".join(self._slug_re.findall(message or "")).lower() if len(slug) > self.truncate_slug_length: - slug = slug[:self.truncate_slug_length].rsplit('_', 1)[0] + '_' + slug = slug[: self.truncate_slug_length].rsplit("_", 1)[0] + "_" filename = "%s.py" % ( - self.file_template % { - 'rev': revid, - 'slug': slug, - 'year': creation_date.year, - 'month': creation_date.month, - 'day': creation_date.day, - 'hour': creation_date.hour, - 'minute': creation_date.minute, - 'second': creation_date.second + self.file_template + % { + "rev": revid, + "slug": slug, + "year": creation_date.year, + "month": creation_date.month, + "day": creation_date.day, + "hour": creation_date.hour, + "minute": creation_date.minute, + "second": creation_date.second, } ) return filename def _generate_template(self, filename, ctx): - tmpl_source = resources.read_text(__pkg__, 'migration.tmpl') + tmpl_source = resources.read_text(__pkg__, "migration.tmpl") rendered = self.templater._render(source=tmpl_source, context=ctx) - with open(os.path.join(self.path, filename), 'w') as f: + with open(os.path.join(self.path, filename), "w") as f: f.write(rendered) def generate_revision(self, revid, message, head=None, splice=False, **kw): @@ -171,11 +156,13 @@ def generate_revision(self, revid, message, head=None, splice=False, **kw): if head is None: head = "head" - with self._catch_revision_errors(multiple_heads=( - "Multiple heads are present; please specify the head " - "revision on which the new revision should be based, " - "or perform a merge." - )): + with self._catch_revision_errors( + multiple_heads=( + "Multiple heads are present; please specify the head " + "revision on which the new revision should be based, " + "or perform a merge." + ) + ): heads = self.revision_map.get_revisions(head) if len(set(heads)) != len(heads): @@ -190,11 +177,10 @@ def generate_revision(self, revid, message, head=None, splice=False, **kw): if head is not None and not head.is_head: raise Exception( "Revision %s is not a head revision; please specify " - "--splice to create a new branch from this revision" - % head.revision) + "--splice to create a new branch from this revision" % head.revision + ) - down_migration = tuple( - h.revision if h is not None else None for h in heads) + down_migration = tuple(h.revision if h is not None else None for h in heads) down_migration_var = tuple_rev_as_scalar(down_migration) if isinstance(down_migration_var, str): @@ -202,16 +188,16 @@ def generate_revision(self, revid, message, head=None, splice=False, **kw): else: down_migration_var = str(down_migration_var) - template_ctx = dict( - asis=asis, - up_migration=revid, - down_migration=down_migration_var, - creation_date=creation_date, - down_migration_str=", ".join(r for r in down_migration), - message=message if message is not None else ("empty message"), - upgrades=kw.get('upgrades', ['pass']), - downgrades=kw.get('downgrades', ['pass']) - ) + template_ctx = { + "asis": asis, + "up_migration": revid, + "down_migration": down_migration_var, + "creation_date": creation_date, + "down_migration_str": ", ".join(r for r in down_migration), + "message": message if message is not None else ("empty message"), + "upgrades": kw.get("upgrades", ["pass"]), + "downgrades": kw.get("downgrades", ["pass"]), + } self._generate_template(rev_filename, template_ctx) script = Script._from_filename(self, rev_filename) @@ -220,7 +206,7 @@ def generate_revision(self, revid, message, head=None, splice=False, **kw): class Script(Revision): - _only_source_rev_file = re.compile(r'(?!__init__)(.*\.py)$') + _only_source_rev_file = re.compile(r"(?!__init__)(.*\.py)$") migration_class = None path = None @@ -228,10 +214,7 @@ def __init__(self, module, migration_class, path): self.module = module self.migration_class = migration_class self.path = path - super(Script, self).__init__( - self.migration_class.revision, - self.migration_class.revises - ) + super(Script, self).__init__(self.migration_class.revision, self.migration_class.revises) @property def doc(self): @@ -251,22 +234,16 @@ def log_entry(self): " (mergepoint)" if self.is_merge_point else "", ) if self.is_merge_point: - entry += "Merges: %s\n" % (self._format_down_revision(), ) + entry += "Merges: %s\n" % (self._format_down_revision(),) else: - entry += "Parent: %s\n" % (self._format_down_revision(), ) + entry += "Parent: %s\n" % (self._format_down_revision(),) if self.is_branch_point: - entry += "Branches into: %s\n" % ( - format_with_comma(self.nextrev)) + entry += "Branches into: %s\n" % (format_with_comma(self.nextrev)) entry += "Path: %s\n" % (self.path,) - entry += "\n%s\n" % ( - "\n".join( - " %s" % para - for para in self.longdoc.splitlines() - ) - ) + entry += "\n%s\n" % ("\n".join(" %s" % para for para in self.longdoc.splitlines())) return entry def __str__(self): @@ -276,21 +253,17 @@ def __str__(self): " (head)" if self.is_head else "", " (branchpoint)" if self.is_branch_point else "", " (mergepoint)" if self.is_merge_point else "", - self.doc) + self.doc, + ) - def _head_only( - self, include_doc=False, - include_parents=False, tree_indicators=True, - head_indicators=True): + def _head_only(self, include_doc=False, include_parents=False, tree_indicators=True, head_indicators=True): text = self.revision if include_parents: - text = "%s -> %s" % ( - self._format_down_revision(), text) + text = "%s -> %s" % (self._format_down_revision(), text) if head_indicators or tree_indicators: text += "%s%s" % ( " (head)" if self._is_real_head else "", - " (effective head)" if self.is_head and - not self._is_real_head else "" + " (effective head)" if self.is_head and not self._is_real_head else "", ) if tree_indicators: text += "%s%s" % ( @@ -301,17 +274,11 @@ def _head_only( text += ", %s" % self.doc return text - def cmd_format( - self, - verbose, - include_doc=False, - include_parents=False, tree_indicators=True): + def cmd_format(self, verbose, include_doc=False, include_parents=False, tree_indicators=True): if verbose: return self.log_entry else: - return self._head_only( - include_doc, - include_parents, tree_indicators) + return self._head_only(include_doc, include_parents, tree_indicators) def _format_down_revision(self): if not self.down_revision: @@ -325,14 +292,13 @@ def _from_filename(cls, scriptdir, filename): if not py_match: return None py_filename = py_match.group(1) - py_module = py_filename.split('.py')[0] + py_module = py_filename.split(".py")[0] __import__(py_module) module = sys.modules[py_module] - migration_class = getattr(module, 'Migration', None) + migration_class = getattr(module, "Migration", None) if migration_class is None: for v in module.__dict__.values(): if isinstance(v, Migration): migration_class = v break - return Script( - module, migration_class, os.path.join(scriptdir.path, filename)) + return Script(module, migration_class, os.path.join(scriptdir.path, filename)) diff --git a/emmett/orm/migrations/utils.py b/emmett/orm/migrations/utils.py index 6ae72560..38e04599 100644 --- a/emmett/orm/migrations/utils.py +++ b/emmett/orm/migrations/utils.py @@ -1,12 +1,12 @@ # -*- coding: utf-8 -*- """ - emmett.orm.migrations.utilities - ------------------------------- +emmett.orm.migrations.utilities +------------------------------- - Provides some migration utilities. +Provides some migration utilities. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ from ..base import Database @@ -22,7 +22,7 @@ def _load_head_to_meta(self): class RuntimeMigration(MigrationOp): def __init__(self, engine: Engine, ops: UpgradeOps): - super().__init__('runtime', ops, ops.reverse(), 'runtime') + super().__init__("runtime", ops, ops.reverse(), "runtime") self.engine = engine for op in self.upgrade_ops.ops: op.engine = self.engine diff --git a/emmett/orm/models.py b/emmett/orm/models.py index 186612e1..b76def18 100644 --- a/emmett/orm/models.py +++ b/emmett/orm/models.py @@ -1,39 +1,23 @@ # -*- coding: utf-8 -*- """ - emmett.orm.models - ----------------- +emmett.orm.models +----------------- - Provides model layer for Emmet's ORM. +Provides model layer for Emmet's ORM. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ import operator import types - from collections import OrderedDict from functools import reduce from typing import Any, Callable from ..datastructures import sdict -from .apis import ( - compute, - rowattr, - rowmethod, - scope, - belongs_to, - refers_to, - has_one, - has_many -) -from .errors import ( - InsertFailureOnSave, - SaveException, - UpdateFailureOnSave, - ValidationError, - DestroyException -) +from .apis import belongs_to, compute, has_many, has_one, refers_to, rowattr, rowmethod, scope +from .errors import DestroyException, InsertFailureOnSave, SaveException, UpdateFailureOnSave, ValidationError from .helpers import ( Callback, ReferenceData, @@ -45,23 +29,30 @@ typed_row_reference, typed_row_reference_from_record, wrap_scope_on_model, - wrap_virtual_on_model + wrap_virtual_on_model, ) from .objects import Field, StructuredRow -from .wrappers import HasOneWrap, HasOneViaWrap, HasManyWrap, HasManyViaWrap +from .wrappers import HasManyViaWrap, HasManyWrap, HasOneViaWrap, HasOneWrap class MetaModel(type): _inheritable_dict_attrs_ = [ - 'indexes', 'validation', ('fields_rw', {'id': False}), 'foreign_keys', - 'default_values', 'update_values', 'repr_values', - 'form_labels', 'form_info', 'form_widgets' + "indexes", + "validation", + ("fields_rw", {"id": False}), + "foreign_keys", + "default_values", + "update_values", + "repr_values", + "form_labels", + "form_info", + "form_widgets", ] def __new__(cls, name, bases, attrs): new_class = type.__new__(cls, name, bases, attrs) #: collect declared attributes - tablename = attrs.get('tablename') + tablename = attrs.get("tablename") fields = [] vfields = [] computations = [] @@ -83,8 +74,7 @@ def __new__(cls, name, bases, attrs): elif isinstance(value, scope): declared_scopes[key] = value declared_relations = sdict( - belongs=OrderedDict(), refers=OrderedDict(), - hasone=OrderedDict(), hasmany=OrderedDict() + belongs=OrderedDict(), refers=OrderedDict(), hasone=OrderedDict(), hasmany=OrderedDict() ) for ref in belongs_to._references_.values(): for item in ref.reference: @@ -132,25 +122,22 @@ def __new__(cls, name, bases, attrs): all_computations = OrderedDict() all_callbacks = OrderedDict() all_scopes = {} - all_relations = sdict( - belongs=OrderedDict(), refers=OrderedDict(), - hasone=OrderedDict(), hasmany=OrderedDict() - ) + all_relations = sdict(belongs=OrderedDict(), refers=OrderedDict(), hasone=OrderedDict(), hasmany=OrderedDict()) super_vfields = OrderedDict() for base in reversed(new_class.__mro__[1:]): - if hasattr(base, '_declared_fields_'): + if hasattr(base, "_declared_fields_"): all_fields.update(base._declared_fields_) - if hasattr(base, '_declared_virtuals_'): + if hasattr(base, "_declared_virtuals_"): all_vfields.update(base._declared_virtuals_) super_vfields.update(base._declared_virtuals_) - if hasattr(base, '_declared_computations_'): + if hasattr(base, "_declared_computations_"): all_computations.update(base._declared_computations_) - if hasattr(base, '_declared_callbacks_'): + if hasattr(base, "_declared_callbacks_"): all_callbacks.update(base._declared_callbacks_) - if hasattr(base, '_declared_scopes_'): + if hasattr(base, "_declared_scopes_"): all_scopes.update(base._declared_scopes_) for key in list(all_relations): - attrkey = '_declared_' + key + '_ref_' + attrkey = "_declared_" + key + "_ref_" if hasattr(base, attrkey): all_relations[key].update(getattr(base, attrkey)) #: compose 'all' attributes @@ -237,13 +224,13 @@ def __new__(cls): return super(Model, cls).__new__(cls) def __init__(self): - if not hasattr(self, 'migrate'): - self.migrate = self.config.get('migrate', self.db._migrate) - if not hasattr(self, 'format'): + if not hasattr(self, "migrate"): + self.migrate = self.config.get("migrate", self.db._migrate) + if not hasattr(self, "format"): self.format = None - if not hasattr(self, 'primary_keys'): + if not hasattr(self, "primary_keys"): self.primary_keys = [] - self._fieldset_pk = set(self.primary_keys or ['id']) + self._fieldset_pk = set(self.primary_keys or ["id"]) @property def config(self): @@ -253,7 +240,7 @@ def __parse_relation_via(self, via): if via is None: return via rv = sdict() - splitted = via.split('.') + splitted = via.split(".") rv.via = splitted[0] if len(splitted) > 1: rv.field = splitted[1] @@ -290,30 +277,27 @@ def __build_relation_modelname(self, name, relation, singularize): relation.model = relation.model[:-1] def __build_relation_fieldnames(self, relation): - splitted = relation.model.split('.') + splitted = relation.model.split(".") relation.model = splitted[0] if len(splitted) > 1: relation.fields = [splitted[1]] else: if len(self.primary_keys) > 1: - relation.fields = [ - f"{decamelize(self.__class__.__name__)}_{pk}" - for pk in self.primary_keys - ] + relation.fields = [f"{decamelize(self.__class__.__name__)}_{pk}" for pk in self.primary_keys] else: relation.fields = [decamelize(self.__class__.__name__)] def __parse_relation_dict(self, rel, singularize): - if 'scope' in rel.model: - rel.scope = rel.model['scope'] - if 'where' in rel.model: - rel.where = rel.model['where'] - if 'via' in rel.model: - rel.update(self.__parse_relation_via(rel.model['via'])) + if "scope" in rel.model: + rel.scope = rel.model["scope"] + if "where" in rel.model: + rel.where = rel.model["where"] + if "via" in rel.model: + rel.update(self.__parse_relation_via(rel.model["via"])) del rel.model else: - if 'target' in rel.model: - rel.model = rel.model['target'] + if "target" in rel.model: + rel.model = rel.model["target"] if not isinstance(rel.model, str): self.__build_relation_modelname(rel.name, rel, singularize) @@ -323,19 +307,16 @@ def __parse_many_relation(self, item, singularize=True): rv.name = list(item)[0] rv.model = item[rv.name] if isinstance(rv.model, dict): - if 'method' in rv.model: - if 'field' in rv.model: - rv.fields = [rv.model['field']] + if "method" in rv.model: + if "field" in rv.model: + rv.fields = [rv.model["field"]] else: if len(self.primary_keys) > 1: - rv.fields = [ - f"{decamelize(self.__class__.__name__)}_{pk}" - for pk in self.primary_keys - ] + rv.fields = [f"{decamelize(self.__class__.__name__)}_{pk}" for pk in self.primary_keys] else: rv.fields = [decamelize(self.__class__.__name__)] - rv.cast = rv.model.get('cast') - rv.method = rv.model['method'] + rv.cast = rv.model.get("cast") + rv.method = rv.model["method"] del rv.model else: self.__parse_relation_dict(rv, singularize) @@ -348,18 +329,15 @@ def __parse_many_relation(self, item, singularize=True): if rv.model == "self": rv.model = self.__class__.__name__ if not rv.via: - rv.reverse = ( - rv.fields[0] if len(rv.fields) == 1 else - decamelize(self.__class__.__name__) - ) + rv.reverse = rv.fields[0] if len(rv.fields) == 1 else decamelize(self.__class__.__name__) return rv def _define_props_(self): #: create pydal's Field elements self.fields = [] - if not self.primary_keys and 'id' not in self._all_fields_: - idfield = Field('id')._make_field('id', model=self) - setattr(self.__class__, 'id', idfield) + if not self.primary_keys and "id" not in self._all_fields_: + idfield = Field("id")._make_field("id", model=self) + self.__class__.id = idfield self.fields.append(idfield) for name, obj in self._all_fields_.items(): if obj.modelname is not None: @@ -372,10 +350,7 @@ def __find_matching_fk_definition(self, rfields, lfields, rmodel): if not set(rfields).issubset(set(rmodel.primary_keys)): return match for key, val in self.foreign_keys.items(): - if ( - set(val["foreign_fields"]) == set(rmodel.primary_keys) and - set(lfields).issubset(set(val["fields"])) - ): + if set(val["foreign_fields"]) == set(rmodel.primary_keys) and set(lfields).issubset(set(val["fields"])): match = key break return match @@ -383,13 +358,10 @@ def __find_matching_fk_definition(self, rfields, lfields, rmodel): def _define_relations_(self): self._virtual_relations_ = OrderedDict() self._compound_relations_ = {} - bad_args_error = ( - "belongs_to, has_one and has_many " - "only accept strings or dicts as arguments" - ) + bad_args_error = "belongs_to, has_one and has_many " "only accept strings or dicts as arguments" #: belongs_to and refers_to are mapped with 'reference' type Field _references = [] - _reference_keys = ['_all_belongs_ref_', '_all_refers_ref_'] + _reference_keys = ["_all_belongs_ref_", "_all_refers_ref_"] belongs_references = {} belongs_fks = {} for key in _reference_keys: @@ -397,24 +369,18 @@ def _define_relations_(self): _references.append(list(getattr(self, key).values())) else: _references.append([]) - is_belongs, ondelete = True, 'cascade' + is_belongs, ondelete = True, "cascade" for _references_obj in _references: for item in _references_obj: if not isinstance(item, (str, dict)): raise RuntimeError(bad_args_error) reference = self.__parse_belongs_relation(item, ondelete) reference.is_refers = not is_belongs - refmodel = ( - self.db[reference.model]._model_ - if reference.model != self.__class__.__name__ - else self - ) + refmodel = self.db[reference.model]._model_ if reference.model != self.__class__.__name__ else self ref_multi_pk = len(refmodel._fieldset_pk) > 1 fk_def_key, fks_data, multi_fk = None, {}, [] if ref_multi_pk and reference.fk: - fk_def_key = self.__find_matching_fk_definition( - [reference.fk], [reference.name], refmodel - ) + fk_def_key = self.__find_matching_fk_definition([reference.fk], [reference.name], refmodel) if not fk_def_key: raise SyntaxError( f"{self.__class__.__name__}.{reference.name} relation " @@ -445,10 +411,9 @@ def _define_relations_(self): local_fields=fks_data["fields"], foreign_fields=fks_data["foreign_fields"], coupled_fields=[ - (local, fks_data["foreign_fields"][idx]) - for idx, local in enumerate(fks_data["fields"]) + (local, fks_data["foreign_fields"][idx]) for idx, local in enumerate(fks_data["fields"]) ], - is_refers=reference.is_refers + is_refers=reference.is_refers, ) self._compound_relations_[reference.name] = sdict( model=reference.model, @@ -465,40 +430,31 @@ def _define_relations_(self): local_fields=[reference.name], foreign_fields=[reference.fk], coupled_fields=[(reference.name, reference.fk)], - is_refers=reference.is_refers + is_refers=reference.is_refers, ) if not fk_def_key and fks_data: - self.foreign_keys[reference.name] = self.foreign_keys.get( - reference.name - ) or fks_data + self.foreign_keys[reference.name] = self.foreign_keys.get(reference.name) or fks_data for reference in references: if reference.model != self.__class__.__name__: tablename = self.db[reference.model]._tablename else: tablename = self.tablename fieldobj = Field( - ( - f"reference {tablename}" if not ref_multi_pk else - f"reference {tablename}.{reference.fk}" - ), + (f"reference {tablename}" if not ref_multi_pk else f"reference {tablename}.{reference.fk}"), ondelete=reference.on_delete, - _isrefers=not is_belongs + _isrefers=not is_belongs, ) setattr(self.__class__, reference.name, fieldobj) - self.fields.append( - getattr(self, reference.name)._make_field( - reference.name, self - ) - ) + self.fields.append(getattr(self, reference.name)._make_field(reference.name, self)) belongs_references[reference.name] = reference is_belongs = False - ondelete = 'nullify' - setattr(self.__class__, '_belongs_ref_', belongs_references) - setattr(self.__class__, '_belongs_fks_', belongs_fks) + ondelete = "nullify" + self.__class__._belongs_ref_ = belongs_references + self.__class__._belongs_fks_ = belongs_fks #: has_one are mapped with rowattr hasone_references = {} - if hasattr(self, '_all_hasone_ref_'): - for item in getattr(self, '_all_hasone_ref_').values(): + if hasattr(self, "_all_hasone_ref_"): + for item in getattr(self, "_all_hasone_ref_").values(): if not isinstance(item, (str, dict)): raise RuntimeError(bad_args_error) reference = self.__parse_many_relation(item, False) @@ -508,15 +464,13 @@ def _define_relations_(self): else: #: maps has_one('thing'), has_one({'thing': 'othername'}) wrapper = HasOneWrap - self._virtual_relations_[reference.name] = rowattr( - reference.name - )(wrapper(reference)) + self._virtual_relations_[reference.name] = rowattr(reference.name)(wrapper(reference)) hasone_references[reference.name] = reference - setattr(self.__class__, '_hasone_ref_', hasone_references) + self.__class__._hasone_ref_ = hasone_references #: has_many are mapped with rowattr hasmany_references = {} - if hasattr(self, '_all_hasmany_ref_'): - for item in getattr(self, '_all_hasmany_ref_').values(): + if hasattr(self, "_all_hasmany_ref_"): + for item in getattr(self, "_all_hasmany_ref_").values(): if not isinstance(item, (str, dict)): raise RuntimeError(bad_args_error) reference = self.__parse_many_relation(item) @@ -526,11 +480,9 @@ def _define_relations_(self): else: #: maps has_many('things'), has_many({'things': 'othername'}) wrapper = HasManyWrap - self._virtual_relations_[reference.name] = rowattr( - reference.name - )(wrapper(reference)) + self._virtual_relations_[reference.name] = rowattr(reference.name)(wrapper(reference)) hasmany_references[reference.name] = reference - setattr(self.__class__, '_hasmany_ref_', hasmany_references) + self.__class__._hasmany_ref_ = hasmany_references self.__define_fks() def __define_fks(self): @@ -538,15 +490,8 @@ def __define_fks(self): implicit_defs = {} grouped_rels = {} for rname, rel in self._belongs_ref_.items(): - rmodel = ( - self.db[rel.model]._model_ - if rel.model != self.__class__.__name__ - else self - ) - if ( - not rmodel.primary_keys and - getattr(rmodel, list(rmodel._fieldset_pk)[0]).type == 'id' - ): + rmodel = self.db[rel.model]._model_ if rel.model != self.__class__.__name__ else self + if not rmodel.primary_keys and getattr(rmodel, list(rmodel._fieldset_pk)[0]).type == "id": continue if len(rmodel._fieldset_pk) > 1: match = self.__find_matching_fk_definition([rel.fk], [rel.name], rmodel) @@ -557,46 +502,43 @@ def __define_fks(self): "needs to be defined into `foreign_keys`." ) crels = grouped_rels[match] = grouped_rels.get( - match, { - 'rels': {}, - 'table': rmodel.tablename, - 'on_delete': self.foreign_keys[match].get( - "on_delete", "cascade" - ) - } + match, + { + "rels": {}, + "table": rmodel.tablename, + "on_delete": self.foreign_keys[match].get("on_delete", "cascade"), + }, ) - crels['rels'][rname] = rel + crels["rels"][rname] = rel else: # NOTE: we need this since pyDAL doesn't support id/refs types != int implicit_defs[rname] = { - 'table': rmodel.tablename, - 'fields_local': [rname], - 'fields_foreign': [rel.fk], - 'on_delete': Field._internal_delete[rel.on_delete] + "table": rmodel.tablename, + "fields_local": [rname], + "fields_foreign": [rel.fk], + "on_delete": Field._internal_delete[rel.on_delete], } - for rname, rel in implicit_defs.items(): - constraint_name = self.__create_fk_contraint_name( - rel['table'], *rel['fields_local'] - ) + for _, rel in implicit_defs.items(): + constraint_name = self.__create_fk_contraint_name(rel["table"], *rel["fields_local"]) self._foreign_keys_[constraint_name] = {**rel} for crels in grouped_rels.values(): constraint_name = self.__create_fk_contraint_name( - crels['table'], *[rel.name for rel in crels['rels'].values()] + crels["table"], *[rel.name for rel in crels["rels"].values()] ) self._foreign_keys_[constraint_name] = { - 'table': crels['table'], - 'fields_local': [rel.name for rel in crels['rels'].values()], - 'fields_foreign': [rel.fk for rel in crels['rels'].values()], - 'on_delete': Field._internal_delete[crels['on_delete']] + "table": crels["table"], + "fields_local": [rel.name for rel in crels["rels"].values()], + "fields_foreign": [rel.fk for rel in crels["rels"].values()], + "on_delete": Field._internal_delete[crels["on_delete"]], } def _define_virtuals_(self): self._all_rowattrs_ = {} self._all_rowmethods_ = {} self._super_rowmethods_ = {} - err = 'rowattr or rowmethod cannot have the name of an existent field!' + err = "rowattr or rowmethod cannot have the name of an existent field!" field_names = [field.name for field in self.fields] - for attr in ['_virtual_relations_', '_all_virtuals_']: + for attr in ["_virtual_relations_", "_all_virtuals_"]: for obj in getattr(self, attr, {}).values(): if obj.field_name in field_names: raise RuntimeError(err) @@ -616,47 +558,36 @@ def _define_virtuals_(self): def _set_row_persistence_id(self, row, ret): row.id = ret.id - object.__setattr__(row, '_concrete', True) + object.__setattr__(row, "_concrete", True) def _set_row_persistence_pk(self, row, ret): row[self.primary_keys[0]] = ret[self.primary_keys[0]] - object.__setattr__(row, '_concrete', True) + object.__setattr__(row, "_concrete", True) def _set_row_persistence_pks(self, row, ret): for field_name in self.primary_keys: row[field_name] = ret[field_name] - object.__setattr__(row, '_concrete', True) + object.__setattr__(row, "_concrete", True) def _unset_row_persistence(self, row): for field_name in self._fieldset_pk: row[field_name] = None - object.__setattr__(row, '_concrete', False) + object.__setattr__(row, "_concrete", False) def _build_rowclass_(self): #: build helpers for rows save_excluded_fields = ( - set( - field.name for field in self.fields if - getattr(field, "type", None) == "id" - ) | - set(self._all_rowattrs_.keys()) | - set(self._all_rowmethods_.keys()) + {field.name for field in self.fields if getattr(field, "type", None) == "id"} + | set(self._all_rowattrs_.keys()) + | set(self._all_rowmethods_.keys()) ) - self._fieldset_initable = set([ - field.name for field in self.fields - ]) - save_excluded_fields - self._fieldset_editable = set([ - field.name for field in self.fields - ]) - save_excluded_fields - self._fieldset_pk + self._fieldset_initable = {field.name for field in self.fields} - save_excluded_fields + self._fieldset_editable = {field.name for field in self.fields} - save_excluded_fields - self._fieldset_pk self._fieldset_all = self._fieldset_initable | self._fieldset_pk - self._fieldset_update = set([ - field.name for field in self.fields - if getattr(field, "update", None) is not None - ]) & self._fieldset_editable - self._relations_wrapset = ( - set(self._belongs_fks_.keys()) - - set(self._compound_relations_.keys()) - ) + self._fieldset_update = { + field.name for field in self.fields if getattr(field, "update", None) is not None + } & self._fieldset_editable + self._relations_wrapset = set(self._belongs_fks_.keys()) - set(self._compound_relations_.keys()) if not self.primary_keys: self._set_row_persistence = self._set_row_persistence_id elif len(self.primary_keys) == 1: @@ -665,22 +596,12 @@ def _build_rowclass_(self): self._set_row_persistence = self._set_row_persistence_pks #: create dynamic row class clsname = self.__class__.__name__ + "Row" - attrs = {'_model': self} + attrs = {"_model": self} attrs.update({k: RowFieldMapper(k) for k in self._fieldset_all}) - attrs.update({ - k: RowVirtualMapper(k, v) - for k, v in self._all_rowattrs_.items() - }) + attrs.update({k: RowVirtualMapper(k, v) for k, v in self._all_rowattrs_.items()}) attrs.update(self._all_rowmethods_) - attrs.update({ - k: RowRelationMapper(self.db, self._belongs_ref_[k]) - for k in self._relations_wrapset - }) - attrs.update({ - k: RowCompoundRelationMapper( - self.db, data - ) for k, data in self._compound_relations_.items() - }) + attrs.update({k: RowRelationMapper(self.db, self._belongs_ref_[k]) for k in self._relations_wrapset}) + attrs.update({k: RowCompoundRelationMapper(self.db, data) for k, data in self._compound_relations_.items()}) self._rowclass_ = type(clsname, (StructuredRow,), attrs) globals()[clsname] = self._rowclass_ @@ -714,7 +635,7 @@ def __define_validation(self): def __define_access(self): for field, value in self.fields_rw.items(): - if field == 'id' and field not in self.table: + if field == "id" and field not in self.table: continue if isinstance(value, (tuple, list)): readable, writable = value @@ -737,14 +658,12 @@ def __define_representation(self): self.table[field].represent = value def __define_computations(self): - err = 'computations should have the name of an existing field to compute!' + err = "computations should have the name of an existing field to compute!" field_names = [field.name for field in self.fields] for obj in self._all_computations_.values(): if obj.field_name not in field_names: raise RuntimeError(err) - self.table[obj.field_name].compute = ( - lambda row, obj=obj, self=self: obj.compute(self, row) - ) + self.table[obj.field_name].compute = lambda row, obj=obj, self=self: obj.compute(self, row) def __define_callbacks(self): for obj in self._all_callbacks_.values(): @@ -766,72 +685,63 @@ def __define_callbacks(self): "_after_commit_update", "_after_commit_delete", "_after_commit_save", - "_after_commit_destroy" + "_after_commit_destroy", ]: - getattr(self.table, t).append( - lambda a, obj=obj, self=self: obj.f(self, a) - ) + getattr(self.table, t).append(lambda a, obj=obj, self=self: obj.f(self, a)) else: - getattr(self.table, t).append( - lambda a, b, obj=obj, self=self: obj.f(self, a, b) - ) + getattr(self.table, t).append(lambda a, b, obj=obj, self=self: obj.f(self, a, b)) def __define_scopes(self): self._scopes_ = {} for obj in self._all_scopes_.values(): self._scopes_[obj.name] = obj if not hasattr(self.__class__, obj.name): - setattr( - self.__class__, obj.name, - classmethod(wrap_scope_on_model(obj.f)) - ) + setattr(self.__class__, obj.name, classmethod(wrap_scope_on_model(obj.f))) def __prepend_table_name(self, name, ns): - return '%s_%s__%s' % (self.tablename, ns, name) + return "%s_%s__%s" % (self.tablename, ns, name) def __create_index_name(self, *values): components = [] for value in values: - components.append(value.replace('_', '')) - return self.__prepend_table_name("_".join(components), 'widx') + components.append(value.replace("_", "")) + return self.__prepend_table_name("_".join(components), "widx") def __create_fk_contraint_name(self, *values): components = [] for value in values: - components.append(value.replace('_', '')) - return self.__prepend_table_name("fk__" + "_".join(components), 'ecnt') + components.append(value.replace("_", "")) + return self.__prepend_table_name("fk__" + "_".join(components), "ecnt") def __parse_index_dict(self, value): rv = {} - fields = value.get('fields') or [] + fields = value.get("fields") or [] if not isinstance(fields, (list, tuple)): fields = [fields] - rv['fields'] = fields + rv["fields"] = fields where_query = None - where_cond = value.get('where') + where_cond = value.get("where") if callable(where_cond): where_query = where_cond(self.__class__) if where_query: - rv['where'] = where_query + rv["where"] = where_query expressions = [] - expressions_cond = value.get('expressions') + expressions_cond = value.get("expressions") if callable(expressions_cond): expressions = expressions_cond(self.__class__) if not isinstance(expressions, (tuple, list)): expressions = [expressions] - rv['expressions'] = expressions - rv['unique'] = value.get('unique', False) + rv["expressions"] = expressions + rv["unique"] = value.get("unique", False) return rv def __define_indexes(self): self._indexes_ = {} #: auto-define indexes based on fields for field in self.fields: - if getattr(field, 'unique', False): - idx_name = self.__prepend_table_name(f'{field.name}_unique', 'widx') - idx_dict = self.__parse_index_dict( - {'fields': [field.name], 'unique': True} - ) + if getattr(field, "unique", False): + idx_name = self.__prepend_table_name(f"{field.name}_unique", "widx") + idx_dict = self.__parse_index_dict({"fields": [field.name], "unique": True}) self._indexes_[idx_name] = idx_dict #: parse user-defined fields for key, value in self.indexes.items(): @@ -841,14 +751,14 @@ def __define_indexes(self): if not isinstance(key, tuple): key = [key] if any(field not in self.table for field in key): - raise SyntaxError(f'Invalid field specified in indexes: {key}') + raise SyntaxError(f"Invalid field specified in indexes: {key}") idx_name = self.__create_index_name(*key) - idx_dict = {'fields': key, 'expressions': [], 'unique': False} + idx_dict = {"fields": key, "expressions": [], "unique": False} elif isinstance(value, dict): - idx_name = self.__prepend_table_name(key, 'widx') + idx_name = self.__prepend_table_name(key, "widx") idx_dict = self.__parse_index_dict(value) else: - raise SyntaxError('Values in indexes dict should be booleans or dicts') + raise SyntaxError("Values in indexes dict should be booleans or dicts") self._indexes_[idx_name] = idx_dict def _row_record_query_id(self, row): @@ -858,35 +768,30 @@ def _row_record_query_pk(self, row): return self.table[self.primary_keys[0]] == row[self.primary_keys[0]] def _row_record_query_pks(self, row): - return reduce( - operator.and_, [self.table[pk] == row[pk] for pk in self.primary_keys] - ) + return reduce(operator.and_, [self.table[pk] == row[pk] for pk in self.primary_keys]) def __define_query_helpers(self): if not self.primary_keys: - self._query_id = self.table.id != None # noqa + self._query_id = self.table.id != None # noqa self._query_row = self._row_record_query_id self._order_by_id_asc = self.table.id self._order_by_id_desc = ~self.table.id elif len(self.primary_keys) == 1: - self._query_id = self.table[self.primary_keys[0]] != None # noqa + self._query_id = self.table[self.primary_keys[0]] != None # noqa self._query_row = self._row_record_query_pk self._order_by_id_asc = self.table[self.primary_keys[0]] self._order_by_id_desc = ~self.table[self.primary_keys[0]] else: self._query_id = reduce( - operator.and_, [ - self.table[key] != None # noqa + operator.and_, + [ + self.table[key] != None # noqa for key in self.primary_keys - ] + ], ) self._query_row = self._row_record_query_pks - self._order_by_id_asc = reduce( - operator.or_, [self.table[key] for key in self.primary_keys] - ) - self._order_by_id_desc = reduce( - operator.or_, [~self.table[key] for key in self.primary_keys] - ) + self._order_by_id_asc = reduce(operator.or_, [self.table[key] for key in self.primary_keys]) + self._order_by_id_desc = reduce(operator.or_, [~self.table[key] for key in self.primary_keys]) def __define_form_utils(self): #: labels @@ -924,27 +829,26 @@ def new(cls, **attributes): if callable(val): val = val() rowattrs[field] = val - for field in (inst.primary_keys or ["id"]): + for field in inst.primary_keys or ["id"]: if inst.table[field].type == "id": rowattrs[field] = None for field in set(inst._compound_relations_.keys()) & attrset: reldata = inst._compound_relations_[field] for local_field, foreign_field in reldata.coupled_fields: rowattrs[local_field] = attributes[field][foreign_field] - rv = inst._rowclass_( - rowattrs, __concrete=False, - **{k: attributes[k] for k in attrset - set(rowattrs)} - ) - rv._fields.update({ - field: attributes[field] if not attributes[field] else ( - typed_row_reference_from_record( - attributes[field], inst.db[inst._belongs_fks_[field].model]._model_ - ) if isinstance(attributes[field], StructuredRow) else - typed_row_reference( - attributes[field], inst.db[inst._belongs_fks_[field].model] + rv = inst._rowclass_(rowattrs, __concrete=False, **{k: attributes[k] for k in attrset - set(rowattrs)}) + rv._fields.update( + { + field: attributes[field] + if not attributes[field] + else ( + typed_row_reference_from_record(attributes[field], inst.db[inst._belongs_fks_[field].model]._model_) + if isinstance(attributes[field], StructuredRow) + else typed_row_reference(attributes[field], inst.db[inst._belongs_fks_[field].model]) ) - ) for field in inst._relations_wrapset & attrset - }) + for field in inst._relations_wrapset & attrset + } + ) return rv @classmethod @@ -965,7 +869,7 @@ def validate(cls, row, write_values: bool = False): inst, errors = cls._instance_(), sdict() for field_name in inst._fieldset_all: field = inst.table[field_name] - default = getattr(field, 'default') + default = getattr(field, "default") if callable(default): default = default() value = row.get(field_name, default) @@ -988,17 +892,11 @@ def all(cls): @classmethod def first(cls): - return cls.all().select( - orderby=cls._instance_()._order_by_id_asc, - limitby=(0, 1) - ).first() + return cls.all().select(orderby=cls._instance_()._order_by_id_asc, limitby=(0, 1)).first() @classmethod def last(cls): - return cls.all().select( - orderby=cls._instance_()._order_by_id_desc, - limitby=(0, 1) - ).first() + return cls.all().select(orderby=cls._instance_()._order_by_id_desc, limitby=(0, 1)).first() @classmethod def get(cls, *args, **kwargs): @@ -1010,40 +908,32 @@ def get(cls, *args, **kwargs): elif isinstance(args[0], dict) and not kwargs: return cls.table(**args[0]) if len(args) != len(inst._fieldset_pk): - raise SyntaxError( - f"{cls.__name__}.get requires the same number of arguments " - "as its primary keys" - ) + raise SyntaxError(f"{cls.__name__}.get requires the same number of arguments " "as its primary keys") pks = inst.primary_keys or ["id"] - return cls.table( - **{pks[idx]: val for idx, val in enumerate(args)} - ) + return cls.table(**{pks[idx]: val for idx, val in enumerate(args)}) return cls.table(**kwargs) - @rowmethod('update_record') + @rowmethod("update_record") def _update_record(self, row, skip_callbacks=False, **fields): newfields = fields or dict(row) for field_name in set(newfields.keys()) - self._fieldset_editable: del newfields[field_name] - res = self.db( - self._query_row(row), ignore_common_filters=True - ).update(skip_callbacks=skip_callbacks, **newfields) + res = self.db(self._query_row(row), ignore_common_filters=True).update( + skip_callbacks=skip_callbacks, **newfields + ) if res: row.update(self.get(**{key: row[key] for key in self._fieldset_pk})) return row - @rowmethod('delete_record') + @rowmethod("delete_record") def _delete_record(self, row, skip_callbacks=False): return self.db(self._query_row(row)).delete(skip_callbacks=skip_callbacks) - @rowmethod('refresh') + @rowmethod("refresh") def _row_refresh(self, row) -> bool: if not row._concrete: return False - last = self.db(self._query_row(row)).select( - limitby=(0, 1), - orderby_on_limitby=False - ).first() + last = self.db(self._query_row(row)).select(limitby=(0, 1), orderby_on_limitby=False).first() if not last: return False row._fields.update(last._fields) @@ -1051,19 +941,12 @@ def _row_refresh(self, row) -> bool: row._changes.clear() return True - @rowmethod('save') - def _row_save( - self, - row, - raise_on_error: bool = False, - skip_callbacks: bool = False - ) -> bool: + @rowmethod("save") + def _row_save(self, row, raise_on_error: bool = False, skip_callbacks: bool = False) -> bool: if row._concrete: if set(row._changes.keys()) & self._fieldset_pk: if raise_on_error: - raise SaveException( - 'Cannot save a record with altered primary key(s)' - ) + raise SaveException("Cannot save a record with altered primary key(s)") return False for field_name in self._fieldset_update: val = self.table[field_name].update @@ -1076,9 +959,9 @@ def _row_save( raise ValidationError return False if row._concrete: - res = self.db( - self._query_row(row), ignore_common_filters=True - )._update_from_save(self, row, skip_callbacks=skip_callbacks) + res = self.db(self._query_row(row), ignore_common_filters=True)._update_from_save( + self, row, skip_callbacks=skip_callbacks + ) if not res: if raise_on_error: raise UpdateFailureOnSave @@ -1089,26 +972,18 @@ def _row_save( if raise_on_error: raise InsertFailureOnSave return False - extra_changes = { - key: row._changes[key] - for key in set(row._changes.keys()) & set(row.__dict__.keys()) - } + extra_changes = {key: row._changes[key] for key in set(row._changes.keys()) & set(row.__dict__.keys())} row._changes.clear() row._changes.update(extra_changes) return True - @rowmethod('destroy') - def _row_destroy( - self, - row, - raise_on_error: bool = False, - skip_callbacks: bool = False - ) -> bool: + @rowmethod("destroy") + def _row_destroy(self, row, raise_on_error: bool = False, skip_callbacks: bool = False) -> bool: if not row._concrete: return False - res = self.db( - self._query_row(row), ignore_common_filters=True - )._delete_from_destroy(self, row, skip_callbacks=skip_callbacks) + res = self.db(self._query_row(row), ignore_common_filters=True)._delete_from_destroy( + self, row, skip_callbacks=skip_callbacks + ) if not res: if raise_on_error: raise DestroyException @@ -1192,9 +1067,9 @@ def __get__(self, obj, objtype=None): pks = {fk: obj[lk] for lk, fk in self.fields} key = (self.field, *pks.values()) if key not in obj._compound_rels: - obj._compound_rels[key] = RowReferenceMulti(pks, self.table) if all( - v is not None for v in pks.values() - ) else None + obj._compound_rels[key] = ( + RowReferenceMulti(pks, self.table) if all(v is not None for v in pks.values()) else None + ) return obj._compound_rels[key] def __delete__(self, obj): diff --git a/emmett/orm/objects.py b/emmett/orm/objects.py index 5c179f0a..975803c6 100644 --- a/emmett/orm/objects.py +++ b/emmett/orm/objects.py @@ -1,12 +1,12 @@ # -*- coding: utf-8 -*- """ - emmett.orm.objects - ------------------ +emmett.orm.objects +------------------ - Provides pyDAL objects implementation for Emmett. +Provides pyDAL objects implementation for Emmett. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ import copy @@ -14,47 +14,44 @@ import decimal import operator import types - from collections import OrderedDict, defaultdict from enum import Enum from functools import reduce from typing import Any, Dict, Optional +from emmett_core.utils import cachedprop from pydal.objects import ( - Table as _Table, + Expression, Field as _Field, - Set as _Set, - Row as _Row, - Rows as _Rows, IterRows as _IterRows, Query, - Expression + Row as _Row, + Rows as _Rows, + Set as _Set, + Table as _Table, ) from ..ctx import current from ..datastructures import sdict from ..html import tag from ..serializers import xml_encode -from ..utils import cachedprop from ..validators import ValidateFromDict from .helpers import ( - RelationBuilder, GeoFieldWrapper, + RelationBuilder, RowReferenceMixin, + typed_row_reference_from_record, wrap_scope_on_set, - typed_row_reference_from_record ) + type_int = int class Table(_Table): def __init__(self, db, tablename, *fields, **kwargs): - _primary_keys, _notnulls = list(kwargs.get('primarykey', [])), {} - _notnulls = { - field.name: field.notnull - for field in fields if hasattr(field, 'notnull') - } + _primary_keys, _notnulls = list(kwargs.get("primarykey", [])), {} + _notnulls = {field.name: field.notnull for field in fields if hasattr(field, "notnull")} super(Table, self).__init__(db, tablename, *fields, **kwargs) self._before_save = [] self._after_save = [] @@ -75,57 +72,32 @@ def __init__(self, db, tablename, *fields, **kwargs): self._unique_fields_validation_ = {} self._primary_keys = _primary_keys #: avoid pyDAL mess in ops and migrations - if len(self._primary_keys) == 1 and getattr(self, '_primarykey', None): + if len(self._primary_keys) == 1 and getattr(self, "_primarykey", None): del self._primarykey for key in self._primary_keys: self[key].notnull = _notnulls[key] - if not hasattr(self, '_id'): + if not hasattr(self, "_id"): self._id = None @cachedprop def _has_commit_insert_callbacks(self): - return any([ - self._before_commit, - self._after_commit, - self._before_commit_insert, - self._after_commit_insert - ]) + return any([self._before_commit, self._after_commit, self._before_commit_insert, self._after_commit_insert]) @cachedprop def _has_commit_update_callbacks(self): - return any([ - self._before_commit, - self._after_commit, - self._before_commit_update, - self._after_commit_update - ]) + return any([self._before_commit, self._after_commit, self._before_commit_update, self._after_commit_update]) @cachedprop def _has_commit_delete_callbacks(self): - return any([ - self._before_commit, - self._after_commit, - self._before_commit_delete, - self._after_commit_delete - ]) + return any([self._before_commit, self._after_commit, self._before_commit_delete, self._after_commit_delete]) @cachedprop def _has_commit_save_callbacks(self): - return any([ - self._before_commit, - self._after_commit, - self._before_commit_save, - self._after_commit_save - ]) + return any([self._before_commit, self._after_commit, self._before_commit_save, self._after_commit_save]) @cachedprop def _has_commit_destroy_callbacks(self): - return any([ - self._before_commit, - self._after_commit, - self._before_commit_destroy, - self._after_commit_destroy - ]) + return any([self._before_commit, self._after_commit, self._before_commit_destroy, self._after_commit_destroy]) def _create_references(self): self._referenced_by = [] @@ -145,14 +117,7 @@ def insert(self, skip_callbacks=False, **fields): if self._has_commit_insert_callbacks: txn = self._db._adapter.top_transaction() if txn: - txn._add_op(TransactionOp( - TransactionOps.insert, - self, - TransactionOpContext( - values=row, - ret=ret - ) - )) + txn._add_op(TransactionOp(TransactionOps.insert, self, TransactionOpContext(values=row, ret=ret))) if ret: for f in self._after_insert: f(row, ret) @@ -177,16 +142,13 @@ def _insert_from_save(self, row, skip_callbacks=False): if self._has_commit_save_callbacks: txn = self._db._adapter.top_transaction() if txn: - txn._add_op(TransactionOp( - TransactionOps.save, - self, - TransactionOpContext( - values=fields, - ret=ret, - row=row.clone_changed(), - changes=row.changes + txn._add_op( + TransactionOp( + TransactionOps.save, + self, + TransactionOpContext(values=fields, ret=ret, row=row.clone_changed(), changes=row.changes), ) - )) + ) if row._concrete: for f in self._after_save: f(row) @@ -194,58 +156,48 @@ def _insert_from_save(self, row, skip_callbacks=False): class Field(_Field): - _internal_types = { - 'integer': 'int', - 'double': 'float', - 'boolean': 'bool', - 'list:integer': 'list:int' - } + _internal_types = {"integer": "int", "double": "float", "boolean": "bool", "list:integer": "list:int"} _pydal_types = { - 'int': 'integer', - 'bool': 'boolean', - 'list:int': 'list:integer', - } - _internal_delete = { - 'cascade': 'CASCADE', 'nullify': 'SET NULL', 'nothing': 'NO ACTION' + "int": "integer", + "bool": "boolean", + "list:int": "list:integer", } + _internal_delete = {"cascade": "CASCADE", "nullify": "SET NULL", "nothing": "NO ACTION"} _inst_count_ = 0 _obj_created_ = False - def __init__(self, type='string', *args, **kwargs): + def __init__(self, type="string", *args, **kwargs): self.modelname = None #: convert type self._type = self._internal_types.get(type, type) #: convert 'rw' -> 'readable', 'writeable' - if 'rw' in kwargs: - _rw = kwargs.pop('rw') + if "rw" in kwargs: + _rw = kwargs.pop("rw") if isinstance(_rw, (tuple, list)): read, write = _rw else: read = write = _rw - kwargs['readable'] = read - kwargs['writable'] = write + kwargs["readable"] = read + kwargs["writable"] = write #: convert 'info' -> 'comment' - _info = kwargs.pop('info', None) + _info = kwargs.pop("info", None) if _info: - kwargs['comment'] = _info + kwargs["comment"] = _info #: convert ondelete parameter - _ondelete = kwargs.get('ondelete') + _ondelete = kwargs.get("ondelete") if _ondelete: if _ondelete not in list(self._internal_delete): - raise SyntaxError( - 'Field ondelete should be set on %s, %s or %s' % - list(self._internal_delete) - ) - kwargs['ondelete'] = self._internal_delete[_ondelete] + raise SyntaxError("Field ondelete should be set on %s, %s or %s" % list(self._internal_delete)) + kwargs["ondelete"] = self._internal_delete[_ondelete] #: process 'refers_to' fields - self._isrefers = kwargs.pop('_isrefers', None) + self._isrefers = kwargs.pop("_isrefers", None) #: get auto validation preferences - self._auto_validation = kwargs.pop('auto_validation', True) + self._auto_validation = kwargs.pop("auto_validation", True) #: intercept validation (will be processed by `_make_field`) self._requires = {} self._custom_requires = [] - if 'validation' in kwargs: - _validation = kwargs.pop('validation') + if "validation" in kwargs: + _validation = kwargs.pop("validation") if isinstance(_validation, dict): self._requires = _validation else: @@ -255,10 +207,10 @@ def __init__(self, type='string', *args, **kwargs): self._validation = {} self._vparser = ValidateFromDict() #: ensure 'length' is an integer - if 'length' in kwargs: - kwargs['length'] = int(kwargs['length']) + if "length" in kwargs: + kwargs["length"] = int(kwargs["length"]) #: store args and kwargs for `_make_field` - self._ormkw = kwargs.pop('_kw', {}) + self._ormkw = kwargs.pop("_kw", {}) self._args = args self._kwargs = kwargs #: increase creation counter (used to keep order of fields) @@ -267,42 +219,36 @@ def __init__(self, type='string', *args, **kwargs): def _default_validation(self): rv = {} - auto_types = [ - 'int', 'float', 'date', 'time', 'datetime', 'json' - ] + auto_types = ["int", "float", "date", "time", "datetime", "json"] if self._type in auto_types: - rv['is'] = self._type - elif self._type.startswith('decimal'): - rv['is'] = 'decimal' - elif self._type == 'jsonb': - rv['is'] = 'json' - if self._type == 'bigint': - rv['is'] = 'int' - if self._type == 'bool': - rv['in'] = (False, True) - if self._type in ['string', 'text', 'password']: - rv['len'] = {'lte': self.length} - if self._type == 'password': - rv['len']['gte'] = 6 - rv['crypt'] = True - if self._type == 'list:int': - rv['is'] = 'list:int' - if ( - self.notnull or self._type.startswith('reference') or - self._type.startswith('list:reference') - ): - rv['presence'] = True + rv["is"] = self._type + elif self._type.startswith("decimal"): + rv["is"] = "decimal" + elif self._type == "jsonb": + rv["is"] = "json" + if self._type == "bigint": + rv["is"] = "int" + if self._type == "bool": + rv["in"] = (False, True) + if self._type in ["string", "text", "password"]: + rv["len"] = {"lte": self.length} + if self._type == "password": + rv["len"]["gte"] = 6 + rv["crypt"] = True + if self._type == "list:int": + rv["is"] = "list:int" + if self.notnull or self._type.startswith("reference") or self._type.startswith("list:reference"): + rv["presence"] = True if not self.notnull and self._isrefers is True: - rv['allow'] = 'empty' + rv["allow"] = "empty" if self.unique: - rv['unique'] = True + rv["unique"] = True return rv def _parse_validation(self): for key in list(self._requires): self._validation[key] = self._requires[key] - self.requires = self._vparser(self, self._validation) + \ - self._custom_requires + self.requires = self._vparser(self, self._validation) + self._custom_requires #: `_make_field` will be called by `Model` class or `Form` class # it will make internal Field class compatible with the pyDAL's one @@ -339,113 +285,102 @@ def __str__(self): return object.__str__(self) def __repr__(self): - if self.modelname and hasattr(self, 'name'): - return "<%s.%s (%s) field>" % (self.modelname, self.name, - self._type) + if self.modelname and hasattr(self, "name"): + return "<%s.%s (%s) field>" % (self.modelname, self.name, self._type) return super(Field, self).__repr__() @classmethod def string(cls, *args, **kwargs): - return cls('string', *args, **kwargs) + return cls("string", *args, **kwargs) @classmethod def int(cls, *args, **kwargs): - return cls('int', *args, **kwargs) + return cls("int", *args, **kwargs) @classmethod def bigint(cls, *args, **kwargs): - return cls('bigint', *args, **kwargs) + return cls("bigint", *args, **kwargs) @classmethod def float(cls, *args, **kwargs): - return cls('float', *args, **kwargs) + return cls("float", *args, **kwargs) @classmethod def text(cls, *args, **kwargs): - return cls('text', *args, **kwargs) + return cls("text", *args, **kwargs) @classmethod def bool(cls, *args, **kwargs): - return cls('bool', *args, **kwargs) + return cls("bool", *args, **kwargs) @classmethod def blob(cls, *args, **kwargs): - return cls('blob', *args, **kwargs) + return cls("blob", *args, **kwargs) @classmethod def date(cls, *args, **kwargs): - return cls('date', *args, **kwargs) + return cls("date", *args, **kwargs) @classmethod def time(cls, *args, **kwargs): - return cls('time', *args, **kwargs) + return cls("time", *args, **kwargs) @classmethod def datetime(cls, *args, **kwargs): - return cls('datetime', *args, **kwargs) + return cls("datetime", *args, **kwargs) @classmethod def decimal(cls, precision, scale, *args, **kwargs): - return cls('decimal({},{})'.format(precision, scale), *args, **kwargs) + return cls("decimal({},{})".format(precision, scale), *args, **kwargs) @classmethod def json(cls, *args, **kwargs): - return cls('json', *args, **kwargs) + return cls("json", *args, **kwargs) @classmethod def jsonb(cls, *args, **kwargs): - return cls('jsonb', *args, **kwargs) + return cls("jsonb", *args, **kwargs) @classmethod def password(cls, *args, **kwargs): - return cls('password', *args, **kwargs) + return cls("password", *args, **kwargs) @classmethod def upload(cls, *args, **kwargs): - return cls('upload', *args, **kwargs) + return cls("upload", *args, **kwargs) @classmethod def int_list(cls, *args, **kwargs): - return cls('list:int', *args, **kwargs) + return cls("list:int", *args, **kwargs) @classmethod def string_list(cls, *args, **kwargs): - return cls('list:string', *args, **kwargs) + return cls("list:string", *args, **kwargs) @classmethod def geography( cls, - geometry_type: str = 'GEOMETRY', + geometry_type: str = "GEOMETRY", srid: Optional[type_int] = None, dimension: Optional[type_int] = None, - **kwargs + **kwargs, ): - kwargs['_kw'] = { - "geometry_type": geometry_type, - "srid": srid, - "dimension": dimension - } + kwargs["_kw"] = {"geometry_type": geometry_type, "srid": srid, "dimension": dimension} return cls("geography", **kwargs) @classmethod def geometry( cls, - geometry_type: str = 'GEOMETRY', + geometry_type: str = "GEOMETRY", srid: Optional[type_int] = None, dimension: Optional[type_int] = None, - **kwargs + **kwargs, ): - kwargs['_kw'] = { - "geometry_type": geometry_type, - "srid": srid, - "dimension": dimension - } + kwargs["_kw"] = {"geometry_type": geometry_type, "srid": srid, "dimension": dimension} return cls("geometry", **kwargs) def cast(self, value, **kwargs): - return Expression( - self.db, self._dialect.cast, self, - self._dialect.types[value] % kwargs, value) + return Expression(self.db, self._dialect.cast, self, self._dialect.types[value] % kwargs, value) class Set(_Set): @@ -465,9 +400,10 @@ def _load_scopes_(self): def _clone(self, ignore_common_filters=None, model=None, **changes): return self.__class__( - self.db, changes.get('query', self.query), + self.db, + changes.get("query", self.query), ignore_common_filters=ignore_common_filters, - model=model or self._model_ + model=model or self._model_, ) def where(self, query, ignore_common_filters=None, model=None): @@ -478,12 +414,11 @@ def where(self, query, ignore_common_filters=None, model=None): elif isinstance(query, str): query = Expression(self.db, query) elif isinstance(query, Field): - query = query != None + query = query != None # noqa: E711 elif isinstance(query, types.LambdaType): model = model or self._model_ if not model: - raise ValueError( - "Too many models involved in the Set to use a lambda") + raise ValueError("Too many models involved in the Set to use a lambda") query = query(model) q = self.query & query if self.query else query return self._clone(ignore_common_filters, model, query=q) @@ -498,27 +433,21 @@ def _parse_paginate(self, pagination): return ((offset - 1) * limit, offset * limit) def _join_set_builder(self, obj, jdata, auto_select_tables): - return JoinedSet._from_set( - obj, jdata=jdata, auto_select_tables=auto_select_tables - ) + return JoinedSet._from_set(obj, jdata=jdata, auto_select_tables=auto_select_tables) def _left_join_set_builder(self, jdata): - return JoinedSet._from_set( - self, ljdata=jdata, auto_select_tables=[self._model_.table] - ) + return JoinedSet._from_set(self, ljdata=jdata, auto_select_tables=[self._model_.table]) def _run_select_(self, *fields, **options): tablemap = self.db._adapter.tables( self.query, - options.get('join', None), - options.get('left', None), - options.get('orderby', None), - options.get('groupby', None) - ) - fields, concrete_tables = self.db._adapter._expand_all_with_concrete_tables( - fields, tablemap + options.get("join", None), + options.get("left", None), + options.get("orderby", None), + options.get("groupby", None), ) - options['_concrete_tables'] = concrete_tables + fields, concrete_tables = self.db._adapter._expand_all_with_concrete_tables(fields, tablemap) + options["_concrete_tables"] = concrete_tables return self.db._adapter.select(self.query, fields, options) def _get_table_from_query(self) -> Table: @@ -528,32 +457,27 @@ def _get_table_from_query(self) -> Table: def select(self, *fields, **options): obj = self - pagination, including = ( - options.pop('paginate', None), - options.pop('including', None) - ) + pagination, including = (options.pop("paginate", None), options.pop("including", None)) if pagination: - options['limitby'] = self._parse_paginate(pagination) + options["limitby"] = self._parse_paginate(pagination) if including and self._model_ is not None: - options['left'], jdata = self._parse_left_rjoins(including) + options["left"], jdata = self._parse_left_rjoins(including) obj = self._left_join_set_builder(jdata) return obj._run_select_(*fields, **options) def iterselect(self, *fields, **options): - pagination = options.pop('paginate', None) + pagination = options.pop("paginate", None) if pagination: - options['limitby'] = self._parse_paginate(pagination) + options["limitby"] = self._parse_paginate(pagination) tablemap = self.db._adapter.tables( self.query, - options.get('join', None), - options.get('left', None), - options.get('orderby', None), - options.get('groupby', None) - ) - fields, concrete_tables = self.db._adapter._expand_all_with_concrete_tables( - fields, tablemap + options.get("join", None), + options.get("left", None), + options.get("orderby", None), + options.get("groupby", None), ) - options['_concrete_tables'] = concrete_tables + fields, concrete_tables = self.db._adapter._expand_all_with_concrete_tables(fields, tablemap) + options["_concrete_tables"] = concrete_tables return self.db._adapter.iterselect(self.query, fields, options) def update(self, skip_callbacks=False, **update_fields): @@ -568,15 +492,11 @@ def update(self, skip_callbacks=False, **update_fields): if table._has_commit_update_callbacks: txn = self._db._adapter.top_transaction() if txn: - txn._add_op(TransactionOp( - TransactionOps.update, - table, - TransactionOpContext( - values=row, - dbset=self, - ret=ret + txn._add_op( + TransactionOp( + TransactionOps.update, table, TransactionOpContext(values=row, dbset=self, ret=ret) ) - )) + ) ret and [f(self, row) for f in table._after_update] return ret @@ -589,14 +509,7 @@ def delete(self, skip_callbacks=False): if table._has_commit_delete_callbacks: txn = self._db._adapter.top_transaction() if txn: - txn._add_op(TransactionOp( - TransactionOps.delete, - table, - TransactionOpContext( - dbset=self, - ret=ret - ) - )) + txn._add_op(TransactionOp(TransactionOps.delete, table, TransactionOpContext(dbset=self, ret=ret))) ret and [f(self) for f in table._after_delete] return ret @@ -604,19 +517,15 @@ def validate_and_update(self, skip_callbacks=False, **update_fields): table = self._get_table_from_query() current._dbvalidation_record_id_ = None if table._unique_fields_validation_ and self.count() == 1: - if any( - table._unique_fields_validation_.get(fieldname) - for fieldname in update_fields.keys() - ): - current._dbvalidation_record_id_ = \ - self.select(table.id).first().id + if any(table._unique_fields_validation_.get(fieldname) for fieldname in update_fields.keys()): + current._dbvalidation_record_id_ = self.select(table.id).first().id response = Row() response.errors = Row() new_fields = copy.copy(update_fields) for key, value in update_fields.items(): value, error = table[key].validate(value) if error: - response.errors[key] = '%s' % error + response.errors[key] = "%s" % error else: new_fields[key] = value del current._dbvalidation_record_id_ @@ -629,9 +538,7 @@ def validate_and_update(self, skip_callbacks=False, **update_fields): if not skip_callbacks and any(f(self, row) for f in table._before_update): ret = 0 else: - ret = self.db._adapter.update( - table, self.query, row.op_values() - ) + ret = self.db._adapter.update(table, self.query, row.op_values()) if not skip_callbacks and ret: for f in table._after_update: f(self, row) @@ -642,9 +549,7 @@ def _update_from_save(self, model, row, skip_callbacks=False): table: Table = model.table if not skip_callbacks and any(f(row) for f in table._before_save): return False - fields = table._fields_and_values_for_save( - row, model._fieldset_editable, table._fields_and_values_for_update - ) + fields = table._fields_and_values_for_save(row, model._fieldset_editable, table._fields_and_values_for_update) if not skip_callbacks and any(f(self, fields) for f in table._before_update): return False ret = self.db._adapter.update(table, self.query, fields.op_values()) @@ -652,27 +557,21 @@ def _update_from_save(self, model, row, skip_callbacks=False): if table._has_commit_update_callbacks or table._has_commit_save_callbacks: txn = self._db._adapter.top_transaction() if txn and table._has_commit_update_callbacks: - txn._add_op(TransactionOp( - TransactionOps.update, - table, - TransactionOpContext( - values=fields, - dbset=self, - ret=ret + txn._add_op( + TransactionOp( + TransactionOps.update, table, TransactionOpContext(values=fields, dbset=self, ret=ret) ) - )) + ) if txn and table._has_commit_save_callbacks: - txn._add_op(TransactionOp( - TransactionOps.save, - table, - TransactionOpContext( - values=fields, - dbset=self, - ret=ret, - row=row.clone_changed(), - changes=row.changes + txn._add_op( + TransactionOp( + TransactionOps.save, + table, + TransactionOpContext( + values=fields, dbset=self, ret=ret, row=row.clone_changed(), changes=row.changes + ), ) - )) + ) ret and [f(self, fields) for f in table._after_update] ret and [f(row) for f in table._after_save] return bool(ret) @@ -687,31 +586,18 @@ def _delete_from_destroy(self, model, row, skip_callbacks=False): if ret: model._unset_row_persistence(row) if not skip_callbacks: - if ( - table._has_commit_delete_callbacks or - table._has_commit_destroy_callbacks - ): + if table._has_commit_delete_callbacks or table._has_commit_destroy_callbacks: txn = self._db._adapter.top_transaction() if txn and table._has_commit_delete_callbacks: - txn._add_op(TransactionOp( - TransactionOps.delete, - table, - TransactionOpContext( - dbset=self, - ret=ret - ) - )) + txn._add_op(TransactionOp(TransactionOps.delete, table, TransactionOpContext(dbset=self, ret=ret))) if txn and table._has_commit_destroy_callbacks: - txn._add_op(TransactionOp( - TransactionOps.destroy, - table, - TransactionOpContext( - dbset=self, - ret=ret, - row=row.clone_changed(), - changes=row.changes + txn._add_op( + TransactionOp( + TransactionOps.destroy, + table, + TransactionOpContext(dbset=self, ret=ret, row=row.clone_changed(), changes=row.changes), ) - )) + ) ret and [f(self) for f in table._after_delete] ret and [f(row) for f in table._after_destroy] return bool(ret) @@ -745,25 +631,23 @@ def _parse_rjoin(self, arg): if rel: if rel.via: r = RelationBuilder(rel, self._model_._instance_()).via() - return r[0], r[1]._table, 'many' + return r[0], r[1]._table, "many" r = RelationBuilder(rel, self._model_._instance_()) - return r.many(), rel.table, 'many' + return r.many(), rel.table, "many" #: match belongs_to and refers_to rel = self._model_._belongs_fks_.get(arg) if rel: r = RelationBuilder(rel, self._model_._instance_()).belongs_query() - return r, self._model_.db[rel.model], 'belongs' + return r, self._model_.db[rel.model], "belongs" #: match has_one rel = self._model_._hasone_ref_.get(arg) if rel: if rel.via: r = RelationBuilder(rel, self._model_._instance_()).via() - return r[0], r[1]._table, 'one' + return r[0], r[1]._table, "one" r = RelationBuilder(rel, self._model_._instance_()) - return r.many(), rel.table, 'one' - raise RuntimeError( - f'Unable to find {arg} relation of {self._model_.__name__} model' - ) + return r.many(), rel.table, "one" + raise RuntimeError(f"Unable to find {arg} relation of {self._model_.__name__} model") def _parse_left_rjoins(self, args): if not isinstance(args, (list, tuple)): @@ -798,7 +682,7 @@ def __getattr__(self, name): class RelationSet(object): - _relation_method_ = 'many' + _relation_method_ = "many" def __init__(self, db, relation_builder, row): self.db = db @@ -821,8 +705,7 @@ def _model_(self): def _fields_(self): pks = self._relation_.model.primary_keys or ["id"] return [ - (relation_field.name, pks[idx]) - for idx, relation_field in enumerate(self._relation_.ref.fields_instances) + (relation_field.name, pks[idx]) for idx, relation_field in enumerate(self._relation_.ref.fields_instances) ] @cachedprop @@ -845,12 +728,12 @@ def __call__(self, query, ignore_common_filters=False): return self._set.where(query, ignore_common_filters) def _last_resultset(self, refresh=False): - if refresh or not hasattr(self, '_cached_resultset'): + if refresh or not hasattr(self, "_cached_resultset"): self._cached_resultset = self._cache_resultset() return self._cached_resultset def _filter_reload(self, kwargs): - return kwargs.pop('reload', False) + return kwargs.pop("reload", False) def new(self, **kwargs): attrs = self._get_fields_from_scopes(self._scopes_, self._model_.tablename) @@ -879,10 +762,7 @@ def _get_fields_from_scopes(scopes, table_name): components.append(component.second) components.append(component.first) else: - if ( - isinstance(component, Field) and - component._tablename == table_name - ): + if isinstance(component, Field) and component._tablename == table_name: current_kv.append(component) else: if current_kv: @@ -903,7 +783,7 @@ def __call__(self, *args, **kwargs): refresh = self._filter_reload(kwargs) if not args and not kwargs: return self._last_resultset(refresh) - kwargs['limitby'] = (0, 1) + kwargs["limitby"] = (0, 1) return self.select(*args, **kwargs).first() @@ -920,24 +800,18 @@ def __call__(self, *args, **kwargs): def add(self, obj, skip_callbacks=False): if not isinstance(obj, (StructuredRow, RowReferenceMixin)): raise RuntimeError(f"Unsupported parameter {obj}") - attrs = self._get_fields_from_scopes( - self._scopes_, self._model_.tablename - ) + attrs = self._get_fields_from_scopes(self._scopes_, self._model_.tablename) rev_attrs = {} for ref, local in self._fields_: attrs[ref] = self._row_[local] rev_attrs[local] = attrs[ref] - rv = self.db(self._model_._query_row(obj)).validate_and_update( - skip_callbacks=skip_callbacks, **attrs - ) + rv = self.db(self._model_._query_row(obj)).validate_and_update(skip_callbacks=skip_callbacks, **attrs) if rv: for key, val in attrs.items(): obj[key] = val if len(rev_attrs) > 1: comprel_key = (self._relation_.ref.reverse, *rev_attrs.values()) - obj._compound_rels[comprel_key] = typed_row_reference_from_record( - self._row_, self._row_._model - ) + obj._compound_rels[comprel_key] = typed_row_reference_from_record(self._row_, self._row_._model) return rv def remove(self, obj, skip_callbacks=False): @@ -945,15 +819,10 @@ def remove(self, obj, skip_callbacks=False): raise RuntimeError(f"Unsupported parameter {obj}") attrs, is_delete = {ref: None for ref, _ in self._fields_}, False if self._model_._belongs_fks_[self._relation_.ref.reverse].is_refers: - rv = self.db(self._model_._query_row(obj)).validate_and_update( - skip_callbacks=skip_callbacks, - **attrs - ) + rv = self.db(self._model_._query_row(obj)).validate_and_update(skip_callbacks=skip_callbacks, **attrs) else: is_delete = True - rv = self.db(self._model_._query_row(obj)).delete( - skip_callbacks=skip_callbacks - ) + rv = self.db(self._model_._query_row(obj)).delete(skip_callbacks=skip_callbacks) if rv: for key, val in attrs.items(): obj[key] = val @@ -965,19 +834,12 @@ def remove(self, obj, skip_callbacks=False): class ViaSet(RelationSet): - _relation_method_ = 'via' + _relation_method_ = "via" @cachedprop def _viadata(self): query, rfield, model_name, rid, via, viadata = super()._get_query_() - return sdict( - query=query, - rfield=rfield, - model_name=model_name, - rid=rid, - via=via, - data=viadata - ) + return sdict(query=query, rfield=rfield, model_name=model_name, rid=rid, via=via, data=viadata) def _get_query_(self): return self._viadata.query @@ -997,11 +859,11 @@ def __call__(self, *args, **kwargs): if not kwargs: return self._last_resultset(refresh) args = [self._viadata.rfield] - kwargs['limitby'] = (0, 1) + kwargs["limitby"] = (0, 1) return self.select(*args, **kwargs).first() def create(self, **kwargs): - raise RuntimeError('Cannot create third objects for one via relations') + raise RuntimeError("Cannot create third objects for one via relations") class HasManyViaSet(ViaSet): @@ -1037,12 +899,12 @@ def _fields_from_scopes(self): return self._get_fields_from_scopes(scopes, rel.table_name) def create(self, **kwargs): - raise RuntimeError('Cannot create third objects for many via relations') + raise RuntimeError("Cannot create third objects for many via relations") def add(self, obj, skip_callbacks=False, **kwargs): # works on join tables only! if self._viadata.via is None: - raise RuntimeError(self._via_error % 'add') + raise RuntimeError(self._via_error % "add") nrow = self._fields_from_scopes() nrow.update(**kwargs) #: get belongs references @@ -1052,25 +914,22 @@ def add(self, obj, skip_callbacks=False, **kwargs): for local_field, foreign_field in rel_fields: nrow[local_field] = obj[foreign_field] #: validate and insert - return self.db[self._viadata.via]._model_.create( - nrow, skip_callbacks=skip_callbacks - ) + return self.db[self._viadata.via]._model_.create(nrow, skip_callbacks=skip_callbacks) def remove(self, obj, skip_callbacks=False): # works on join tables only! if self._viadata.via is None: - raise RuntimeError(self._via_error % 'remove') + raise RuntimeError(self._via_error % "remove") #: get belongs references self_fields, rel_fields = self._get_relation_fields() #: delete query = reduce( - operator.and_, [ - self.db[self._viadata.via][field] == self._viadata.rid[idx] - for idx, field in enumerate(self_fields) - ] + [ + operator.and_, + [self.db[self._viadata.via][field] == self._viadata.rid[idx] for idx, field in enumerate(self_fields)] + + [ self.db[self._viadata.via][local_field] == obj[foreign_field] for local_field, foreign_field in rel_fields - ] + ], ) return self.db(query).delete(skip_callbacks=skip_callbacks) @@ -1097,46 +956,40 @@ def _clone(self, ignore_common_filters=None, model=None, **changes): def _join_set_builder(self, obj, jdata, auto_select_tables): return JoinedSet._from_set( - obj, jdata=self._jdata_ + jdata, ljdata=self._ljdata_, - auto_select_tables=self._auto_select_tables_ + auto_select_tables + obj, + jdata=self._jdata_ + jdata, + ljdata=self._ljdata_, + auto_select_tables=self._auto_select_tables_ + auto_select_tables, ) def _left_join_set_builder(self, jdata): return JoinedSet._from_set( - self, jdata=self._jdata_, ljdata=self._ljdata_ + jdata, - auto_select_tables=self._auto_select_tables_ + self, jdata=self._jdata_, ljdata=self._ljdata_ + jdata, auto_select_tables=self._auto_select_tables_ ) def _iterselect_rows(self, *fields, **options): tablemap = self.db._adapter.tables( self.query, - options.get('join', None), - options.get('left', None), - options.get('orderby', None), - options.get('groupby', None) - ) - fields, concrete_tables = self.db._adapter._expand_all_with_concrete_tables( - fields, tablemap + options.get("join", None), + options.get("left", None), + options.get("orderby", None), + options.get("groupby", None), ) + fields, concrete_tables = self.db._adapter._expand_all_with_concrete_tables(fields, tablemap) colnames, sql = self.db._adapter._select_wcols(self.query, fields, **options) return JoinIterRows(self.db, sql, fields, concrete_tables, colnames) def _split_joins(self, joins): - rv = {'belongs': [], 'one': [], 'many': []} + rv = {"belongs": [], "one": [], "many": []} for jname, jtable, rel_type in joins: rv[rel_type].append((jname, jtable)) - return rv['belongs'], rv['one'], rv['many'] + return rv["belongs"], rv["one"], rv["many"] def _build_records_from_joined(self, rowmap, inclusions, colnames): for rid, many_data in inclusions.items(): for jname, included in many_data.items(): - rowmap[rid][jname]._cached_resultset = Rows( - self.db, list(included.values()), [] - ) - return JoinRows( - self.db, list(rowmap.values()), colnames, - _jdata=self._jdata_ + self._ljdata_ - ) + rowmap[rid][jname]._cached_resultset = Rows(self.db, list(included.values()), []) + return JoinRows(self.db, list(rowmap.values()), colnames, _jdata=self._jdata_ + self._ljdata_) def _select_rowpks_extractor(self, row): if not set(row.keys()).issuperset(self._pks_): @@ -1149,27 +1002,20 @@ def _run_select_(self, *fields, **options): #: build parsers belongs_j, one_j, many_j = self._split_joins(self._jdata_) belongs_l, one_l, many_l = self._split_joins(self._ljdata_) - parsers = ( - self._build_jparsers(belongs_j, one_j, many_j) + - self._build_lparsers(belongs_l, one_l, many_l) - ) + parsers = self._build_jparsers(belongs_j, one_j, many_j) + self._build_lparsers(belongs_l, one_l, many_l) #: auto add selection field for left joins if self._ljdata_: fields = list(fields) if not fields: fields = [v.ALL for v in self._auto_select_tables_] - for join in options['left']: + for join in options["left"]: fields.append(join.first.ALL) #: use iterselect for performance rows = self._iterselect_rows(*fields, **options) #: rebuild rowset using nested objects plainrows = [] rowmap = OrderedDict() - inclusions = defaultdict( - lambda: { - jname: OrderedDict() for jname, _ in (many_j + many_l) - } - ) + inclusions = defaultdict(lambda: {jname: OrderedDict() for jname, _ in (many_j + many_l)}) for row in rows: if self._stable_ not in row: plainrows.append(row) @@ -1183,9 +1029,7 @@ def _run_select_(self, *fields, **options): parser(rowmap, inclusions, row, rid) if not rowmap and plainrows: return Rows(self.db, plainrows, rows.colnames) - return self._build_records_from_joined( - rowmap, inclusions, rows.colnames - ) + return self._build_records_from_joined(rowmap, inclusions, rows.colnames) def _build_jparsers(self, belongs, one, many): rv = [] @@ -1212,15 +1056,15 @@ def _jbelong_parser(db, fieldname, tablename): rmodel = db[tablename]._model_ def parser(rowmap, inclusions, row, rid): - rowmap[rid][fieldname] = typed_row_reference_from_record( - row[tablename], rmodel - ) + rowmap[rid][fieldname] = typed_row_reference_from_record(row[tablename], rmodel) + return parser @staticmethod def _jone_parser(db, fieldname, tablename): def parser(rowmap, inclusions, row, rid): rowmap[rid][fieldname]._cached_resultset = row[tablename] + return parser @staticmethod @@ -1230,10 +1074,10 @@ def _jmany_parser(db, fieldname, tablename): ext = lambda row: tuple(row[pk] for pk in pks) if len(pks) > 1 else row[pks[0]] def parser(rowmap, inclusions, row, rid): - inclusions[rid][fieldname][ext(row[tablename])] = \ - inclusions[rid][fieldname].get( - ext(row[tablename]), row[tablename] - ) + inclusions[rid][fieldname][ext(row[tablename])] = inclusions[rid][fieldname].get( + ext(row[tablename]), row[tablename] + ) + return parser @staticmethod @@ -1245,9 +1089,8 @@ def _lbelong_parser(db, fieldname, tablename): def parser(rowmap, inclusions, row, rid): if not check(row[tablename]): return - rowmap[rid][fieldname] = typed_row_reference_from_record( - row[tablename], rmodel - ) + rowmap[rid][fieldname] = typed_row_reference_from_record(row[tablename], rmodel) + return parser @staticmethod @@ -1260,6 +1103,7 @@ def parser(rowmap, inclusions, row, rid): if not check(row[tablename]): return rowmap[rid][fieldname]._cached_resultset = row[tablename] + return parser @staticmethod @@ -1272,17 +1116,16 @@ def _lmany_parser(db, fieldname, tablename): def parser(rowmap, inclusions, row, rid): if not check(row[tablename]): return - inclusions[rid][fieldname][ext(row[tablename])] = \ - inclusions[rid][fieldname].get( - ext(row[tablename]), row[tablename] - ) + inclusions[rid][fieldname][ext(row[tablename])] = inclusions[rid][fieldname].get( + ext(row[tablename]), row[tablename] + ) + return parser class Row(_Row): _as_dict_types_ = tuple( - [type(None)] + [int, float, bool, list, dict, str] + - [datetime.datetime, datetime.date, datetime.time] + [type(None)] + [int, float, bool, list, dict, str] + [datetime.datetime, datetime.date, datetime.time] ) @classmethod @@ -1312,10 +1155,10 @@ def __json__(self): return self.as_dict() def __xml__(self, key=None, quote=True): - return xml_encode(self.as_dict(), key or 'row', quote) + return xml_encode(self.as_dict(), key or "row", quote) def __str__(self): - return ''.format(self.as_dict(geo_coordinates=False)) + return "".format(self.as_dict(geo_coordinates=False)) def __repr__(self): return str(self) @@ -1353,11 +1196,7 @@ def _from_engine(cls, data: Dict[str, Any]): rv._fields.update(data) return rv - def __init__( - self, - fields: Optional[Dict[str, Any]] = None, - **extras: Any - ): + def __init__(self, fields: Optional[Dict[str, Any]] = None, **extras: Any): object.__setattr__(self, "_changes", {}) object.__setattr__(self, "_compound_rels", {}) object.__setattr__(self, "_concrete", extras.pop("__concrete", False)) @@ -1374,10 +1213,7 @@ def __getitem__(self, name): def __setattr__(self, key, value): if key in self.__slots__: return - oldv = ( - self._changes[key][0] if key in self._changes else - getattr(self, key, None) - ) + oldv = self._changes[key][0] if key in self._changes else getattr(self, key, None) object.__setattr__(self, key, value) newv = getattr(self, key, None) if (oldv is None and value is not None) or oldv != newv: @@ -1392,12 +1228,7 @@ def __getstate__(self): return { "__fields": self._fields, "__extras": self.__dict__, - "__struct": { - "_concrete": self._concrete, - "_changes": {}, - "_compound_rels": {}, - "_virtuals": {} - } + "__struct": {"_concrete": self._concrete, "_changes": {}, "_compound_rels": {}, "_virtuals": {}}, } def __setstate__(self, state): @@ -1415,7 +1246,7 @@ def __eq__(self, other): return self._fields == other._fields and self.__dict__ == other.__dict__ def __copy__(self): - return StructuredRow(self._fields, __concrete=self._concrete, **self.__dict__) + return self.__class__(dict(self._fields), __concrete=self._concrete, **self.__dict__) def keys(self): for pool in (self._fields, self.__dict__): @@ -1463,10 +1294,7 @@ def clone(self): fields[key] = self._fields[key] return self.__class__(fields, __concrete=self._concrete, **self.__dict__) - def clone_changed(self): - return self.__class__( - {**self._fields}, __concrete=self._concrete, **self.__dict__ - ) + clone_changed = __copy__ @property def validation_errors(self): @@ -1478,9 +1306,7 @@ def is_valid(self): class Rows(_Rows): - def __init__( - self, db=None, records=[], colnames=[], compact=True, rawrows=None - ): + def __init__(self, db=None, records=[], colnames=[], compact=True, rawrows=None): self.db = db self.records = records self.colnames = colnames @@ -1491,7 +1317,7 @@ def __init__( def compact(self): if not self.records: return False - return len(self._rowkeys_) == 1 and self._rowkeys_[0] != '_extra' + return len(self._rowkeys_) == 1 and self._rowkeys_[0] != "_extra" @cachedprop def compact_tablename(self): @@ -1519,7 +1345,7 @@ def sorted(self, f, reverse=False): keyf = lambda r: f(r[self.compact_tablename]) else: keyf = f - return [r for r in sorted(self.records, key=keyf, reverse=reverse)] + return sorted(self.records, key=keyf, reverse=reverse) def sort(self, f, reverse=False): self.records = self.sorted(f, reverse) @@ -1538,9 +1364,9 @@ def render(self, *args, **kwargs): def as_list(self, datetime_to_str=False, custom_types=None): return [item.as_dict(datetime_to_str, custom_types) for item in self] - def as_dict(self, key='id', datetime_to_str=False, custom_types=None): - if '.' in key: - splitted_key = key.split('.') + def as_dict(self, key="id", datetime_to_str=False, custom_types=None): + if "." in key: + splitted_key = key.split(".") keyf = lambda row: row[splitted_key[0]][splitted_key[1]] else: keyf = lambda row: row[key] @@ -1550,7 +1376,7 @@ def __json__(self): return [item.__json__() for item in self] def __xml__(self, key=None, quote=True): - key = key or 'rows' + key = key or "rows" return tag[key](*[item.__xml__(quote=quote) for item in self]) def __str__(self): @@ -1580,17 +1406,11 @@ def __next__(self): if db_row is None: raise StopIteration row = self.db._adapter._parse( - db_row, - self.fdata, - self.tables, - self.concrete_tables, - self.fields, - self.colnames, - self.blob_decode + db_row, self.fdata, self.tables, self.concrete_tables, self.fields, self.colnames, self.blob_decode ) if self.compact: keys = list(row.keys()) - if len(keys) == 1 and keys[0] != '_extra': + if len(keys) == 1 and keys[0] != "_extra": row = row[keys[0]] return row @@ -1609,13 +1429,10 @@ def __iter__(self): class JoinRows(Rows): def __init__(self, *args, **kwargs): - self._joins_ = kwargs.pop('_jdata') + self._joins_ = kwargs.pop("_jdata") super(JoinRows, self).__init__(*args, **kwargs) - def as_list( - self, compact=True, storage_to_dict=True, datetime_to_str=False, - custom_types=None - ): + def as_list(self, compact=True, storage_to_dict=True, datetime_to_str=False, custom_types=None): if storage_to_dict: items = [] for row in self: @@ -1625,7 +1442,7 @@ def as_list( item[jdata[0]] = row[jdata[0]].as_list() items.append(item) else: - items = [item for item in self] + items = list(self) return items @@ -1654,7 +1471,7 @@ def __init__( dbset: Optional[Set] = None, ret: Any = None, row: Optional[Row] = None, - changes: Optional[sdict] = None + changes: Optional[sdict] = None, ): self.values = values self.dbset = dbset @@ -1666,12 +1483,7 @@ def __init__( class TransactionOp: __slots__ = ["op_type", "table", "context"] - def __init__( - self, - op_type: TransactionOps, - table: Table, - context: TransactionOpContext - ): + def __init__(self, op_type: TransactionOps, table: Table, context: TransactionOpContext): self.op_type = op_type self.table = table self.context = context diff --git a/emmett/orm/transactions.py b/emmett/orm/transactions.py index d7e7160d..7dfed161 100644 --- a/emmett/orm/transactions.py +++ b/emmett/orm/transactions.py @@ -1,20 +1,19 @@ # -*- coding: utf-8 -*- """ - emmett.orm.transactions - ----------------------- +emmett.orm.transactions +----------------------- - Provides pyDAL advanced transactions implementation for Emmett. +Provides pyDAL advanced transactions implementation for Emmett. - :copyright: 2014 Giovanni Barillari +:copyright: 2014 Giovanni Barillari - Parts of this code are inspired to peewee - :copyright: (c) 2010 by Charles Leifer +Parts of this code are inspired to peewee +:copyright: (c) 2010 by Charles Leifer - :license: BSD-3-Clause +:license: BSD-3-Clause """ import uuid - from functools import wraps @@ -24,6 +23,7 @@ def __call__(self, fn): def inner(*args, **kwargs): with self: return fn(*args, **kwargs) + return inner @@ -105,7 +105,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): class _savepoint(callable_context_manager): def __init__(self, adapter, sid=None): self.adapter = adapter - self.sid = sid or 's' + uuid.uuid4().hex + self.sid = sid or "s" + uuid.uuid4().hex self.quoted_sid = self.adapter.dialect.quote(self.sid) self._ops = [] self._parent = None @@ -117,16 +117,16 @@ def _add_ops(self, ops): self._ops.extend(ops) def _begin(self): - self.adapter.execute('SAVEPOINT %s;' % self.quoted_sid) + self.adapter.execute("SAVEPOINT %s;" % self.quoted_sid) def commit(self, begin=True): - self.adapter.execute('RELEASE SAVEPOINT %s;' % self.quoted_sid) + self.adapter.execute("RELEASE SAVEPOINT %s;" % self.quoted_sid) if begin: self._begin() def rollback(self): self._ops.clear() - self.adapter.execute('ROLLBACK TO SAVEPOINT %s;' % self.quoted_sid) + self.adapter.execute("ROLLBACK TO SAVEPOINT %s;" % self.quoted_sid) def __enter__(self): self._parent = self.adapter.top_transaction() diff --git a/emmett/orm/wrappers.py b/emmett/orm/wrappers.py index a588c2fa..658fe1cf 100644 --- a/emmett/orm/wrappers.py +++ b/emmett/orm/wrappers.py @@ -1,16 +1,16 @@ # -*- coding: utf-8 -*- """ - emmett.orm.wrappers - ------------------- +emmett.orm.wrappers +------------------- - Provides ORM wrappers utilities. +Provides ORM wrappers utilities. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ from .helpers import RelationBuilder -from .objects import HasOneSet, HasOneViaSet, HasManySet, HasManyViaSet +from .objects import HasManySet, HasManyViaSet, HasOneSet, HasOneViaSet class Wrapper(object): diff --git a/emmett/parsers.py b/emmett/parsers.py index 829798e9..6f0ec3bc 100644 --- a/emmett/parsers.py +++ b/emmett/parsers.py @@ -1,46 +1,10 @@ # -*- coding: utf-8 -*- """ - emmett.parsers - -------------- +emmett.parsers +-------------- - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ -from functools import partial -from typing import Any, Callable, Dict - -try: - import orjson - _json_impl = orjson.loads - _json_opts = {} -except ImportError: - import rapidjson - _json_impl = rapidjson.loads - _json_opts = { - "datetime_mode": rapidjson.DM_ISO8601 | rapidjson.DM_NAIVE_IS_UTC, - "number_mode": rapidjson.NM_NATIVE - } - - -class Parsers(object): - _registry_: Dict[str, Callable[[str], Dict[str, Any]]] = {} - - @classmethod - def register_for(cls, target): - def wrap(f): - cls._registry_[target] = f - return f - return wrap - - @classmethod - def get_for(cls, target): - return cls._registry_[target] - - -json = partial( - _json_impl, - **_json_opts -) - -Parsers.register_for('json')(json) +from emmett_core.parsers import Parsers as Parsers diff --git a/emmett/pipeline.py b/emmett/pipeline.py index 70ae4608..16613c97 100644 --- a/emmett/pipeline.py +++ b/emmett/pipeline.py @@ -1,265 +1,31 @@ # -*- coding: utf-8 -*- """ - emmett.pipeline - --------------- +emmett.pipeline +--------------- - Provides the pipeline classes. +Provides the pipeline classes. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ -import asyncio import types -from functools import wraps -from typing import Optional +from emmett_core.http.helpers import redirect +from emmett_core.pipeline.extras import RequirePipe as _RequirePipe +from emmett_core.pipeline.pipe import Pipe as Pipe +from .ctx import current from .helpers import flash -from .http import HTTPResponse, redirect -class Pipeline: - __slots__ = ['_method_open', '_method_close', 'pipes'] - _type_suffix = '' +class RequirePipe(_RequirePipe): + __slots__ = ["flash"] + _current = current - def __init__(self, pipes=[]): - self._method_open = f'open_{self._type_suffix}' - self._method_close = f'close_{self._type_suffix}' - self.pipes = pipes - - @staticmethod - def _awaitable_wrap(f): - @wraps(f) - async def awaitable(*args, **kwargs): - return f(*args, **kwargs) - return awaitable - - def __call__(self, f): - raise NotImplementedError - - def _flow_open(self): - rv = [] - for pipe in self.pipes: - if pipe._pipeline_all_methods_.issuperset( - {'open', self._method_open} - ): - raise RuntimeError( - f'{pipe.__class__.__name__} pipe has double open methods.' - f' Use `open` or `{self._method_open}`, not both.' - ) - if 'open' in pipe._pipeline_all_methods_: - rv.append(pipe.open) - if self._method_open in pipe._pipeline_all_methods_: - rv.append(getattr(pipe, self._method_open)) - return rv - - def _flow_close(self): - rv = [] - for pipe in reversed(self.pipes): - if pipe._pipeline_all_methods_.issuperset( - {'close', self._method_close} - ): - raise RuntimeError( - f'{pipe.__class__.__name__} pipe has double close methods.' - f' Use `close` or `{self._method_close}`, not both.' - ) - if 'close' in pipe._pipeline_all_methods_: - rv.append(pipe.close) - if self._method_close in pipe._pipeline_all_methods_: - rv.append(getattr(pipe, self._method_close)) - return rv - - -class RequestPipeline(Pipeline): - __slots__ = [] - _type_suffix = 'request' - - def _get_proper_wrapper(self, pipe): - if pipe._pipeline_all_methods_.issuperset( - {'on_pipe_success', 'on_pipe_failure'} - ): - rv = _wrap_flow_request_complete - elif 'on_pipe_success' in pipe._pipeline_all_methods_: - rv = _wrap_flow_request_success - elif 'on_pipe_failure' in pipe._pipeline_all_methods_: - rv = _wrap_flow_request_failure - else: - rv = _wrap_flow_request_basic - return rv - - def __call__(self, f): - if not asyncio.iscoroutinefunction(f): - f = self._awaitable_wrap(f) - for pipe in reversed(self.pipes): - if not isinstance(pipe, Pipe): - continue - if not pipe._is_flow_request_responsible: - continue - wrapper = self._get_proper_wrapper(pipe) - pipe_method = ( - pipe.pipe_request - if 'pipe_request' in pipe._pipeline_all_methods_ - else pipe.pipe) - f = wrapper( - pipe_method, pipe.on_pipe_success, pipe.on_pipe_failure, f) - return f - - def _output_type(self): - rv = None - for pipe in reversed(self.pipes): - if not pipe._is_flow_request_responsible or pipe.output is None: - continue - rv = pipe.output - return rv - - -class WebsocketPipeline(Pipeline): - __slots__ = [] - _type_suffix = 'ws' - - def _get_proper_wrapper(self, pipe): - if pipe._pipeline_all_methods_.issuperset( - {'on_pipe_success', 'on_pipe_failure'} - ): - rv = _wrap_flow_ws_complete - elif 'on_pipe_success' in pipe._pipeline_all_methods_: - rv = _wrap_flow_ws_success - elif 'on_pipe_failure' in pipe._pipeline_all_methods_: - rv = _wrap_flow_ws_failure - else: - rv = _wrap_flow_ws_basic - return rv - - def __call__(self, f): - if not asyncio.iscoroutinefunction(f): - f = self._awaitable_wrap(f) - for pipe in reversed(self.pipes): - if not isinstance(pipe, Pipe): - continue - if not pipe._is_flow_ws_responsible: - continue - wrapper = self._get_proper_wrapper(pipe) - pipe_method = ( - pipe.pipe_ws - if 'pipe_ws' in pipe._pipeline_all_methods_ - else pipe.pipe) - f = wrapper( - pipe_method, pipe.on_pipe_success, pipe.on_pipe_failure, f) - return f - - def _flow_receive(self): - rv = [] - for pipe in self.pipes: - if 'on_receive' not in pipe._pipeline_all_methods_: - continue - rv.append(pipe.on_receive) - return rv - - def _flow_send(self): - rv = [] - for pipe in reversed(self.pipes): - if 'on_send' not in pipe._pipeline_all_methods_: - continue - rv.append(pipe.on_send) - return rv - - -class MetaPipe(type): - _pipeline_methods_ = { - 'open', 'open_request', 'open_ws', - 'close', 'close_request', 'close_ws', - 'pipe', 'pipe_request', 'pipe_ws', - 'on_pipe_success', 'on_pipe_failure', - 'on_receive', 'on_send' - } - - def __new__(cls, name, bases, attrs): - new_class = type.__new__(cls, name, bases, attrs) - if not bases: - return new_class - declared_methods = cls._pipeline_methods_ & set(attrs.keys()) - new_class._pipeline_declared_methods_ = declared_methods - all_methods = set() - for base in reversed(new_class.__mro__[:-2]): - if hasattr(base, '_pipeline_declared_methods_'): - all_methods = all_methods | base._pipeline_declared_methods_ - all_methods = all_methods | declared_methods - new_class._pipeline_all_methods_ = all_methods - new_class._is_flow_request_responsible = bool( - all_methods & { - 'pipe', 'pipe_request', 'on_pipe_success', 'on_pipe_failure' - } - ) - new_class._is_flow_ws_responsible = bool( - all_methods & { - 'pipe', 'pipe_ws', 'on_pipe_success', 'on_pipe_failure' - } - ) - if all_methods.issuperset({'pipe', 'pipe_request'}): - raise RuntimeError( - f'{name} has double pipe methods. ' - 'Use `pipe` or `pipe_request`, not both.' - ) - if all_methods.issuperset({'pipe', 'pipe_ws'}): - raise RuntimeError( - f'{name} has double pipe methods. ' - 'Use `pipe` or `pipe_ws`, not both.' - ) - return new_class - - -class Pipe(metaclass=MetaPipe): - output: Optional[str] = None - - async def open(self): - pass - - async def open_request(self): - pass - - async def open_ws(self): - pass - - async def close(self): - pass - - async def close_request(self): - pass - - async def close_ws(self): - pass - - async def pipe(self, next_pipe, **kwargs): - return await next_pipe(**kwargs) - - async def pipe_request(self, next_pipe, **kwargs): - return await next_pipe(**kwargs) - - async def pipe_ws(self, next_pipe, **kwargs): - return await next_pipe(**kwargs) - - async def on_pipe_success(self): - pass - - async def on_pipe_failure(self): - pass - - def on_receive(self, data): - return data - - def on_send(self, data): - return data - - -class RequirePipe(Pipe): - def __init__(self, condition=None, otherwise=None): - if condition is None or otherwise is None: - raise SyntaxError('usage: @requires(condition, otherwise)') - if not callable(otherwise) and not isinstance(otherwise, str): - raise SyntaxError("'otherwise' param must be string or callable") - self.condition = condition - self.otherwise = otherwise + def __init__(self, condition=None, otherwise=None, flash=True): + super().__init__(condition=condition, otherwise=otherwise) + self.flash = flash async def pipe_request(self, next_pipe, **kwargs): flag = self.condition() @@ -267,34 +33,25 @@ async def pipe_request(self, next_pipe, **kwargs): if self.otherwise is not None: if callable(self.otherwise): return self.otherwise() - redirect(self.otherwise) + redirect(self.__class__._current, self.otherwise) else: - flash('Insufficient privileges') - redirect('/') + if self.flash: + flash("Insufficient privileges") + redirect(self.__class__._current, "/") return await next_pipe(**kwargs) - async def pipe_ws(self, next_pipe, **kwargs): - flag = self.condition() - if not flag: - return - await next_pipe(**kwargs) - class Injector(Pipe): - namespace: str = '__global__' + namespace: str = "__global__" def __init__(self): self._injections_ = {} - if self.namespace != '__global__': + if self.namespace != "__global__": self._inject = self._inject_local return self._inject = self._inject_global - for attr_name in ( - set(dir(self)) - - self.__class__._pipeline_methods_ - - {'output', 'namespace'} - ): - if attr_name.startswith('_'): + for attr_name in set(dir(self)) - self.__class__._pipeline_methods_ - {"output", "namespace"}: + if attr_name.startswith("_"): continue attr = getattr(self, attr_name) if isinstance(attr, types.MethodType): @@ -306,6 +63,7 @@ def __init__(self): def _wrapped_method(method): def wrap(*args, **kwargs): return method(*args, **kwargs) + return wrap def _inject_local(self, ctx): @@ -319,90 +77,3 @@ async def pipe_request(self, next_pipe, **kwargs): if isinstance(ctx, dict): self._inject(ctx) return ctx - - -def _wrap_flow_request_complete(pipe_method, on_success, on_failure, f): - @wraps(f) - async def flow(**kwargs): - try: - output = await pipe_method(f, **kwargs) - await on_success() - return output - except HTTPResponse: - await on_success() - raise - except Exception: - await on_failure() - raise - return flow - - -def _wrap_flow_request_success(pipe_method, on_success, on_failure, f): - @wraps(f) - async def flow(**kwargs): - try: - output = await pipe_method(f, **kwargs) - await on_success() - return output - except HTTPResponse: - await on_success() - raise - return flow - - -def _wrap_flow_request_failure(pipe_method, on_success, on_failure, f): - @wraps(f) - async def flow(**kwargs): - try: - return await pipe_method(f, **kwargs) - except HTTPResponse: - raise - except Exception: - await on_failure() - raise - return flow - - -def _wrap_flow_request_basic(pipe_method, on_success, on_failure, f): - @wraps(f) - async def flow(**kwargs): - return await pipe_method(f, **kwargs) - return flow - - -def _wrap_flow_ws_complete(pipe_method, on_success, on_failure, f): - @wraps(f) - async def flow(**kwargs): - try: - await pipe_method(f, **kwargs) - await on_success() - except Exception: - await on_failure() - raise - return flow - - -def _wrap_flow_ws_success(pipe_method, on_success, on_failure, f): - @wraps(f) - async def flow(**kwargs): - await pipe_method(f, **kwargs) - await on_success() - return flow - - -def _wrap_flow_ws_failure(pipe_method, on_success, on_failure, f): - @wraps(f) - async def flow(**kwargs): - try: - await pipe_method(f, **kwargs) - except Exception: - await on_failure() - raise - return flow - - -def _wrap_flow_ws_basic(pipe_method, on_success, on_failure, f): - @wraps(f) - async def flow(**kwargs): - return await pipe_method(f, **kwargs) - return flow diff --git a/emmett/routing/dispatchers.py b/emmett/routing/dispatchers.py deleted file mode 100644 index 36e47b90..00000000 --- a/emmett/routing/dispatchers.py +++ /dev/null @@ -1,181 +0,0 @@ -# -*- coding: utf-8 -*- -""" - emmett.routing.dispatchers - -------------------------- - - Provides pipeline dispatchers for routes. - - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause -""" - -import asyncio - -from ..ctx import current - - -class Dispatcher: - __slots__ = ['f', 'flow_open', 'flow_close'] - - def __init__(self, route): - self.f = route.f - self.flow_open = route.pipeline_flow_open - self.flow_close = route.pipeline_flow_close - - async def _parallel_flow(self, flow): - tasks = [asyncio.create_task(method()) for method in flow] - await asyncio.gather(*tasks, return_exceptions=True) - for task in tasks: - if task.exception(): - raise task.exception() - - def dispatch(self, reqargs): - return self.f(**reqargs) - - -class RequestDispatcher(Dispatcher): - __slots__ = ['response_builder'] - - def __init__(self, route, rule, response_builder): - super().__init__(route) - self.response_builder = response_builder - - async def dispatch(self, reqargs, response): - return self.response_builder(await self.f(**reqargs), response) - - -class RequestOpenDispatcher(RequestDispatcher): - __slots__ = [] - - async def dispatch(self, reqargs, response): - await self._parallel_flow(self.flow_open) - return self.response_builder(await self.f(**reqargs), response) - - -class RequestCloseDispatcher(RequestDispatcher): - __slots__ = [] - - async def dispatch(self, reqargs, response): - try: - rv = self.response_builder(await self.f(**reqargs), response) - finally: - await self._parallel_flow(self.flow_close) - return rv - - -class RequestFlowDispatcher(RequestDispatcher): - __slots__ = [] - - async def dispatch(self, reqargs, response): - await self._parallel_flow(self.flow_open) - try: - rv = self.response_builder(await self.f(**reqargs), response) - finally: - await self._parallel_flow(self.flow_close) - return rv - - -class WSOpenDispatcher(Dispatcher): - __slots__ = [] - - async def dispatch(self, reqargs): - await self._parallel_flow(self.flow_open) - await self.f(**reqargs) - - -class WSCloseDispatcher(Dispatcher): - __slots__ = [] - - async def dispatch(self, reqargs): - try: - await self.f(**reqargs) - except asyncio.CancelledError: - await asyncio.shield(self._parallel_flow(self.flow_close)) - return - except Exception: - await self._parallel_flow(self.flow_close) - raise - await asyncio.shield(self._parallel_flow(self.flow_close)) - - -class WSFlowDispatcher(Dispatcher): - __slots__ = [] - - async def dispatch(self, reqargs): - await self._parallel_flow(self.flow_open) - try: - await self.f(**reqargs) - except asyncio.CancelledError: - await asyncio.shield(self._parallel_flow(self.flow_close)) - return - except Exception: - await self._parallel_flow(self.flow_close) - raise - await asyncio.shield(self._parallel_flow(self.flow_close)) - - -class CacheDispatcher(RequestDispatcher): - __slots__ = ['route', 'cache_rule'] - - def __init__(self, route, rule, response_builder): - super().__init__(route, rule, response_builder) - self.route = route - self.cache_rule = rule.cache_rule - - async def get_data(self, reqargs, response): - key = self.cache_rule._build_ctx_key( - self.route, **self.cache_rule._build_ctx( - reqargs, self.route, current - ) - ) - data = self.cache_rule.cache.get(key) - if data is not None: - response.headers.update(data['headers']) - return data['content'] - content = await self.f(**reqargs) - if response.status == 200: - self.cache_rule.cache.set( - key, - {'content': content, 'headers': response.headers}, - self.cache_rule.duration - ) - return content - - async def dispatch(self, reqargs, response): - content = await self.get_data(reqargs, response) - return self.response_builder(content, response) - - -class CacheOpenDispatcher(CacheDispatcher): - __slots__ = [] - - async def dispatch(self, reqargs, response): - await self._parallel_flow(self.flow_open) - return await super().dispatch(reqargs, response) - - -class CacheCloseDispatcher(CacheDispatcher): - __slots__ = [] - - async def dispatch(self, reqargs, response): - try: - content = await self.get_data(reqargs, response) - except Exception: - await self._parallel_flow(self.flow_close) - raise - await self._parallel_flow(self.flow_close) - return self.response_builder(content, response) - - -class CacheFlowDispatcher(CacheDispatcher): - __slots__ = [] - - async def dispatch(self, reqargs, response): - await self._parallel_flow(self.flow_open) - try: - content = await self.get_data(reqargs, response) - except Exception: - await self._parallel_flow(self.flow_close) - raise - await self._parallel_flow(self.flow_close) - return self.response_builder(content, response) diff --git a/emmett/routing/response.py b/emmett/routing/response.py index aa7ae6fc..2b6d1407 100644 --- a/emmett/routing/response.py +++ b/emmett/routing/response.py @@ -1,146 +1,72 @@ # -*- coding: utf-8 -*- """ - emmett.routing.response - ----------------------- +emmett.routing.response +----------------------- - Provides response builders for http routes. +Provides response builders for http routes. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ from __future__ import annotations -from typing import Any, Dict, Union +from typing import Any, Dict, Tuple, Union +from emmett_core.http.response import HTTPResponse, HTTPStringResponse +from emmett_core.routing.response import ResponseProcessor from renoir.errors import TemplateMissingError from ..ctx import current from ..helpers import load_component from ..html import asis -from ..http import HTTPResponse, HTTP, HTTPBytes -from ..wrappers.response import Response -from .rules import HTTPRoutingRule from .urls import url -_html_content_type = 'text/html; charset=utf-8' - -class MetaResponseBuilder: - def __init__(self, route: HTTPRoutingRule): - self.route = route - - def __call__(self, output: Any, response: Response) -> HTTPResponse: - raise NotImplementedError - - -class ResponseBuilder(MetaResponseBuilder): - http_cls = HTTP - - def __call__(self, output: Any, response: Response) -> HTTP: - return self.http_cls( - response.status, - output, - headers=response.headers, - cookies=response.cookies - ) - - -class EmptyResponseBuilder(ResponseBuilder): - http_cls = HTTPResponse - - def __call__(self, output: Any, response: Response) -> HTTPResponse: - return self.http_cls( - response.status, - headers=response.headers, - cookies=response.cookies - ) - - -class ResponseProcessor(ResponseBuilder): - def process(self, output: Any, response: Response): - raise NotImplementedError - - def __call__(self, output: Any, response: Response) -> HTTP: - return self.http_cls( - response.status, - self.process(output, response), - headers=response.headers, - cookies=response.cookies - ) - - -class BytesResponseBuilder(MetaResponseBuilder): - http_cls = HTTPBytes - - def __call__(self, output: Any, response: Response) -> HTTPBytes: - return self.http_cls( - response.status, - output, - headers=response.headers, - cookies=response.cookies - ) +_html_content_type = "text/html; charset=utf-8" class TemplateResponseBuilder(ResponseProcessor): - def process( - self, - output: Union[Dict[str, Any], None], - response: Response - ) -> str: - response.headers._data['content-type'] = _html_content_type - base_ctx = { - 'current': current, - 'url': url, - 'asis': asis, - 'load_component': load_component - } + def process(self, output: Union[Dict[str, Any], None], response) -> str: + response.headers._data["content-type"] = _html_content_type + base_ctx = {"current": current, "url": url, "asis": asis, "load_component": load_component} output = base_ctx if output is None else {**base_ctx, **output} try: - return self.route.app.templater.render( - self.route.template, output - ) + return self.route.app.templater.render(self.route.template, output) except TemplateMissingError as exc: - raise HTTP( - 404, - body="{}\n".format(exc.message), - cookies=response.cookies - ) + raise HTTPStringResponse(404, body="{}\n".format(exc.message), cookies=response.cookies) + + +class SnippetResponseBuilder(ResponseProcessor): + def process(self, output: Tuple[str, Union[Dict[str, Any], None]], response) -> str: + response.headers._data["content-type"] = _html_content_type + template, output = output + base_ctx = {"current": current, "url": url, "asis": asis, "load_component": load_component} + output = base_ctx if output is None else {**base_ctx, **output} + return self.route.app.templater._render(template, f"_snippet.{current.request.name}", output) class AutoResponseBuilder(ResponseProcessor): - def process(self, output: Any, response: Response) -> str: - is_template = False + def process(self, output: Any, response) -> str: + is_template, snippet = False, None + if isinstance(output, tuple): + snippet, output = output if isinstance(output, dict): is_template = True - output = { - **{ - 'current': current, - 'url': url, - 'asis': asis, - 'load_component': load_component - }, - **output - } + output = {**{"current": current, "url": url, "asis": asis, "load_component": load_component}, **output} elif output is None: - output = { - 'current': current, - 'url': url, - 'asis': asis, - 'load_component': load_component - } + is_template = True + output = {"current": current, "url": url, "asis": asis, "load_component": load_component} if is_template: - response.headers._data['content-type'] = _html_content_type + response.headers._data["content-type"] = _html_content_type + if snippet is not None: + return self.route.app.templater._render(snippet, f"_snippet.{current.request.name}", output) try: - return self.route.app.templater.render( - self.route.template, output - ) + return self.route.app.templater.render(self.route.template, output) except TemplateMissingError as exc: - raise HTTP( - 404, - body="{}\n".format(exc.message), - cookies=response.cookies - ) + raise HTTPStringResponse(404, body="{}\n".format(exc.message), cookies=response.cookies) elif isinstance(output, str): return output + elif isinstance(output, HTTPResponse): + return output return str(output) diff --git a/emmett/routing/router.py b/emmett/routing/router.py index 0b1b1fdd..5fa1d913 100644 --- a/emmett/routing/router.py +++ b/emmett/routing/router.py @@ -1,371 +1,40 @@ # -*- coding: utf-8 -*- """ - emmett.routing.router - --------------------- +emmett.routing.router +--------------------- - Provides router implementations. +Provides router implementations. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ from __future__ import annotations -import re - -from collections import namedtuple -from typing import Any, Callable, Dict, List, Type - -from ..ctx import current -from ..extensions import Signals -from ..http import HTTP -from .response import ( - MetaResponseBuilder, - EmptyResponseBuilder, - ResponseBuilder, - AutoResponseBuilder, - BytesResponseBuilder, - TemplateResponseBuilder +from emmett_core.routing.router import ( + HTTPRouter as _HTTPRouter, + RoutingCtx as RoutingCtx, + RoutingCtxGroup as RoutingCtxGroup, + WebsocketRouter as WebsocketRouter, ) -from .rules import RoutingRule, HTTPRoutingRule, WebsocketRoutingRule - - -RouteRecReq = namedtuple( - "RouteRecReq", ["name", "match", "dispatch"] -) -RouteRecWS = namedtuple( - "RouteRecWS", ["name", "match", "dispatch", "flow_recv", "flow_send"] -) - - -class Router: - __slots__ = [ - '_get_routes_in_for_host', - '_match_lang', - '_prefix_main_len', - '_prefix_main', - '_routes_nohost', - '_routes_str', - 'app', - 'routes_in', - 'routes_out', - 'routes' - ] - - _outputs: Dict[str, Type[MetaResponseBuilder]] = {} - _routing_rule_cls: Type[RoutingRule] = RoutingRule - _routing_signal = Signals.before_routes - _routing_started = False - _routing_stack: List[RoutingRule] = [] - _re_components = re.compile(r'(\()?([^<\w]+)?<(\w+)\:(\w+)>(\)\?)?') - - def __init__(self, app, url_prefix=None): - self.app = app - self.routes = [] - self.routes_in = {'__any__': self._build_routing_dict()} - self.routes_out = {} - self._routes_str = {} - self._routes_nohost = (self.routes_in['__any__'], ) - self._get_routes_in_for_host = self._get_routes_in_for_host_nomatch - main_prefix = url_prefix or '' - if main_prefix: - main_prefix = main_prefix.rstrip('/') - if not main_prefix.startswith('/'): - main_prefix = '/' + main_prefix - if main_prefix == '/': - main_prefix = '' - self._prefix_main = main_prefix - self._prefix_main_len = len(self._prefix_main) - self._set_language_handling() - - def _set_language_handling(self): - self._match_lang = ( - self._match_with_lang if self.app.language_force_on_url - else self._match_no_lang) - - @property - def static_versioning(self): - return ( - self.app.config.static_version_urls and - self.app.config.static_version - ) or '' - - @staticmethod - def _build_routing_dict(): - return {'static': {}, 'match': {}} - @classmethod - def build_route_components(cls, path): - components = [] - params = [] - for match in cls._re_components.findall(path): - params.append(match[1] + "{}") - statics = cls._re_components.sub("{}", path).split("{}") - if not params: - components = statics - else: - components.append(statics[0]) - for idx, _ in enumerate(params): - components.append(params[idx] + statics[idx + 1]) - return components +from .response import AutoResponseBuilder, SnippetResponseBuilder, TemplateResponseBuilder +from .rules import HTTPRoutingRule - def _get_routes_in_for_host_match(self, wrapper): - return ( - self.routes_in.get(wrapper.host, self.routes_in['__any__']), - self.routes_in['__any__']) - def _get_routes_in_for_host_nomatch(self, wrapper): - return self._routes_nohost - - def _match_with_lang(self, wrapper, path): - path, lang = self._split_lang(path) - current.language = wrapper.language = lang - return path - - def _match_no_lang(self, wrapper, path): - wrapper.language = None - return path - - @staticmethod - def remove_trailslash(path): - return path.rstrip('/') or path - - def _split_lang(self, path): - default = self.app.language_default - if len(path) <= 1: - return path, default - clean_path = path.lstrip('/') - if clean_path[2:3] == '/': - lang, new_path = clean_path[:2], clean_path[2:] - if lang != default and lang in self.app._languages_set: - return new_path, lang - return path, default - - def add_route(self, route): - raise NotImplementedError - - def match(self, wrapper): - raise NotImplementedError - - async def dispatch(self): - raise NotImplementedError - - def __call__(self, *args, **kwargs): - if not self.__class__._routing_started: - self.__class__._routing_started = True - self.app.send_signal(self._routing_signal) - return RoutingCtx(self, self._routing_rule_cls, *args, **kwargs) - - @classmethod - def exposing(cls): - return cls._routing_stack[-1] - - -class HTTPRouter(Router): - __slots__ = ['pipeline', 'injectors'] +class HTTPRouter(_HTTPRouter): + __slots__ = ["injectors"] _routing_rule_cls = HTTPRoutingRule - _routing_rec_builder = RouteRecReq - _outputs = { - 'empty': EmptyResponseBuilder, - 'auto': AutoResponseBuilder, - 'bytes': BytesResponseBuilder, - 'str': ResponseBuilder, - 'template': TemplateResponseBuilder + **_HTTPRouter._outputs, + **{ + "auto": AutoResponseBuilder, + "template": TemplateResponseBuilder, + "snippet": SnippetResponseBuilder, + }, } def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.pipeline = [] self.injectors = [] - - @staticmethod - def _build_routing_dict(): - rv = {} - for scheme in ['http', 'https']: - rv[scheme] = {} - for method in [ - 'DELETE', 'GET', 'HEAD', 'OPTIONS', 'PATCH', 'POST', 'PUT' - ]: - rv[scheme][method] = {'static': {}, 'match': {}} - return rv - - def add_route_str(self, route): - self._routes_str[route.name] = "%s %s://%s%s%s -> %s" % ( - "|".join(route.methods), - "|".join(route.schemes), - route.hostname or "", - self._prefix_main, - route.path, - route.name - ) - - def add_route(self, route): - self.routes.append(route) - host = route.hostname or '__any__' - if host not in self.routes_in: - self.routes_in[host] = self._build_routing_dict() - self._get_routes_in_for_host = self._get_routes_in_for_host_match - for scheme in route.schemes: - for method in route.methods: - routing_dict = self.routes_in[host][scheme][method] - slot, key = ( - ('static', route.path) if route.is_static else - ('match', route.name) - ) - routing_dict[slot][key] = self._routing_rec_builder( - name=route.name, - match=route.match, - dispatch=route.dispatchers[method].dispatch - ) - self.routes_out[route.name] = { - 'host': route.hostname, - 'path': self.build_route_components(route.path) - } - self.add_route_str(route) - - def match(self, request): - path = self._match_lang( - request, - self.remove_trailslash(request.path) - ) - for routing_dict in self._get_routes_in_for_host(request): - sub_dict = routing_dict[request.scheme][request.method] - element = sub_dict['static'].get(path) - if element: - return element, {} - for element in sub_dict['match'].values(): - match, args = element.match(path) - if match: - return element, args - return None, {} - - async def dispatch(self, request, response): - match, reqargs = self.match(request) - if not match: - raise HTTP(404, body="Resource not found\n") - request.name = match.name - return await match.dispatch(reqargs, response) - - -class WebsocketRouter(Router): - __slots__ = ['pipeline'] - - _routing_rule_cls = WebsocketRoutingRule - _routing_rec_builder = RouteRecWS - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.pipeline = [] - - @staticmethod - def _build_routing_dict(): - rv = {} - for scheme in ['http', 'https', 'ws', 'wss']: - rv[scheme] = {'static': {}, 'match': {}} - return rv - - @staticmethod - def _all_schemes_from_route(schemes): - auto = {'ws': ['ws', 'http'], 'wss': ['wss', 'https']} - rv = [] - for scheme in schemes: - rv.extend(auto[scheme]) - return rv - - def add_route_str(self, route): - self._routes_str[route.name] = "%s://%s%s%s -> %s" % ( - "|".join(route.schemes), - route.hostname or "", - self._prefix_main, - route.path, - route.name - ) - - def add_route(self, route): - self.routes.append(route) - host = route.hostname or '__any__' - if host not in self.routes_in: - self.routes_in[host] = self._build_routing_dict() - self._get_routes_in_for_host = self._get_routes_in_for_host_match - for scheme in self._all_schemes_from_route(route.schemes): - routing_dict = self.routes_in[host][scheme] - slot, key = ( - ('static', route.path) if route.is_static else - ('match', route.name) - ) - routing_dict[slot][key] = self._routing_rec_builder( - name=route.name, - match=route.match, - dispatch=route.dispatcher.dispatch, - flow_recv=route.pipeline_flow_receive, - flow_send=route.pipeline_flow_send - ) - self.routes_out[route.name] = { - 'host': route.hostname, - 'path': self.build_route_components(route.path) - } - self.add_route_str(route) - - def match(self, websocket): - path = self._match_lang( - websocket, - self.remove_trailslash(websocket.path) - ) - for routing_dict in self._get_routes_in_for_host(websocket): - sub_dict = routing_dict[websocket.scheme] - element = sub_dict['static'].get(path) - if element: - return element, {} - for element in sub_dict['match'].values(): - match, args = element.match(path) - if match: - return element, args - return None, {} - - async def dispatch(self, websocket): - match, reqargs = self.match(websocket) - if not match: - raise HTTP(404, body="Resource not found\n") - websocket.name = match.name - websocket._bind_flow( - match.flow_recv, - match.flow_send - ) - await match.dispatch(reqargs) - - -class RoutingCtx: - __slots__ = ['router', 'rule'] - - def __init__( - self, - router: Router, - rule_cls: Type[RoutingRule], - *args, - **kwargs - ): - self.router = router - self.rule = rule_cls(self.router, *args, **kwargs) - self.router._routing_stack.append(self.rule) - - def __call__(self, f: Callable[..., Any]) -> Callable[..., Any]: - self.router.app.send_signal(Signals.before_route, route=self.rule, f=f) - rv = self.rule(f) - self.router.app.send_signal(Signals.after_route, route=self.rule) - self.router._routing_stack.pop() - return rv - - -class RoutingCtxGroup: - __slots__ = ['ctxs'] - - def __init__(self, ctxs: List[RoutingCtx]): - self.ctxs = ctxs - - def __call__(self, f: Callable[..., Any]) -> Callable[..., Any]: - rv = f - for ctx in self.ctxs: - rv = ctx(f) - return rv diff --git a/emmett/routing/routes.py b/emmett/routing/routes.py index 698d05c4..e1201230 100644 --- a/emmett/routing/routes.py +++ b/emmett/routing/routes.py @@ -1,176 +1,57 @@ # -*- coding: utf-8 -*- """ - emmett.routing.routes - --------------------- +emmett.routing.routes +--------------------- - Provides routes objects. +Provides routes objects. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ import re - from functools import wraps import pendulum +from emmett_core.http.response import HTTPResponse +from emmett_core.routing.routes import HTTPRoute as _HTTPRoute -from ..http import HTTP -from .dispatchers import ( - CacheCloseDispatcher, - CacheDispatcher, - CacheFlowDispatcher, - CacheOpenDispatcher, - Dispatcher, - RequestCloseDispatcher, - RequestDispatcher, - RequestFlowDispatcher, - RequestOpenDispatcher, - WSCloseDispatcher, - WSFlowDispatcher, - WSOpenDispatcher -) - -REGEX_INT = re.compile(r'') -REGEX_STR = re.compile(r'') -REGEX_ANY = re.compile(r'') -REGEX_ALPHA = re.compile(r'') -REGEX_DATE = re.compile(r'') -REGEX_FLOAT = re.compile(r'') - -class Route: - __slots__ = [ - 'f', - 'hostname', - 'is_static', - 'match', - 'name', - 'parse_reqargs', - 'path', - 'pipeline_flow_close', - 'pipeline_flow_open', - 'regex', - 'schemes' - ] - _re_condl = re.compile(r'\(.*\)\?') - _re_param = re.compile(r'<(\w+)\:(\w+)>') +class HTTPRoute(_HTTPRoute): + __slots__ = [] def __init__(self, rule, path, idx): - self.name = rule.name if idx == 0 else f"{rule.name}.{idx}" - self.f = rule.f - if not path.startswith('/'): - path = '/' + path - if rule.prefix: - path = (path != '/' and rule.prefix + path) or rule.prefix - self.path = path - self.schemes = tuple(rule.schemes) - self.hostname = rule.hostname - self.regex = re.compile(self.build_regex(self.path)) - self.pipeline_flow_open = rule.pipeline_flow_open - self.pipeline_flow_close = rule.pipeline_flow_close - self.build_matcher() + super().__init__(rule, path, idx) self.build_argparser() - @staticmethod - def build_regex(path): - path = REGEX_INT.sub(r'(?P<\g<1>>\\d+)', path) - path = REGEX_STR.sub(r'(?P<\g<1>>[^/]+)', path) - path = REGEX_ANY.sub(r'(?P<\g<1>>.*)', path) - path = REGEX_ALPHA.sub(r'(?P<\g<1>>[^/\\W\\d_]+)', path) - path = REGEX_DATE.sub(r'(?P<\g<1>>\\d{4}-\\d{2}-\\d{2})', path) - path = REGEX_FLOAT.sub(r'(?P<\g<1>>\\d+\.\\d+)', path) - return f'^({path})$' - - def match_simple(self, path): - return path == self.path, {} - - def match_regex(self, path): - match = self.regex.match(path) - if match: - return True, self.parse_reqargs(match) - return False, {} - - def build_matcher(self): - if ( - self._re_condl.findall(self.path) or - self._re_param.findall(self.path) - ): - matcher, is_static = self.match_regex, False - else: - matcher, is_static = self.match_simple, True - self.match = matcher - self.is_static = is_static - def build_argparser(self): - parsers = { - 'int': self._parse_int_reqarg, - 'float': self._parse_float_reqarg, - 'date': self._parse_date_reqarg - } - opt_parsers = { - 'int': self._parse_int_reqarg_opt, - 'float': self._parse_float_reqarg_opt, - 'date': self._parse_date_reqarg_opt - } + parsers = {"date": self._parse_date_reqarg} + opt_parsers = {"date": self._parse_date_reqarg_opt} pipeline = [] for key in parsers.keys(): optionals = [] - for element in re.compile( - r'\(([^<]+)?<{}\:(\w+)>\)\?'.format(key) - ).findall(self.path): + for element in re.compile(r"\(([^<]+)?<{}\:(\w+)>\)\?".format(key)).findall(self.path): optionals.append(element[1]) - elements = set( - re.compile(r'<{}\:(\w+)>'.format(key)).findall(self.path)) + elements = set(re.compile(r"<{}\:(\w+)>".format(key)).findall(self.path)) args = elements - set(optionals) if args: parser = self._wrap_reqargs_parser(parsers[key], args) pipeline.append(parser) if optionals: - parser = self._wrap_reqargs_parser( - opt_parsers[key], optionals) + parser = self._wrap_reqargs_parser(opt_parsers[key], optionals) pipeline.append(parser) if pipeline: - self.parse_reqargs = self._wrap_reqargs_pipeline(pipeline) - else: - self.parse_reqargs = self._parse_reqargs - - @staticmethod - def _parse_reqargs(match): - return match.groupdict() - - @staticmethod - def _parse_int_reqarg(args, route_args): - for arg in args: - route_args[arg] = int(route_args[arg]) - - @staticmethod - def _parse_int_reqarg_opt(args, route_args): - for arg in args: - if route_args[arg] is None: - continue - route_args[arg] = int(route_args[arg]) - - @staticmethod - def _parse_float_reqarg(args, route_args): - for arg in args: - route_args[arg] = float(route_args[arg]) - - @staticmethod - def _parse_float_reqarg_opt(args, route_args): - for arg in args: - if route_args[arg] is None: - continue - route_args[arg] = float(route_args[arg]) + for key, dispatcher in self.dispatchers.items(): + self.dispatchers[key] = DispacherWrapper(dispatcher, pipeline) @staticmethod def _parse_date_reqarg(args, route_args): try: for arg in args: - route_args[arg] = pendulum.DateTime.strptime( - route_args[arg], "%Y-%m-%d") + dt = route_args[arg] + route_args[arg] = pendulum.datetime(dt.year, dt.month, dt.day) except Exception: - raise HTTP(404) + raise HTTPResponse(404) @staticmethod def _parse_date_reqarg_opt(args, route_args): @@ -178,89 +59,28 @@ def _parse_date_reqarg_opt(args, route_args): for arg in args: if route_args[arg] is None: continue - route_args[arg] = pendulum.DateTime.strptime( - route_args[arg], "%Y-%m-%d") + dt = route_args[arg] + route_args[arg] = pendulum.datetime(dt.year, dt.month, dt.day) except Exception: - raise HTTP(404) + raise HTTPResponse(404) @staticmethod def _wrap_reqargs_parser(parser, args): @wraps(parser) def wrapped(route_args): return parser(args, route_args) - return wrapped - @staticmethod - def _wrap_reqargs_pipeline(parsers): - def wrapped(match): - route_args = match.groupdict() - for parser in parsers: - parser(route_args) - return route_args return wrapped - def build_dispatcher(self, rule): - raise NotImplementedError +class DispacherWrapper: + __slots__ = ["dispatcher", "parsers"] -class HTTPRoute(Route): - __slots__ = ['methods', 'dispatchers'] - - def __init__(self, rule, path, idx): - super().__init__(rule, path, idx) - self.methods = tuple(method.upper() for method in rule.methods) - self.build_dispatchers(rule) - - def build_dispatchers(self, rule): - dispatchers = { - 'base': (RequestDispatcher, CacheDispatcher), - 'open': (RequestOpenDispatcher, CacheOpenDispatcher), - 'close': (RequestCloseDispatcher, CacheCloseDispatcher), - 'flow': (RequestFlowDispatcher, CacheFlowDispatcher) - } - if self.pipeline_flow_open and self.pipeline_flow_close: - dispatcher, cdispatcher = dispatchers['flow'] - elif self.pipeline_flow_open and not self.pipeline_flow_close: - dispatcher, cdispatcher = dispatchers['open'] - elif not self.pipeline_flow_open and self.pipeline_flow_close: - dispatcher, cdispatcher = dispatchers['close'] - else: - dispatcher, cdispatcher = dispatchers['base'] - self.dispatchers = {} - for method in self.methods: - dispatcher_cls = ( - cdispatcher if rule.cache_rule and method in ['HEAD', 'GET'] - else dispatcher - ) - self.dispatchers[method] = dispatcher_cls( - self, - rule, - rule.head_builder if method == 'HEAD' else rule.response_builder - ) - - -class WebsocketRoute(Route): - __slots__ = ['pipeline_flow_receive', 'pipeline_flow_send', 'dispatcher'] - - def __init__(self, rule, path, idx): - super().__init__(rule, path, idx) - self.pipeline_flow_receive = rule.pipeline_flow_receive - self.pipeline_flow_send = rule.pipeline_flow_send - self.build_dispatcher(rule) + def __init__(self, dispatcher, parsers): + self.dispatcher = dispatcher + self.parsers = parsers - def build_dispatcher(self, rule): - dispatchers = { - 'base': Dispatcher, - 'open': WSOpenDispatcher, - 'close': WSCloseDispatcher, - 'flow': WSFlowDispatcher - } - if self.pipeline_flow_open and self.pipeline_flow_close: - dispatcher = dispatchers['flow'] - elif self.pipeline_flow_open and not self.pipeline_flow_close: - dispatcher = dispatchers['open'] - elif not self.pipeline_flow_open and self.pipeline_flow_close: - dispatcher = dispatchers['close'] - else: - dispatcher = dispatchers['base'] - self.dispatcher = dispatcher(self) + def dispatch(self, reqargs, response): + for parser in self.parsers: + parser(reqargs) + return self.dispatcher.dispatch(reqargs, response) diff --git a/emmett/routing/rules.py b/emmett/routing/rules.py index 79d74b7d..3891ad72 100644 --- a/emmett/routing/rules.py +++ b/emmett/routing/rules.py @@ -1,201 +1,71 @@ # -*- coding: utf-8 -*- """ - emmett.routing.rules - -------------------- +emmett.routing.rules +-------------------- - Provides routing rules definition apis. +Provides routing rules definition apis. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ from __future__ import annotations import os - from typing import Any, Callable -from ..cache import RouteCacheRule -from ..pipeline import RequestPipeline, WebsocketPipeline, Pipe -from .routes import HTTPRoute, WebsocketRoute - - -class RoutingRule: - __slots__ = ['router'] - - def __init__(self, router, *args, **kwargs): - self.router = router +from emmett_core.routing.rules import HTTPRoutingRule as _HTTPRoutingRule - @property - def app(self): - return self.router.app - - def build_name(self, f): - filename = os.path.realpath(f.__code__.co_filename) - short = filename[1 + len(self.app.root_path):].rsplit('.', 1)[0] - if not short: - short = filename.rsplit('.', 1)[0] - if short == "__init__": - short = self.app.root_path.rsplit('/', 1)[-1] - #: allow only one naming level if name is not provided - if len(short.split(os.sep)) > 1: - short = short.split(os.sep)[-1] - return '.'.join(short.split(os.sep) + [f.__name__]) - - def __call__(self, f: Callable[..., Any]) -> Callable[..., Any]: - raise NotImplementedError +from ..ctx import current +from .routes import HTTPRoute -class HTTPRoutingRule(RoutingRule): - __slots__ = [ - 'cache_rule', - 'f', - 'head_builder', - 'hostname', - 'methods', - 'name', - 'output_type', - 'paths', - 'pipeline_flow_close', - 'pipeline_flow_open', - 'pipeline', - 'prefix', - 'response_builder', - 'schemes', - 'template_folder', - 'template_path', - 'template' - ] +class HTTPRoutingRule(_HTTPRoutingRule): + __slots__ = ["injectors", "template_folder", "template_path", "template"] + current = current + route_cls = HTTPRoute def __init__( - self, router, paths=None, name=None, template=None, pipeline=None, - injectors=None, schemes=None, hostname=None, methods=None, prefix=None, - template_folder=None, template_path=None, cache=None, output='auto' + self, + router, + paths=None, + name=None, + template=None, + pipeline=None, + injectors=None, + schemes=None, + hostname=None, + methods=None, + prefix=None, + template_folder=None, + template_path=None, + cache=None, + output="auto", ): - super().__init__(router) - self.name = name - self.paths = paths - if self.paths is None: - self.paths = [] - if not isinstance(self.paths, (list, tuple)): - self.paths = (self.paths,) - self.schemes = schemes or ('http', 'https') - if not isinstance(self.schemes, (list, tuple)): - self.schemes = (self.schemes,) - self.methods = methods or ('get', 'post', 'head') - if not isinstance(self.methods, (list, tuple)): - self.methods = (self.methods,) - self.hostname = hostname or self.app.config.hostname_default - if prefix: - if not prefix.startswith('/'): - prefix = '/' + prefix - self.prefix = prefix - if output not in self.router._outputs: - raise SyntaxError( - 'Invalid output specified. Allowed values are: {}'.format( - ', '.join(self.router._outputs.keys()))) - self.output_type = output + super().__init__( + router, + paths=paths, + name=name, + pipeline=pipeline, + schemes=schemes, + hostname=hostname, + methods=methods, + prefix=prefix, + cache=cache, + output=output, + ) self.template = template self.template_folder = template_folder self.template_path = template_path or self.app.template_path - self.pipeline = ( - self.router.pipeline + (pipeline or []) + - self.router.injectors + (injectors or [])) - self.cache_rule = None - if cache: - if not isinstance(cache, RouteCacheRule): - raise RuntimeError( - 'route cache argument should be a valid caching rule') - if any(key in self.methods for key in ['get', 'head']): - self.cache_rule = cache - # check pipes are indeed valid pipes - if any(not isinstance(pipe, Pipe) for pipe in self.pipeline): - raise RuntimeError('Invalid pipeline') + self.pipeline = self.pipeline + self.router.injectors + (injectors or []) def _make_builders(self, output_type): builder_cls = self.router._outputs[output_type] - return builder_cls(self), self.router._outputs['empty'](self) + return builder_cls(self), self.router._outputs["empty"](self) def __call__(self, f: Callable[..., Any]) -> Callable[..., Any]: - if not self.paths: - self.paths.append("/" + f.__name__) - if not self.name: - self.name = self.build_name(f) - # is it good? - if self.name.endswith("."): - self.name = self.name + f.__name__ - # if not self.template: self.template = f.__name__ + self.app.template_default_extension if self.template_folder: self.template = os.path.join(self.template_folder, self.template) - pipeline_obj = RequestPipeline(self.pipeline) - wrapped_f = pipeline_obj(f) - self.pipeline_flow_open = pipeline_obj._flow_open() - self.pipeline_flow_close = pipeline_obj._flow_close() - self.f = wrapped_f - output_type = pipeline_obj._output_type() or self.output_type - self.response_builder, self.head_builder = self._make_builders(output_type) - for idx, path in enumerate(self.paths): - self.router.add_route(HTTPRoute(self, path, idx)) - return f - - -class WebsocketRoutingRule(RoutingRule): - __slots__ = [ - 'f', - 'hostname', - 'name', - 'paths', - 'pipeline_flow_close', - 'pipeline_flow_open', - 'pipeline_flow_receive', - 'pipeline_flow_send', - 'pipeline', - 'prefix', - 'schemes' - ] - - def __init__( - self, router, paths=None, name=None, pipeline=None, schemes=None, - hostname=None, prefix=None - ): - super().__init__(router) - self.name = name - self.paths = paths - if self.paths is None: - self.paths = [] - if not isinstance(self.paths, (list, tuple)): - self.paths = (self.paths,) - self.schemes = schemes or ('ws', 'wss') - if not isinstance(self.schemes, (list, tuple)): - self.schemes = (self.schemes,) - self.hostname = hostname or self.app.config.hostname_default - if prefix: - if not prefix.startswith('/'): - prefix = '/' + prefix - self.prefix = prefix - self.pipeline = self.router.pipeline + (pipeline or []) - # check pipes are indeed valid pipes - if any(not isinstance(pipe, Pipe) for pipe in self.pipeline): - raise RuntimeError('Invalid pipeline') - - def __call__(self, f: Callable[..., Any]) -> Callable[..., Any]: - if not self.paths: - self.paths.append("/" + f.__name__) - if not self.name: - self.name = self.build_name(f) - # is it good? - if self.name.endswith("."): - self.name = self.name + f.__name__ - # - pipeline_obj = WebsocketPipeline(self.pipeline) - wrapped_f = pipeline_obj(f) - self.pipeline_flow_open = pipeline_obj._flow_open() - self.pipeline_flow_close = pipeline_obj._flow_close() - self.pipeline_flow_receive = pipeline_obj._flow_receive() - self.pipeline_flow_send = pipeline_obj._flow_send() - self.f = wrapped_f - for idx, path in enumerate(self.paths): - self.router.add_route(WebsocketRoute(self, path, idx)) - return f + return super().__call__(f) diff --git a/emmett/routing/urls.py b/emmett/routing/urls.py index 86441682..f0604fa1 100644 --- a/emmett/routing/urls.py +++ b/emmett/routing/urls.py @@ -1,271 +1,17 @@ # -*- coding: utf-8 -*- """ - emmett.routing.urls - ------------------- +emmett.routing.urls +------------------- - Provides url builder apis. +Provides url builder apis. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ -from urllib.parse import quote as uquote +from emmett_core.routing.urls import Url from ..ctx import current -class UrlBuilder: - __slots__ = ('components', '_args') - - def __init__(self, components=[]): - if not components: - self.components = ['/{}'] - self._args = [''] - else: - self.components = ['{}'] + components[1:] - self._args = [components[0]] - - @property - def path(self): - return self._args[0] - - def arg(self, value): - if not self.components: - self.components.append('/{}') - return self.components.pop(0).format(value) - - def add_prefix(self, args): - if current.app._router_http._prefix_main: - self.components.insert(0, '{}') - args.insert(0, current.app._router_http._prefix_main) - - def add_language(self, args, language): - if language: - self.components.insert(0, '/{}') - args.insert(0, language) - - def path_prefix(self, scheme, host): - if scheme and host: - return '{}://{}'.format(scheme, host) - return '' - - def args(self, args): - rv = '' - for arg in args: - rv += self.arg(arg) - return rv - - def params(self, params): - if params: - return '?' + '&'.join( - '%s=%s' % (uquote(str(k)), uquote(str(v))) - for k, v in params.items() - ) - return '' - - def url(self, scheme, host, language, args, params): - args = self._args + args - self.add_language(args, language) - self.add_prefix(args) - return ( - f'{self.path_prefix(scheme, host)}{self.args(args)}' - f'{self.params(params)}' - ) - - -class HttpUrlBuilder(UrlBuilder): - __slots__ = ('components', '_args') - - def add_static_versioning(self, args): - versioning = current.app._router_http.static_versioning - if (self.path.startswith('/static') and versioning): - self.components.insert(1, "/_{}") - args.insert(1, str(versioning)) - - def anchor(self, anchor): - rv = '' - if anchor: - if not isinstance(anchor, (list, tuple)): - anchor = [anchor] - for element in anchor: - rv += '#{}'.format(element) - return rv - - def url(self, scheme, host, language, args, params, anchor): - args = self._args + args - self.add_static_versioning(args) - self.add_language(args, language) - self.add_prefix(args) - return ( - f'{self.path_prefix(scheme, host)}{self.args(args)}' - f'{self.params(params)}{self.anchor(anchor)}' - ) - - -class Url: - __slots__ = [] - http_to_ws_schemes = {'http': 'ws', 'https': 'wss'} - - def http( - self, path, args=[], params={}, anchor=None, - sign=None, scheme=None, host=None, language=None - ): - if not isinstance(args, (list, tuple)): - args = [args] - # allow user to use url('static', 'file') - if path == 'static': - path = '/static' - # routes urls with 'dot' notation - if '/' not in path: - module = None - # urls like 'function' refers to same module - if '.' not in path: - namespace = current.app.config.url_default_namespace or \ - current.app.name - path = namespace + "." + path - # urls like '.function' refers to main app module - elif path.startswith('.'): - if not hasattr(current, 'request'): - raise RuntimeError( - f'cannot build url("{path}",...) ' - 'without current request' - ) - module = current.request.name.rsplit('.', 1)[0] - path = module + path - # find correct route - try: - url_components = current.app._router_http.routes_out[path]['path'] - url_host = current.app._router_http.routes_out[path]['host'] - builder = HttpUrlBuilder(url_components) - # try to use the correct hostname - if url_host is not None: - try: - if current.request.host != url_host: - scheme = current.request.scheme - host = url_host - except Exception: - pass - except KeyError: - if path.endswith('.static'): - module = module or path.rsplit('.', 1)[0] - builder = HttpUrlBuilder([f'/static/__{module}__']) - else: - raise RuntimeError(f'invalid url("{path}",...)') - # handle classic urls - else: - builder = HttpUrlBuilder([path]) - # add language - lang = None - if current.app.language_force_on_url: - if language: - #: use the given language if is enabled in application - if language in current.app.languages: - lang = language - else: - #: try to use the request language if context exists - if hasattr(current, 'request'): - lang = current.request.language - if lang == current.app.language_default: - lang = None - # # add extension (useless??) - # if extension: - # url = url + '.' + extension - # scheme=True means to use current scheme - if scheme is True: - if not hasattr(current, 'request'): - raise RuntimeError( - f'cannot build url("{path}",...) without current request' - ) - scheme = current.request.scheme - # add scheme and host - if scheme: - if host is None: - if not hasattr(current, 'request'): - raise RuntimeError( - f'cannot build url("{path}",...) ' - 'without current request' - ) - host = current.request.host - # add signature - if sign: - if '_signature' in params: - del params['_signature'] - params['_signature'] = sign( - path, args, params, anchor, scheme, host, language) - return builder.url(scheme, host, lang, args, params, anchor) - - def ws( - self, path, args=[], params={}, scheme=None, host=None, language=None - ): - if not isinstance(args, (list, tuple)): - args = [args] - # routes urls with 'dot' notation - if '/' not in path: - # urls like 'function' refers to same module - if '.' not in path: - namespace = current.app.config.url_default_namespace or current.app.name - path = namespace + "." + path - # urls like '.function' refers to main app module - elif path.startswith('.'): - if not hasattr(current, 'request'): - raise RuntimeError( - f'cannot build url("{path}",...) ' - 'without current request' - ) - module = current.request.name.rsplit('.', 1)[0] - path = module + path - # find correct route - try: - url_components = current.app._router_ws.routes_out[path]['path'] - url_host = current.app._router_ws.routes_out[path]['host'] - builder = UrlBuilder(url_components) - # try to use the correct hostname - if url_host is not None: - # TODO: remap host - try: - if current.request.host != url_host: - scheme = self.http_to_ws_schemes[current.request.scheme] - host = url_host - except Exception: - pass - except KeyError: - raise RuntimeError(f'invalid url("{path}",...)') - # handle classic urls - else: - builder = UrlBuilder([path]) - # add language - lang = None - if current.app.language_force_on_url: - if language: - #: use the given language if is enabled in application - if language in current.app.languages: - lang = language - else: - #: try to use the request language if context exists - if hasattr(current, 'request'): - lang = current.request.language - if lang == current.app.language_default: - lang = None - # scheme=True means to use current scheme - if scheme is True: - if not hasattr(current, 'request'): - raise RuntimeError( - f'cannot build url("{path}",...) without current request' - ) - scheme = self.http_to_ws_schemes[current.request.scheme] - # add scheme and host - if scheme: - if host is None: - if not hasattr(current, 'request'): - raise RuntimeError( - f'cannot build url("{path}",...) ' - 'without current request' - ) - host = current.request.host - return builder.url(scheme, host, lang, args, params) - - def __call__(self, *args, **kwargs): - return self.http(*args, **kwargs) - - -url = Url() +url = Url(current) diff --git a/emmett/rsgi/handlers.py b/emmett/rsgi/handlers.py index abd9de12..de7373ec 100644 --- a/emmett/rsgi/handlers.py +++ b/emmett/rsgi/handlers.py @@ -1,180 +1,50 @@ # -*- coding: utf-8 -*- """ - emmett.rsgi.handlers - -------------------- +emmett.rsgi.handlers +-------------------- - Provides RSGI handlers. +Provides RSGI handlers. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ from __future__ import annotations -import asyncio import os -import re - -from typing import Awaitable, Callable, Optional, Tuple +from typing import Awaitable, Callable +from emmett_core.http.response import HTTPResponse +from emmett_core.protocols.rsgi.handlers import HTTPHandler as _HTTPHandler, WSHandler as _WSHandler +from emmett_core.utils import cachedprop from granian.rsgi import ( HTTPProtocol, - ProtocolClosed, Scope, - WebsocketMessageType, - WebsocketProtocol ) -from ..ctx import RequestContext, WSContext, current -from ..debug import smart_traceback, debug_handler -from ..http import HTTPResponse, HTTPFile, HTTP -from ..utils import cachedprop +from ..ctx import current +from ..debug import debug_handler, smart_traceback from ..wrappers.response import Response - -from .helpers import WSTransport from .wrappers import Request, Websocket -REGEX_STATIC = re.compile( - r'^/static/(?P__[\w\-\.]+__/)?(?P_\d+\.\d+\.\d+/)?(?P.*?)$' -) -REGEX_STATIC_LANG = re.compile( - r'^/(?P\w{2}/)?static/(?P__[\w\-\.]__+/)?(?P_\d+\.\d+\.\d+/)?(?P.*?)$' -) - - -class Handler: - __slots__ = ['app'] - - def __init__(self, app): - self.app = app - - -class RequestHandler(Handler): - __slots__ = ['router'] - def __init__(self, app): - super().__init__(app) - self._bind_router() - self._configure_methods() - - def _bind_router(self): - raise NotImplementedError - - def _configure_methods(self): - raise NotImplementedError - - -class HTTPHandler(RequestHandler): - __slots__ = ['pre_handler', 'static_handler', 'static_matcher', '__dict__'] - - def _bind_router(self): - self.router = self.app._router_http - - def _configure_methods(self): - self.static_matcher = ( - self._static_lang_matcher if self.app.language_force_on_url else - self._static_nolang_matcher - ) - self.static_handler = ( - self._static_handler if self.app.config.handle_static else - self.dynamic_handler - ) - self.pre_handler = ( - self._prefix_handler if self.router._prefix_main else - self.static_handler - ) - - async def __call__( - self, - scope: Scope, - protocol: HTTPProtocol - ): - try: - http = await self.pre_handler(scope, protocol, scope.path) - except asyncio.TimeoutError: - self.app.log.warn( - f"Timeout sending response: ({scope.path})" - ) - if coro := http.rsgi(protocol): - await coro +class HTTPHandler(_HTTPHandler): + __slots__ = [] + wapper_cls = Request + response_cls = Response @cachedprop def error_handler(self) -> Callable[[], Awaitable[str]]: - return ( - self._debug_handler if self.app.debug else self.exception_handler - ) - - @cachedprop - def exception_handler(self) -> Callable[[], Awaitable[str]]: - return self.app.error_handlers.get(500, self._exception_handler) - - @staticmethod - async def _http_response(code: int) -> HTTPResponse: - return HTTP(code) - - def _prefix_handler( - self, - scope: Scope, - protocol: HTTPProtocol, - path: str - ) -> Awaitable[HTTPResponse]: - if not path.startswith(self.router._prefix_main): - return self._http_response(404) - path = path[self.router._prefix_main_len:] or '/' - return self.static_handler(scope, protocol, path) - - def _static_lang_matcher( - self, path: str - ) -> Tuple[Optional[str], Optional[str]]: - match = REGEX_STATIC_LANG.match(path) - if match: - lang, mname, version, file_name = match.group('l', 'm', 'v', 'f') - if mname: - mod = self.app._modules.get(mname) - spath = mod._static_path if mod else self.app.static_path - else: - spath = self.app.static_path - static_file = os.path.join(spath, file_name) - if lang: - lang_file = os.path.join(spath, lang, file_name) - if os.path.exists(lang_file): - static_file = lang_file - return static_file, version - return None, None - - def _static_nolang_matcher( - self, path: str - ) -> Tuple[Optional[str], Optional[str]]: - if path.startswith('/static/'): - mname, version, file_name = REGEX_STATIC.match(path).group('m', 'v', 'f') - if mname: - mod = self.app._modules.get(mname[2:-3]) - static_file = os.path.join(mod._static_path, file_name) if mod else None - elif file_name: - static_file = os.path.join(self.app.static_path, file_name) - else: - static_file = None - return static_file, version - return None, None + return self._debug_handler if self.app.debug else self.exception_handler - async def _static_response(self, file_path: str) -> HTTPFile: - return HTTPFile(file_path) - - def _static_handler( - self, - scope: Scope, - protocol: HTTPProtocol, - path: str - ) -> Awaitable[HTTPResponse]: + def _static_handler(self, scope: Scope, protocol: HTTPProtocol, path: str) -> Awaitable[HTTPResponse]: #: handle internal assets - if path.startswith('/__emmett__'): + if path.startswith("/__emmett__"): file_name = path[12:] if not file_name: return self._http_response(404) - static_file = os.path.join( - os.path.dirname(__file__), '..', 'assets', file_name - ) - if os.path.splitext(static_file)[1] == 'html': + static_file = os.path.join(os.path.dirname(__file__), "..", "assets", file_name) + if os.path.splitext(static_file)[1] == "html": return self._http_response(404) return self._static_response(static_file) #: handle app assets @@ -183,138 +53,10 @@ def _static_handler( return self._static_response(static_file) return self.dynamic_handler(scope, protocol, path) - async def dynamic_handler( - self, - scope: Scope, - protocol: HTTPProtocol, - path: str - ) -> HTTPResponse: - request = Request( - scope, - path, - protocol, - max_content_length=self.app.config.request_max_content_length, - body_timeout=self.app.config.request_body_timeout - ) - response = Response() - ctx = RequestContext(self.app, request, response) - ctx_token = current._init_(ctx) - try: - http = await self.router.dispatch(request, response) - except HTTPResponse as http_exception: - http = http_exception - #: render error with handlers if in app - error_handler = self.app.error_handlers.get(http.status_code) - if error_handler: - http = HTTP( - http.status_code, - await error_handler(), - headers=response.headers, - cookies=response.cookies - ) - except Exception: - self.app.log.exception('Application exception:') - http = HTTP( - 500, - await self.error_handler(), - headers=response.headers - ) - finally: - current._close_(ctx_token) - return http - async def _debug_handler(self) -> str: - current.response.headers._data['content-type'] = ( - 'text/html; charset=utf-8' - ) + current.response.headers._data["content-type"] = "text/html; charset=utf-8" return debug_handler(smart_traceback(self.app)) - async def _exception_handler(self) -> str: - current.response.headers._data['content-type'] = 'text/plain' - return 'Internal error' - - -class WSHandler(RequestHandler): - __slots__ = ['pre_handler', '__dict__'] - - def _bind_router(self): - self.router = self.app._router_ws - - def _configure_methods(self): - self.pre_handler = ( - self._prefix_handler if self.router._prefix_main else - self.dynamic_handler - ) - - async def __call__( - self, - scope: Scope, - protocol: WebsocketProtocol - ): - transport = WSTransport(protocol) - task_transport = asyncio.create_task(self.handle_transport(transport)) - task_request = asyncio.create_task(self.handle_request(scope, transport)) - _, pending = await asyncio.wait( - [task_request, task_transport], return_when=asyncio.FIRST_COMPLETED - ) - for task in pending: - task.cancel() - self._close_connection(transport) - - async def handle_transport(self, transport: WSTransport): - await transport.accepted.wait() - try: - while True: - msg = await transport.transport.receive() - if msg.kind == WebsocketMessageType.close: - transport.interrupted = True - break - await transport.input.put(msg) - except ProtocolClosed: - transport.interrupted = True - - def handle_request( - self, - scope: Scope, - transport: WSTransport - ): - return self.pre_handler(scope, transport, scope.path) - - async def _empty_awaitable(self): - return - - def _prefix_handler( - self, - scope: Scope, - transport: WSTransport, - path: str - ) -> Awaitable[None]: - if not path.startswith(self.router._prefix_main): - transport.status = 404 - return self._empty_awaitable() - path = path[self.router._prefix_main_len:] or '/' - return self.dynamic_handler(scope, transport, path) - - async def dynamic_handler( - self, - scope: Scope, - transport: WSTransport, - path: str - ): - ctx = WSContext(self.app, Websocket(scope, path, transport)) - ctx_token = current._init_(ctx) - try: - await self.router.dispatch(ctx.websocket) - except HTTPResponse as http: - transport.status = http.status_code - except asyncio.CancelledError: - if not transport.interrupted: - self.app.log.exception('Application exception:') - except Exception: - transport.status = 500 - self.app.log.exception('Application exception:') - finally: - current._close_(ctx_token) - def _close_connection(self, transport: WSTransport): - transport.protocol.close(transport.status) +class WSHandler(_WSHandler): + wrapper_cls = Websocket diff --git a/emmett/rsgi/helpers.py b/emmett/rsgi/helpers.py deleted file mode 100644 index f1f25971..00000000 --- a/emmett/rsgi/helpers.py +++ /dev/null @@ -1,42 +0,0 @@ -# -*- coding: utf-8 -*- -""" - emmett.rsgi.helpers - ------------------- - - Provides RSGI helpers - - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause -""" - -import asyncio - -from granian.rsgi import WebsocketProtocol - - -class WSTransport: - __slots__ = [ - 'protocol', 'transport', - 'accepted', 'interrupted', - 'input', 'status', 'noop' - ] - - def __init__( - self, - protocol: WebsocketProtocol - ) -> None: - self.protocol = protocol - self.transport = None - self.accepted = asyncio.Event() - self.input = asyncio.Queue() - self.interrupted = False - self.status = 200 - self.noop = asyncio.Event() - - async def init(self): - self.transport = await self.protocol.accept() - self.accepted.set() - - @property - def receive(self): - return self.input.get diff --git a/emmett/rsgi/wrappers.py b/emmett/rsgi/wrappers.py index 0de63011..9996a295 100644 --- a/emmett/rsgi/wrappers.py +++ b/emmett/rsgi/wrappers.py @@ -1,158 +1,26 @@ # -*- coding: utf-8 -*- """ - emmett.rsgi.wrappers - -------------------- +emmett.rsgi.wrappers +-------------------- - Provides RSGI request and websocket wrappers +Provides RSGI request and websocket wrappers - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ -import asyncio +import pendulum +from emmett_core.protocols.rsgi.wrappers import Request as _Request, Websocket as Websocket +from emmett_core.utils import cachedprop -from datetime import datetime -from typing import Any, Dict, List, Union, Optional -from urllib.parse import parse_qs -from granian.rsgi import Scope, HTTPProtocol, ProtocolClosed - -from .helpers import WSTransport -from ..datastructures import sdict -from ..http import HTTP -from ..utils import cachedprop -from ..wrappers.helpers import regex_client -from ..wrappers.request import Request as _Request -from ..wrappers.websocket import Websocket as _Websocket - - -class RSGIIngressMixin: - def __init__( - self, - scope: Scope, - path: str, - protocol: Union[HTTPProtocol, WSTransport] - ): - self._scope = scope - self._proto = protocol - self.scheme = scope.scheme - self.path = path - - @property - def headers(self): - return self._scope.headers - - @cachedprop - def host(self) -> str: - if self._scope.http_version[0] == '1': - return self.headers.get('host') - return self._scope.authority +class Request(_Request): + __slots__ = [] @cachedprop - def query_params(self) -> sdict[str, Union[str, List[str]]]: - rv: sdict[str, Any] = sdict() - for key, values in parse_qs( - self._scope.query_string, keep_blank_values=True - ).items(): - if len(values) == 1: - rv[key] = values[0] - continue - rv[key] = values - return rv - - -class Request(RSGIIngressMixin, _Request): - __slots__ = ['_scope', '_proto'] - - def __init__( - self, - scope: Scope, - path: str, - protocol: HTTPProtocol, - max_content_length: Optional[int] = None, - body_timeout: Optional[int] = None - ): - super().__init__(scope, path, protocol) - self.max_content_length = max_content_length - self.body_timeout = body_timeout - self._now = datetime.utcnow() - self.method = scope.method - - @property - def _multipart_headers(self): - return dict(self.headers.items()) + def now(self) -> pendulum.DateTime: + return pendulum.instance(self._now) @cachedprop - async def body(self) -> bytes: - if ( - self.max_content_length and - self.content_length > self.max_content_length - ): - raise HTTP(413, 'Request entity too large') - try: - rv = await asyncio.wait_for(self._proto(), timeout=self.body_timeout) - except asyncio.TimeoutError: - raise HTTP(408, 'Request timeout') - return rv - - @cachedprop - def client(self) -> str: - g = regex_client.search(self.headers.get('x-forwarded-for', '')) - client = ( - (g.group() or '').split(',')[0] if g else ( - self._scope.client[0] if self._scope.client else None - ) - ) - if client in (None, '', 'unknown', 'localhost'): - client = '::1' if self.host.startswith('[') else '127.0.0.1' - return client # type: ignore - - async def push_promise(self, path: str): - raise NotImplementedError("RSGI protocol doesn't support HTTP2 push.") - - -class Websocket(RSGIIngressMixin, _Websocket): - __slots__ = ['_scope', '_proto'] - - def __init__( - self, - scope: Scope, - path: str, - protocol: WSTransport - ): - super().__init__(scope, path, protocol) - self._flow_receive = None - self._flow_send = None - self.receive = self._accept_and_receive - self.send = self._accept_and_send - - async def accept( - self, - headers: Optional[Dict[str, str]] = None, - subprotocol: Optional[str] = None - ): - if self._proto.transport: - return - await self._proto.init() - self.receive = self._wrapped_receive - self.send = self._wrapped_send - - async def _wrapped_receive(self) -> Any: - data = (await self._proto.receive()).data - for method in self._flow_receive: - data = method(data) - return data - - async def _wrapped_send(self, data: Any): - for method in self._flow_send: - data = method(data) - trx = ( - self._proto.transport.send_str if isinstance(data, str) else - self._proto.transport.send_bytes - ) - try: - await trx(data) - except ProtocolClosed: - if not self._proto.interrupted: - raise - await self._proto.noop.wait() + def now_local(self) -> pendulum.DateTime: + return self.now.in_timezone(pendulum.local_timezone()) # type: ignore diff --git a/emmett/security.py b/emmett/security.py index 3cbfbdc0..01cb4d84 100644 --- a/emmett/security.py +++ b/emmett/security.py @@ -1,30 +1,21 @@ # -*- coding: utf-8 -*- """ - emmett.security - --------------- +emmett.security +--------------- - Miscellaneous security helpers. +Miscellaneous security helpers. - :copyright: 2014 Giovanni Barillari - - Based on the code of web2py (http://www.web2py.com) - :copyright: (c) by Massimo Di Pierro - - :license: LGPLv3 (http://www.gnu.org/licenses/lgpl.html) +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ import hashlib import hmac -import os -import random -import struct -import threading import time -import uuid as uuidm - from collections import OrderedDict +from uuid import uuid4 -from emmett_crypto import kdf +from emmett_core.cryptography import kdf # TODO: check bytes conversions from ._shortcuts import to_bytes @@ -40,17 +31,17 @@ def _clean(self): def gen_token(self): self._clean() - token = str(uuid()) + token = str(uuid4()) self[token] = int(time.time()) return token def md5_hash(text): - """ Generate a md5 hash with the given text """ + """Generate a md5 hash with the given text""" return hashlib.md5(text).hexdigest() -def simple_hash(text, key='', salt='', digest_alg='md5'): +def simple_hash(text, key="", salt="", digest_alg="md5"): """ Generates hash with the given text using the specified digest hashing algorithm @@ -59,14 +50,10 @@ def simple_hash(text, key='', salt='', digest_alg='md5'): raise RuntimeError("simple_hash with digest_alg=None") elif not isinstance(digest_alg, str): # manual approach h = digest_alg(text + key + salt) - elif digest_alg.startswith('pbkdf2'): # latest and coolest! - iterations, keylen, alg = digest_alg[7:-1].split(',') + elif digest_alg.startswith("pbkdf2"): # latest and coolest! + iterations, keylen, alg = digest_alg[7:-1].split(",") return kdf.pbkdf2_hex( - text, - salt, - iterations=int(iterations), - keylen=int(keylen), - hash_algorithm=kdf.PBKDF2_HMAC[alg] + text, salt, iterations=int(iterations), keylen=int(keylen), hash_algorithm=kdf.PBKDF2_HMAC[alg] ) elif key: # use hmac digest_alg = get_digest(digest_alg) @@ -101,89 +88,10 @@ def get_digest(value): DIGEST_ALG_BY_SIZE = { - 128 / 4: 'md5', - 160 / 4: 'sha1', - 224 / 4: 'sha224', - 256 / 4: 'sha256', - 384 / 4: 'sha384', - 512 / 4: 'sha512', + 128 / 4: "md5", + 160 / 4: "sha1", + 224 / 4: "sha224", + 256 / 4: "sha256", + 384 / 4: "sha384", + 512 / 4: "sha512", } - - -def _init_urandom(): - """ - This function and the web2py_uuid follow from the following discussion: - http://groups.google.com/group/web2py-developers/browse_thread/thread/7fd5789a7da3f09 - - At startup web2py compute a unique ID that identifies the machine by adding - uuid.getnode() + int(time.time() * 1e3) - - This is a 48-bit number. It converts the number into 16 8-bit tokens. - It uses this value to initialize the entropy source ('/dev/urandom') - and to seed random. - - If os.random() is not supported, it falls back to using random and issues - a warning. - """ - node_id = uuidm.getnode() - microseconds = int(time.time() * 1e6) - ctokens = [((node_id + microseconds) >> ((i % 6) * 8)) % - 256 for i in range(16)] - random.seed(node_id + microseconds) - try: - os.urandom(1) - have_urandom = True - try: - # try to add process-specific entropy - frandom = open('/dev/urandom', 'wb') - try: - frandom.write(bytes([]).join(bytes([t]) for t in ctokens)) - finally: - frandom.close() - except IOError: - # works anyway - pass - except NotImplementedError: - have_urandom = False - packed = bytes([]).join(bytes([x]) for x in ctokens) - unpacked_ctokens = struct.unpack('=QQ', packed) - return unpacked_ctokens, have_urandom - - -_UNPACKED_CTOKENS, _HAVE_URANDOM = _init_urandom() - - -def fast_urandom16(urandom=[], locker=threading.RLock()): - """ - this is 4x faster than calling os.urandom(16) and prevents - the "too many files open" issue with concurrent access to os.urandom() - """ - try: - return urandom.pop() - except IndexError: - try: - locker.acquire() - ur = os.urandom(16 * 1024) - urandom += [ur[i:i + 16] for i in range(16, 1024 * 16, 16)] - return ur[0:16] - finally: - locker.release() - - -def uuid(ctokens=_UNPACKED_CTOKENS): - """ - It works like uuid.uuid4 except that tries to use os.urandom() if possible - and it XORs the output with the tokens uniquely associated with - this machine. - """ - rand_longs = (random.getrandbits(64), random.getrandbits(64)) - if _HAVE_URANDOM: - urand_longs = struct.unpack('=QQ', fast_urandom16()) - byte_s = struct.pack('=QQ', - rand_longs[0] ^ urand_longs[0] ^ ctokens[0], - rand_longs[1] ^ urand_longs[1] ^ ctokens[1]) - else: - byte_s = struct.pack('=QQ', - rand_longs[0] ^ ctokens[0], - rand_longs[1] ^ ctokens[1]) - return str(uuidm.UUID(bytes=byte_s, version=4)) diff --git a/emmett/serializers.py b/emmett/serializers.py index 62be271c..e5a73bb0 100644 --- a/emmett/serializers.py +++ b/emmett/serializers.py @@ -1,97 +1,28 @@ # -*- coding: utf-8 -*- """ - emmett.serializers - ------------------ +emmett.serializers +------------------ - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ -from functools import partial -from typing import Any, Callable, Dict, Union +from emmett_core.serializers import Serializers as Serializers -from .html import tag, htmlescape - -try: - import orjson - _json_impl = orjson.dumps - _json_opts = { - "option": orjson.OPT_NON_STR_KEYS | orjson.OPT_NAIVE_UTC - } - _json_type = "bytes" -except ImportError: - import rapidjson - _json_impl = rapidjson.dumps - _json_opts = { - "datetime_mode": rapidjson.DM_ISO8601 | rapidjson.DM_NAIVE_IS_UTC, - "number_mode": rapidjson.NM_NATIVE | rapidjson.NM_DECIMAL - } - _json_type = "str" - -_json_safe_table = { - 'u2028': [r'\u2028', '\\u2028'], - 'u2029': [r'\u2029', '\\u2029'] -} - - -class Serializers: - _registry_: Dict[str, Callable[[Any], Union[bytes, str]]] = {} - - @classmethod - def register_for(cls, target): - def wrap(f): - cls._registry_[target] = f - return f - return wrap - - @classmethod - def get_for(cls, target): - return cls._registry_[target] - - -def _json_default(obj): - if hasattr(obj, '__json__'): - return obj.__json__() - raise TypeError - - -json = partial( - _json_impl, - default=_json_default, - **_json_opts -) - - -def json_safe(value): - rv = json(value) - for val, rep in _json_safe_table.values(): - rv.replace(val, rep) - return rv +from .html import htmlescape, tag def xml_encode(value, key=None, quote=True): - if hasattr(value, '__xml__'): + if hasattr(value, "__xml__"): return value.__xml__(key, quote) if isinstance(value, dict): - return tag[key]( - *[ - tag[k](xml_encode(v, None, quote)) - for k, v in value.items() - ]) + return tag[key](*[tag[k](xml_encode(v, None, quote)) for k, v in value.items()]) if isinstance(value, list): - return tag[key]( - *[ - tag[item](xml_encode(item, None, quote)) - for item in value - ]) + return tag[key](*[tag[item](xml_encode(item, None, quote)) for item in value]) return htmlescape(value) -Serializers.register_for('json')(json) - - -@Serializers.register_for('xml') -def xml(value, encoding='UTF-8', key='document', quote=True): - rv = ('' % encoding) + \ - str(xml_encode(value, key, quote)) +@Serializers.register_for("xml") +def xml(value, encoding="UTF-8", key="document", quote=True): + rv = ('' % encoding) + str(xml_encode(value, key, quote)) return rv diff --git a/emmett/server.py b/emmett/server.py deleted file mode 100644 index 92bd1f49..00000000 --- a/emmett/server.py +++ /dev/null @@ -1,56 +0,0 @@ -# -*- coding: utf-8 -*- -""" - emmett.server - ------------- - - Provides server wrapper over granian - - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause -""" - -from typing import Optional - -from granian import Granian - - -def run( - interface, - app, - host='127.0.0.1', - port=8000, - loop='auto', - loop_opt=False, - log_level=None, - log_access=False, - workers=1, - threads=1, - threading_mode='workers', - backlog=1024, - backpressure=None, - http='auto', - enable_websockets=True, - ssl_certfile: Optional[str] = None, - ssl_keyfile: Optional[str] = None -): - app_path = ":".join([app[0], app[1] or "app"]) - runner = Granian( - app_path, - address=host, - port=port, - interface=interface, - workers=workers, - threads=threads, - threading_mode=threading_mode, - loop=loop, - loop_opt=loop_opt, - http=http, - websockets=enable_websockets, - backlog=backlog, - backpressure=backpressure, - log_level=log_level, - log_access=log_access, - ssl_cert=ssl_certfile, - ssl_key=ssl_keyfile - ) - runner.serve() diff --git a/emmett/sessions.py b/emmett/sessions.py index e96dc512..7a62c697 100644 --- a/emmett/sessions.py +++ b/emmett/sessions.py @@ -1,403 +1,23 @@ # -*- coding: utf-8 -*- """ - emmett.sessions - --------------- +emmett.sessions +--------------- - Provides session managers for applications. +Provides session managers for applications. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ from __future__ import annotations -import os -import pickle -import tempfile -import time -import zlib - -from typing import Any, Dict, Optional, Type, TypeVar - -from emmett_crypto import symmetric as crypto_symmetric +from emmett_core.sessions import SessionManager as _SessionManager from .ctx import current -from .datastructures import sdict, SessionData -from .pipeline import Pipe -from .security import uuid -from .wrappers import IngressWrapper - - -class SessionPipe(Pipe): - def __init__( - self, - expire: int = 3600, - secure: bool = False, - samesite: str = "Lax", - domain: Optional[str] = None, - cookie_name: Optional[str] = None, - cookie_data: Optional[Dict[str, Any]] = None - ): - self.expire = expire - self.secure = secure - self.samesite = samesite - self.domain = domain - self.cookie_name = ( - cookie_name or f'emt_session_data_{current.app.name}' - ) - self.cookie_data = cookie_data or {} - - def _load_session(self, wrapper: IngressWrapper): - raise NotImplementedError - - def _new_session(self) -> SessionData: - raise NotImplementedError - - def _pack_session(self, expiration: int): - current.response.cookies[self.cookie_name] = self._session_cookie_data() - cookie_data = current.response.cookies[self.cookie_name] - cookie_data['path'] = "/" - cookie_data['expires'] = expiration - cookie_data['samesite'] = self.samesite - if self.secure: - cookie_data['secure'] = True - if self.domain is not None: - cookie_data['domain'] = self.domain - for key, val in self.cookie_data.items(): - cookie_data[key] = val - - def _session_cookie_data(self) -> str: - raise NotImplementedError - - async def open_request(self): - if self.cookie_name in current.request.cookies: - current.session = self._load_session(current.request) - if not current.session: - current.session = self._new_session() - - async def open_ws(self): - if self.cookie_name in current.websocket.cookies: - current.session = self._load_session(current.websocket) - if not current.session: - current.session = self._new_session() - - async def close_request(self): - expiration = current.session._expiration or self.expire - self._pack_session(expiration) - - def clear(self): - pass - - -class CookieSessionPipe(SessionPipe): - def __init__( - self, - key, - expire=3600, - secure=False, - samesite="Lax", - domain=None, - cookie_name=None, - cookie_data=None, - encryption_mode="modern", - compression_level=0 - ): - super().__init__( - expire=expire, - secure=secure, - samesite=samesite, - domain=domain, - cookie_name=cookie_name, - cookie_data=cookie_data - ) - self.key = key - if encryption_mode != "modern": - raise ValueError("Unsupported encryption_mode") - self.compression_level = compression_level - - def _encrypt_data(self) -> str: - data = pickle.dumps(sdict(current.session)) - if self.compression_level: - data = zlib.compress(data, self.compression_level) - return crypto_symmetric.encrypt_b64(data, self.key) - - def _decrypt_data(self, data: str) -> SessionData: - try: - ddata = crypto_symmetric.decrypt_b64(data, self.key) - if self.compression_level: - ddata = zlib.decompress(ddata) - rv = pickle.loads(ddata) - except Exception: - rv = None - return SessionData(rv, expires=self.expire) - - def _load_session(self, wrapper: IngressWrapper) -> SessionData: - cookie_data = wrapper.cookies[self.cookie_name].value - return self._decrypt_data(cookie_data) - - def _new_session(self) -> SessionData: - return SessionData(expires=self.expire) - - def _session_cookie_data(self) -> str: - return self._encrypt_data() - - def clear(self): - raise NotImplementedError( - f"{self.__class__.__name__} doesn't support sessions clearing. " - f"Change the 'key' parameter to invalidate existing ones." - ) - - -class BackendStoredSessionPipe(SessionPipe): - def _new_session(self): - return SessionData(sid=uuid()) - - def _session_cookie_data(self) -> str: - return current.session._sid - - def _load_session(self, wrapper: IngressWrapper) -> Optional[SessionData]: - sid = wrapper.cookies[self.cookie_name].value - data = self._load(sid) - if data is not None: - return SessionData(data, sid=sid) - return None - - def _delete_session(self): - pass - - def _save_session(self, expiration: int): - pass - - def _load(self, sid: str): - return None - async def close_request(self): - if not current.session: - self._delete_session() - if current.session._modified: - #: if we got here means we want to destroy session definitely - if self.cookie_name in current.response.cookies: - del current.response.cookies[self.cookie_name] - return - expiration = current.session._expiration or self.expire - self._save_session(expiration) - self._pack_session(expiration) - - -class FileSessionPipe(BackendStoredSessionPipe): - _fs_transaction_suffix = '.__emt_sess' - _fs_mode = 0o600 - - def __init__( - self, - expire=3600, - secure=False, - samesite="Lax", - domain=None, - cookie_name=None, - cookie_data=None, - filename_template='emt_%s.sess' - ): - super().__init__( - expire=expire, - secure=secure, - samesite=samesite, - domain=domain, - cookie_name=cookie_name, - cookie_data=cookie_data - ) - assert not filename_template.endswith(self._fs_transaction_suffix), \ - 'filename templates cannot end with %s' % \ - self._fs_transaction_suffix - self._filename_template = filename_template - self._path = os.path.join(current.app.root_path, 'sessions') - #: create required paths if needed - if not os.path.exists(self._path): - os.mkdir(self._path) - - def _delete_session(self): - fn = self._get_filename(current.session._sid) - try: - os.unlink(fn) - except OSError: - pass - - def _save_session(self, expiration): - if current.session._modified: - self._store(current.session, expiration) - - def _get_filename(self, sid): - return os.path.join(self._path, self._filename_template % str(sid)) - - def _load(self, sid): - try: - with open(self._get_filename(sid), 'rb') as f: - exp = pickle.load(f) - val = pickle.load(f) - except IOError: - return None - if exp < time.time(): - return None - return val - - def _store(self, session, expiration): - fn = self._get_filename(session._sid) - now = time.time() - exp = now + expiration - fd, tmp = tempfile.mkstemp( - suffix=self._fs_transaction_suffix, dir=self._path) - f = os.fdopen(fd, 'wb') - try: - pickle.dump(exp, f, 1) - pickle.dump(sdict(session), f, pickle.HIGHEST_PROTOCOL) - finally: - f.close() - try: - os.rename(tmp, fn) - os.chmod(fn, self._fs_mode) - except Exception: - pass - - def clear(self): - for element in os.listdir(self._path): - try: - os.unlink(os.path.join(self._path, element)) - except Exception: - pass - - -class RedisSessionPipe(BackendStoredSessionPipe): - def __init__( - self, - redis, - prefix="emtsess:", - expire=3600, - secure=False, - samesite="Lax", - domain=None, - cookie_name=None, - cookie_data=None - ): - super().__init__( - expire=expire, - secure=secure, - samesite=samesite, - domain=domain, - cookie_name=cookie_name, - cookie_data=cookie_data - ) - self.redis = redis - self.prefix = prefix - - def _delete_session(self): - self.redis.delete(self.prefix + current.session._sid) - - def _save_session(self, expiration): - if current.session._modified: - self.redis.setex( - self.prefix + current.session._sid, - expiration, - current.session._dump - ) - else: - self.redis.expire(self.prefix + current.session._sid, expiration) - - def _load(self, sid): - data = self.redis.get(self.prefix + sid) - return pickle.loads(data) if data else data - - def clear(self): - self.redis.delete(self.prefix + "*") - - -TSessionPipe = TypeVar("TSessionPipe", bound=SessionPipe) - - -class SessionManager: - _pipe: Optional[SessionPipe] = None +class SessionManager(_SessionManager): @classmethod - def _build_pipe( - cls, - handler_cls: Type[TSessionPipe], - *args: Any, - **kwargs: Any - ) -> TSessionPipe: - cls._pipe = handler_cls(*args, **kwargs) + def _build_pipe(cls, handler_cls, *args, **kwargs): + cls._pipe = handler_cls(current, *args, **kwargs) return cls._pipe - - @classmethod - def cookies( - cls, - key: str, - expire: int = 3600, - secure: bool = False, - samesite: str = "Lax", - domain: Optional[str] = None, - cookie_name: Optional[str] = None, - cookie_data: Optional[Dict[str, Any]] = None, - encryption_mode: str = "modern", - compression_level: int = 0 - ) -> CookieSessionPipe: - return cls._build_pipe( - CookieSessionPipe, - key, - expire=expire, - secure=secure, - samesite=samesite, - domain=domain, - cookie_name=cookie_name, - cookie_data=cookie_data, - encryption_mode=encryption_mode, - compression_level=compression_level - ) - - @classmethod - def files( - cls, - expire: int = 3600, - secure: bool = False, - samesite: str = "Lax", - domain: Optional[str] = None, - cookie_name: Optional[str] = None, - cookie_data: Optional[Dict[str, Any]] = None, - filename_template: str = 'emt_%s.sess' - ) -> FileSessionPipe: - return cls._build_pipe( - FileSessionPipe, - expire=expire, - secure=secure, - samesite=samesite, - domain=domain, - cookie_name=cookie_name, - cookie_data=cookie_data, - filename_template=filename_template - ) - - @classmethod - def redis( - cls, - redis: Any, - prefix: str = "emtsess:", - expire: int = 3600, - secure: bool = False, - samesite: str = "Lax", - domain: Optional[str] = None, - cookie_name: Optional[str] = None, - cookie_data: Optional[Dict[str, Any]] = None - ) -> RedisSessionPipe: - return cls._build_pipe( - RedisSessionPipe, - redis, - prefix=prefix, - expire=expire, - secure=secure, - samesite=samesite, - domain=domain, - cookie_name=cookie_name, - cookie_data=cookie_data - ) - - @classmethod - def clear(cls): - cls._pipe.clear() diff --git a/emmett/templating/lexers.py b/emmett/templating/lexers.py index d7b5c5b6..801807f6 100644 --- a/emmett/templating/lexers.py +++ b/emmett/templating/lexers.py @@ -1,12 +1,12 @@ # -*- coding: utf-8 -*- """ - emmett.templating.lexers - ------------------------ +emmett.templating.lexers +------------------------ - Provides the Emmett lexers for Renoir engine. +Provides the Emmett lexers for Renoir engine. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ from renoir import Lexer @@ -17,10 +17,8 @@ class HelpersLexer(Lexer): helpers = [ - '', - '' + '', + '', ] def process(self, ctx, value): @@ -30,17 +28,12 @@ def process(self, ctx, value): class MetaLexer(Lexer): def process(self, ctx, value): - ctx.python_node('for name, value in current.response._meta_tmpl():') - ctx.variable( - "'' % (name, value)", - escape=False) - ctx.python_node('pass') - ctx.python_node( - 'for name, value in current.response._meta_tmpl_prop():') - ctx.variable( - "'' % (name, value)", - escape=False) - ctx.python_node('pass') + ctx.python_node("for name, value in current.response._meta_tmpl():") + ctx.variable('\'\' % (name, value)', escape=False) + ctx.python_node("pass") + ctx.python_node("for name, value in current.response._meta_tmpl_prop():") + ctx.variable('\'\' % (name, value)', escape=False) + ctx.python_node("pass") class StaticLexer(Lexer): @@ -48,20 +41,16 @@ class StaticLexer(Lexer): def process(self, ctx, value): file_name = value.split("?")[0] - surl = url('static', file_name) + surl = url("static", file_name) file_ext = file_name.rsplit(".", 1)[-1] - if file_ext == 'js': - s = u'' % surl + if file_ext == "js": + s = '' % surl elif file_ext == "css": - s = u'' % surl + s = '' % surl else: s = None if s: ctx.html(s) -lexers = { - 'include_helpers': HelpersLexer(), - 'include_meta': MetaLexer(), - 'include_static': StaticLexer() -} +lexers = {"include_helpers": HelpersLexer(), "include_meta": MetaLexer(), "include_static": StaticLexer()} diff --git a/emmett/templating/templater.py b/emmett/templating/templater.py index cf088534..a4ce70ad 100644 --- a/emmett/templating/templater.py +++ b/emmett/templating/templater.py @@ -1,16 +1,15 @@ # -*- coding: utf-8 -*- """ - emmett.templating.templater - --------------------------- +emmett.templating.templater +--------------------------- - Provides the Emmett implementation for Renoir engine. +Provides the Emmett implementation for Renoir engine. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ import os - from functools import reduce from typing import Optional, Tuple @@ -21,7 +20,7 @@ class Templater(Renoir): def __init__(self, **kwargs): - kwargs['lexers'] = lexers + kwargs["lexers"] = lexers super().__init__(**kwargs) self._namespaces = {} @@ -46,9 +45,7 @@ def register_namespace(self, namespace: str, path: Optional[str] = None): path = path or self.path self._namespaces[namespace] = path - def _get_namespace_path_elements( - self, file_name: str, path: Optional[str] - ) -> Tuple[str, str]: + def _get_namespace_path_elements(self, file_name: str, path: Optional[str]) -> Tuple[str, str]: if ":" in file_name: namespace, file_name = file_name.split(":") path = self._namespaces.get(namespace, self.path) @@ -60,9 +57,7 @@ def _preload(self, file_name: str, path: Optional[str] = None): path, file_name = self._get_namespace_path_elements(file_name, path) file_extension = os.path.splitext(file_name)[1] return reduce( - lambda args, loader: loader(args[0], args[1]), - self.loaders.get(file_extension, []), - (path, file_name) + lambda args, loader: loader(args[0], args[1]), self.loaders.get(file_extension, []), (path, file_name) ) def _no_preload(self, file_name: str, path: Optional[str] = None): diff --git a/emmett/testing.py b/emmett/testing.py new file mode 100644 index 00000000..5fc9746c --- /dev/null +++ b/emmett/testing.py @@ -0,0 +1,35 @@ +from emmett_core.protocols.rsgi.test_client.client import ( + ClientContext as _ClientContext, + ClientHTTPHandlerMixin, + EmmettTestClient as _EmmettTestClient, +) + +from .ctx import current +from .rsgi.handlers import HTTPHandler +from .wrappers.response import Response + + +class ClientContextResponse(Response): + def __init__(self, original_response: Response): + super().__init__() + self.status = original_response.status + self.headers._data.update(original_response.headers._data) + self.cookies.update(original_response.cookies.copy()) + self.__dict__.update(original_response.__dict__) + + +class ClientContext(_ClientContext): + _response_wrap_cls = ClientContextResponse + + def __init__(self, ctx): + super().__init__(ctx) + self.T = current.T + + +class ClientHTTPHandler(ClientHTTPHandlerMixin, HTTPHandler): + _client_ctx_cls = ClientContext + + +class EmmettTestClient(_EmmettTestClient): + _current = current + _handler_cls = ClientHTTPHandler diff --git a/emmett/testing/__init__.py b/emmett/testing/__init__.py deleted file mode 100644 index 7dacddbf..00000000 --- a/emmett/testing/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .client import EmmettTestClient diff --git a/emmett/testing/client.py b/emmett/testing/client.py deleted file mode 100644 index cd5dad65..00000000 --- a/emmett/testing/client.py +++ /dev/null @@ -1,365 +0,0 @@ -# -*- coding: utf-8 -*- -""" - emmett.testing.client - --------------------- - - Provides base classes for testing suite. - - :copyright: 2014 Giovanni Barillari - - Several parts of this code comes from Werkzeug. - :copyright: (c) 2015 by Armin Ronacher. - - :license: BSD-3-Clause -""" - -import asyncio -import copy -import types - -from io import BytesIO - -from ..asgi.handlers import HTTPHandler -from ..asgi.wrappers import Request -from ..ctx import RequestContext, current -from ..http import HTTP, HTTPResponse -from ..wrappers.response import Response -from ..utils import cachedprop -from .env import ScopeBuilder -from .helpers import TestCookieJar, Headers -from .urls import get_host, url_parse, url_unparse - - -class ClientContextResponse(Response): - def __init__(self, original_response: Response): - super().__init__() - self.status = original_response.status - self.headers._data.update(original_response.headers._data) - self.cookies.update(original_response.cookies.copy()) - self.__dict__.update(original_response.__dict__) - - -class ClientContext: - def __init__(self, ctx): - self.request = Request(ctx.request._scope, None, None) - self.response = ClientContextResponse(ctx.response) - self.session = copy.deepcopy(ctx.session) - self.T = current.T - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_value, tb): - pass - - -class ClientHTTPHandler(HTTPHandler): - async def dynamic_handler(self, scope, receive, send): - request = Request( - scope, - receive, - send, - max_content_length=self.app.config.request_max_content_length, - body_timeout=self.app.config.request_body_timeout - ) - response = Response() - ctx = RequestContext(self.app, request, response) - ctx_token = current._init_(ctx) - try: - http = await self.router.dispatch(request, response) - except HTTPResponse as http_exception: - http = http_exception - #: render error with handlers if in app - error_handler = self.app.error_handlers.get(http.status_code) - if error_handler: - http = HTTP( - http.status_code, - await error_handler(), - headers=response.headers, - cookies=response.cookies - ) - except Exception: - self.app.log.exception('Application exception:') - http = HTTP( - 500, - await self.error_handler(), - headers=response.headers - ) - finally: - scope['emt.ctx'] = ClientContext(ctx) - current._close_(ctx_token) - return http - - -class ClientResponse(object): - def __init__(self, ctx, raw, status, headers): - self.context = ctx - self.raw = raw - self.status = status - self.headers = headers - self._close = lambda: None - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_value, tb): - pass - - @cachedprop - def data(self): - self._ensure_sequence() - rv = b''.join(self.iter_encoded()) - return rv.decode('utf8') - - @property - def is_sequence(self): - return isinstance(self.raw, (tuple, list)) - - def _ensure_sequence(self, mutable=False): - if self.is_sequence: - # if we need a mutable object, we ensure it's a list. - if mutable and not isinstance(self.response, list): - self.response = list(self.response) - return - self.make_sequence() - - def make_sequence(self): - if not self.is_sequence: - close = getattr(self.raw, 'close', None) - self.raw = list(self.iter_encoded()) - self._close = close - - def close(self): - if hasattr(self.raw, 'close'): - self.response.close() - self._close() - - def iter_encoded(self, charset='utf8'): - for item in self.raw: - if isinstance(item, str): - yield item.encode(charset) - else: - yield item - - -class EmmettTestClient(object): - """This class allows to send requests to a wrapped application.""" - - def __init__( - self, application, response_wrapper=ClientResponse, use_cookies=True, - allow_subdomain_redirects=False - ): - self.application = application - self.response_wrapper = response_wrapper - if use_cookies: - self.cookie_jar = TestCookieJar() - else: - self.cookie_jar = None - self.allow_subdomain_redirects = allow_subdomain_redirects - - def run_asgi_app(self, scope, body): - if self.cookie_jar is not None: - self.cookie_jar.inject_asgi(scope) - rv = run_asgi_app(self.application, scope, body) - if self.cookie_jar is not None: - self.cookie_jar.extract_asgi(scope, Headers(rv['headers'])) - return rv - - def resolve_redirect(self, response, new_loc, scope, headers): - """Resolves a single redirect and triggers the request again - directly on this redirect client. - """ - scheme, netloc, script_root, qs, anchor = url_parse(new_loc) - base_url = url_unparse((scheme, netloc, '', '', '')).rstrip('/') + '/' - - cur_name = netloc.split(':', 1)[0].split('.') - real_name = get_host(scope, headers).rsplit(':', 1)[0].split('.') - - if len(cur_name) == 1 and not cur_name[0]: - allowed = True - else: - if self.allow_subdomain_redirects: - allowed = cur_name[-len(real_name):] == real_name - else: - allowed = cur_name == real_name - - if not allowed: - raise RuntimeError('%r does not support redirect to ' - 'external targets' % self.__class__) - - status_code = response['status'] - if status_code == 307: - method = scope['method'] - else: - method = 'GET' - - # For redirect handling we temporarily disable the response - # wrapper. This is not threadsafe but not a real concern - # since the test client must not be shared anyways. - old_response_wrapper = self.response_wrapper - self.response_wrapper = None - try: - return self.open( - path=script_root, base_url=base_url, query_string=qs, - method=method, as_tuple=True) - finally: - self.response_wrapper = old_response_wrapper - - def open(self, *args, **kwargs): - as_tuple = kwargs.pop('as_tuple', False) - follow_redirects = kwargs.pop('follow_redirects', False) - scope, body = None, b'' - if not kwargs and len(args) == 1: - if isinstance(args[0], ScopeBuilder): - scope, body = args[0].get_data() - if scope is None: - builder = ScopeBuilder(*args, **kwargs) - try: - scope, body = builder.get_data() - finally: - builder.close() - - response = self.run_asgi_app(scope, body) - - # handle redirects - redirect_chain = [] - while 1: - status_code = response['status'] - if ( - status_code not in (301, 302, 303, 305, 307) or - not follow_redirects - ): - break - headers = Headers(response['headers']) - new_location = headers['location'] - if new_location.startswith('/'): - new_location = ( - scope['scheme'] + "://" + - scope['server'][0] + new_location) - new_redirect_entry = (new_location, status_code) - if new_redirect_entry in redirect_chain: - raise Exception('loop detected') - redirect_chain.append(new_redirect_entry) - scope, response = self.resolve_redirect( - response, new_location, scope, headers) - - if self.response_wrapper is not None: - response = self.response_wrapper( - scope['emt.ctx'], response['body'], response['status'], - Headers(response['headers'])) - if as_tuple: - return scope, response - return response - - def get(self, *args, **kw): - """Like open but method is enforced to GET.""" - kw['method'] = 'GET' - return self.open(*args, **kw) - - def patch(self, *args, **kw): - """Like open but method is enforced to PATCH.""" - kw['method'] = 'PATCH' - return self.open(*args, **kw) - - def post(self, *args, **kw): - """Like open but method is enforced to POST.""" - kw['method'] = 'POST' - return self.open(*args, **kw) - - def head(self, *args, **kw): - """Like open but method is enforced to HEAD.""" - kw['method'] = 'HEAD' - return self.open(*args, **kw) - - def put(self, *args, **kw): - """Like open but method is enforced to PUT.""" - kw['method'] = 'PUT' - return self.open(*args, **kw) - - def delete(self, *args, **kw): - """Like open but method is enforced to DELETE.""" - kw['method'] = 'DELETE' - return self.open(*args, **kw) - - def options(self, *args, **kw): - """Like open but method is enforced to OPTIONS.""" - kw['method'] = 'OPTIONS' - return self.open(*args, **kw) - - def trace(self, *args, **kw): - """Like open but method is enforced to TRACE.""" - kw['method'] = 'TRACE' - return self.open(*args, **kw) - - def __repr__(self): - return '<%s %r>' % ( - self.__class__.__name__, - self.application - ) - - -def run_asgi_app(app, scope, body=b''): - try: - loop = asyncio.get_event_loop() - except RuntimeError: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - request_complete = False - response_started = False - response_complete = False - raw = {"body": BytesIO()} - - async def receive(): - nonlocal request_complete - - if request_complete: - while not response_complete: - await asyncio.sleep(0.0001) - return {"type": "http.disconnect"} - - if isinstance(body, str): - body_bytes = body.encode("utf-8") - elif body is None: - body_bytes = b"" - elif isinstance(body, types.GeneratorType): - try: - chunk = body.send(None) - if isinstance(chunk, str): - chunk = chunk.encode("utf-8") - return { - "type": "http.request", "body": chunk, "more_body": True} - except StopIteration: - request_complete = True - return {"type": "http.request", "body": b""} - else: - body_bytes = body - - request_complete = True - return {"type": "http.request", "body": body_bytes} - - async def send(message): - nonlocal response_started, response_complete - - if message["type"] == "http.response.start": - raw["version"] = 11 - raw["status"] = message["status"] - raw["headers"] = [ - (key.decode(), value.decode()) - for key, value in message["headers"] - ] - raw["preload_content"] = False - response_started = True - elif message["type"] == "http.response.body": - body = message.get("body", b"") - more_body = message.get("more_body", False) - if scope['method'] != "HEAD": - raw["body"].write(body) - if not more_body: - raw["body"].seek(0) - response_complete = True - - handler = ClientHTTPHandler(app) - loop.run_until_complete(handler(scope, receive, send)) - - return raw diff --git a/emmett/testing/env.py b/emmett/testing/env.py deleted file mode 100644 index ab12ddee..00000000 --- a/emmett/testing/env.py +++ /dev/null @@ -1,313 +0,0 @@ -# -*- coding: utf-8 -*- -""" - emmett.testing.env - ------------------ - - Provides environment class for testing suite. - - :copyright: 2014 Giovanni Barillari - - Several parts of this code comes from Werkzeug. - :copyright: (c) 2015 by Armin Ronacher. - - :license: BSD-3-Clause -""" - -import cgi -import sys - -from io import BytesIO - -from ..datastructures import sdict -from .helpers import Headers, filesdict, stream_encode_multipart -from .urls import iri_to_uri, url_fix, url_parse, url_unparse, url_encode - - -class ScopeBuilder(object): - #: This class creates an ASGI environment for testing purposes. - - #: the server protocol to use. defaults to HTTP/1.1 - server_protocol = '1.1' - - def __init__( - self, path='/', base_url=None, query_string=None, method='GET', - input_stream=None, content_type=None, content_length=None, - errors_stream=None, headers=None, data=None, environ_base=None, - environ_overrides=None, charset='utf-8' - ): - if query_string is None and '?' in path: - path, query_string = path.split('?', 1) - self.charset = charset - self.path = iri_to_uri(path) - if base_url is not None: - base_url = url_fix(iri_to_uri(base_url, charset), charset) - self.base_url = base_url - if isinstance(query_string, (bytes, str)): - self.query_string = query_string - else: - if query_string is None: - query_string = sdict() - elif not isinstance(query_string, dict): - query_string = self._parse_querystring(query_string) - self.args = query_string - self.method = method - if headers is None: - headers = Headers() - elif not isinstance(headers, Headers): - headers = Headers(headers) - self.headers = headers - if content_type is not None: - self.content_type = content_type - if errors_stream is None: - errors_stream = sys.stderr - self.errors_stream = errors_stream - self.environ_base = environ_base - self.environ_overrides = environ_overrides - self.input_stream = input_stream - self.content_length = content_length - self.closed = False - - if data: - if input_stream is not None: - raise TypeError('can\'t provide input stream and data') - if isinstance(data, str): - data = data.encode(self.charset) - if isinstance(data, bytes): - self.input_stream = BytesIO(data) - if self.content_length is None: - self.content_length = len(data) - else: - for key, values in data.items(): - if not isinstance(values, list): - values = [values] - for v in values: - if isinstance(v, (tuple)) or hasattr(v, 'read'): - self._add_file_from_data(key, v) - else: - if self.form[key] is None: - self.form[key] = [] - self.form[key].append(v) - - @staticmethod - def _parse_querystring(query_string): - dget = cgi.parse_qs(query_string, keep_blank_values=1) - params = sdict(dget) - for key, value in params.items(): - if isinstance(value, list) and len(value) == 1: - params[key] = value[0] - return params - - def _add_file_from_data(self, key, value): - if isinstance(value, tuple): - self.files.add_file(key, *value) - else: - self.files.add_file(key, value) - - def _get_base_url(self): - return url_unparse( - (self.url_scheme, self.host, self.script_root, '', '') - ).rstrip('/') + '/' - - def _set_base_url(self, value): - if value is None: - scheme = 'http' - netloc = 'localhost' - script_root = '' - else: - scheme, netloc, script_root, qs, anchor = url_parse(value) - if qs or anchor: - raise ValueError( - 'base url must not contain a query string or fragment') - self.script_root = script_root.rstrip('/') - self.host = netloc - self.url_scheme = scheme - - base_url = property(_get_base_url, _set_base_url) - del _get_base_url, _set_base_url - - def _get_content_type(self): - ct = self.headers.get('Content-Type') - if ct is None and not self._input_stream: - if self._files: - return 'multipart/form-data' - elif self._form: - return 'application/x-www-form-urlencoded' - return None - return ct - - def _set_content_type(self, value): - if value is None: - self.headers.pop('Content-Type', None) - else: - self.headers['Content-Type'] = value - - content_type = property(_get_content_type, _set_content_type) - del _get_content_type, _set_content_type - - def _get_content_length(self): - return self.headers.get('Content-Length', type=int) - - def _set_content_length(self, value): - if value is None: - self.headers.pop('Content-Length', None) - else: - self.headers['Content-Length'] = str(value) - - content_length = property(_get_content_length, _set_content_length) - del _get_content_length, _set_content_length - - def form_property(name, storage): - key = '_' + name - - def getter(self): - if self._input_stream is not None: - raise AttributeError('an input stream is defined') - rv = getattr(self, key) - if rv is None: - rv = storage() - setattr(self, key, rv) - - return rv - - def setter(self, value): - self._input_stream = None - setattr(self, key, value) - return property(getter, setter) - - form = form_property('form', sdict) - files = form_property('files', filesdict) - del form_property - - def _get_input_stream(self): - return self._input_stream - - def _set_input_stream(self, value): - self._input_stream = value - self._form = self._files = None - - input_stream = property(_get_input_stream, _set_input_stream, doc=''' - An optional input stream. If you set this it will clear - :attr:`form` and :attr:`files`.''') - del _get_input_stream, _set_input_stream - - def _get_query_string(self): - if self._query_string is None: - if self._args is not None: - return url_encode(self._args, charset=self.charset) - return '' - return self._query_string - - def _set_query_string(self, value): - self._query_string = value - self._args = None - - query_string = property(_get_query_string, _set_query_string) - del _get_query_string, _set_query_string - - def _get_args(self): - if self._query_string is not None: - raise AttributeError('a query string is defined') - if self._args is None: - self._args = sdict() - return self._args - - def _set_args(self, value): - self._query_string = None - self._args = value - - args = property(_get_args, _set_args) - del _get_args, _set_args - - @property - def server_name(self): - """The server name (read-only, use :attr:`host` to set)""" - return self.host.split(':', 1)[0] - - @property - def server_port(self): - """The server port as integer (read-only, use :attr:`host` to set)""" - pieces = self.host.split(':', 1) - if len(pieces) == 2 and pieces[1].isdigit(): - return int(pieces[1]) - elif self.url_scheme == 'https': - return 443 - return 80 - - def __del__(self): - try: - self.close() - except Exception: - pass - - def close(self): - """Closes all files. If you put real :class:`file` objects into the - :attr:`files` dict you can call this method to automatically close - them all in one go. - """ - if self.closed: - return - try: - files = self.files.values() - except AttributeError: - files = () - for f in files: - try: - f.close() - except Exception: - pass - self.closed = True - - def get_data(self): - input_stream = self.input_stream - content_length = self.content_length - content_type = self.content_type - - if input_stream is not None: - start_pos = input_stream.tell() - input_stream.seek(0, 2) - end_pos = input_stream.tell() - input_stream.seek(start_pos) - content_length = end_pos - start_pos - elif content_type == 'multipart/form-data': - values = sdict() - for d in [self.files, self.form]: - for key, val in d.items(): - values[key] = val - input_stream, content_length, boundary = \ - stream_encode_multipart(values, charset=self.charset) - content_type += '; boundary="%s"' % boundary - elif content_type == 'application/x-www-form-urlencoded': - values = url_encode(self.form, charset=self.charset) - values = values.encode('ascii') - content_length = len(values) - input_stream = BytesIO(values) - else: - input_stream = BytesIO() - - result = {'headers': [(b'host', self.host.encode('utf-8'))]} - if self.content_type: - result['headers'].extend([ - (b'content-type', content_type.encode('utf-8')), - (b'content-length', str(content_length or 0).encode('utf-8'))]) - if self.environ_base: - result.update(self.environ_base) - - result.update({ - 'type': 'http', - 'http_version': self.server_protocol, - 'server': (self.server_name, self.server_port), - 'client': (self.server_name, self.server_port), - 'method': self.method, - 'scheme': self.url_scheme, - 'root_path': self.script_root, - 'path': self.path, - 'query_string': self.query_string.encode('utf-8'), - 'emt.path': self.path, - 'emt.input': None - }) - for key, value in self.headers.to_asgi_list(): - result['headers'].append( - (key.encode('utf-8'), value.encode('utf-8'))) - if self.environ_overrides: - result.update(self.environ_overrides) - return result, input_stream.getvalue() diff --git a/emmett/testing/helpers.py b/emmett/testing/helpers.py deleted file mode 100644 index 1d35e8ec..00000000 --- a/emmett/testing/helpers.py +++ /dev/null @@ -1,377 +0,0 @@ -# -*- coding: utf-8 -*- -""" - emmett.testing.helpers - ---------------------- - - Provides helpers for testing suite. - - :copyright: 2014 Giovanni Barillari - - Several parts of this code comes from Werkzeug. - :copyright: (c) 2015 by Armin Ronacher. - - :license: BSD-3-Clause -""" - -import codecs -import mimetypes -import re -import sys - -from http.cookiejar import CookieJar -from io import BytesIO -from urllib.request import Request as U2Request - -# TODO: check conversions -from .._shortcuts import to_bytes, to_unicode -from ..datastructures import sdict -from .urls import get_host, uri_to_iri, url_quote - -_quoted_string_re = r'"[^"\\]*(?:\\.[^"\\]*)*"' -_option_header_piece_re = re.compile( - r';\s*(%s|[^\s;,=]+)\s*(?:=\s*(%s|[^;,]+)?)?\s*' % - (_quoted_string_re, _quoted_string_re) -) -_option_header_start_mime_type = re.compile(r',\s*([^;,\s]+)([;,]\s*.+)?') - - -class _TestCookieHeaders(object): - def __init__(self, headers): - self.headers = headers - - def getheaders(self, name): - headers = [] - name = name.lower() - for k, v in self.headers: - if k.lower() == name: - headers.append(v) - return headers - - def get_all(self, name, default=None): - rv = [] - for k, v in self.headers: - if k.lower() == name.lower(): - rv.append(v) - return rv or default or [] - - -class _TestCookieResponse(object): - def __init__(self, headers): - self.headers = _TestCookieHeaders(headers) - - def info(self): - return self.headers - - -class TestCookieJar(CookieJar): - def inject_asgi(self, scope): - cvals = [] - for cookie in self: - cvals.append('%s=%s' % (cookie.name, cookie.value)) - if cvals: - scope['headers'].append( - (b'cookie', '; '.join(cvals).encode('utf-8'))) - - def extract_asgi(self, scope, headers): - self.extract_cookies( - _TestCookieResponse(headers), - U2Request(get_current_url(scope)), - ) - - -class Headers(dict): - def __init__(self, headers=[]): - super(Headers, self).__init__() - self._list = [] - for header in headers: - key, value = header[0].lower(), header[1] - self._list.append((key, value)) - self[key] = value - - def __getitem__(self, name): - return super(Headers, self).__getitem__(name.lower()) - - def __setitem__(self, name, value): - return super(Headers, self).__setitem__(name.lower(), value) - - def __iter__(self): - return iter(self._list) - - def get(self, name, d=None, type=None): - rv = super(Headers, self).get(name.lower) - if rv is None: - return d - if type is None: - return rv - try: - rv = type(rv) - except ValueError: - pass - return rv - - def to_asgi_list(self): - return list(self) - - -class _FileHandler(object): - def __init__(self, stream=None, filename=None, name=None, - content_type=None, content_length=None, - headers=None): - self.name = name - self.stream = stream or BytesIO() - - # if no filename is provided we can attempt to get the filename - # from the stream object passed. There we have to be careful to - # skip things like , etc. Python marks these - # special filenames with angular brackets. - if filename is None: - filename = getattr(stream, 'name', None) - #s = make_literal_wrapper(filename) - if filename and filename[0] == '<' and filename[-1] == '>': - filename = None - - # On Python 3 we want to make sure the filename is always unicode. - # This might not be if the name attribute is bytes due to the - # file being opened from the bytes API. - if isinstance(filename, bytes): - filename = filename.decode(get_filesystem_encoding(), - 'replace') - - self.filename = filename - if headers is None: - headers = Headers() - self.headers = headers - if content_type is not None: - headers['Content-Type'] = content_type - if content_length is not None: - headers['Content-Length'] = str(content_length) - - def _parse_content_type(self): - if not hasattr(self, '_parsed_content_type'): - self._parsed_content_type = \ - parse_options_header(self.content_type) - - @property - def content_type(self): - return self.headers.get('content-type') - - @property - def content_length(self): - return int(self.headers.get('content-length') or 0) - - @property - def mimetype(self): - self._parse_content_type() - return self._parsed_content_type[0].lower() - - @property - def mimetype_params(self): - self._parse_content_type() - return self._parsed_content_type[1] - - def save(self, dst, buffer_size=16384): - """Save the file to a destination path or file object. If the - destination is a file object you have to close it yourself after the - call. The buffer size is the number of bytes held in memory during - the copy process. It defaults to 16KB. - For secure file saving also have a look at :func:`secure_filename`. - :param dst: a filename or open file object the uploaded file - is saved to. - :param buffer_size: the size of the buffer. This works the same as - the `length` parameter of - :func:`shutil.copyfileobj`. - """ - from shutil import copyfileobj - close_dst = False - if isinstance(dst, str): - dst = open(dst, 'wb') - close_dst = True - try: - copyfileobj(self.stream, dst, buffer_size) - finally: - if close_dst: - dst.close() - - def close(self): - """Close the underlying file if possible.""" - try: - self.stream.close() - except Exception: - pass - - def __nonzero__(self): - return bool(self.filename) - __bool__ = __nonzero__ - - def __getattr__(self, name): - return getattr(self.stream, name) - - def __iter__(self): - return iter(self.readline, '') - - def __repr__(self): - return '<%s: %r (%r)>' % ( - self.__class__.__name__, - self.filename, - self.content_type - ) - - -class filesdict(sdict): - def add_file(self, name, file, filename=None, content_type=None): - if isinstance(file, _FileHandler): - value = file - else: - if isinstance(file, str): - if filename is None: - filename = file - file = open(file, 'rb') - if filename and content_type is None: - content_type = mimetypes.guess_type(filename)[0] or \ - 'application/octet-stream' - value = _FileHandler(file, filename, name, content_type) - self[name] = value - - -def get_current_url( - scope, root_only=False, strip_querystring=False, host_only=False, - trusted_hosts=None -): - headers = Headers(scope['headers']) - components = [ - scope['scheme'], '://', get_host(scope, headers, trusted_hosts)] - if host_only: - return uri_to_iri(''.join(components) + '/') - components.extend([ - url_quote(to_bytes(scope['root_path'])).rstrip('/'), '/']) - if not root_only: - components.append(url_quote(to_bytes(scope['path']).lstrip(b'/'))) - if not strip_querystring: - if scope['query_string']: - components.extend(['?', to_unicode(scope['query_string'])]) - return uri_to_iri(''.join(components)) - - -def _is_ascii_encoding(encoding): - if encoding is None: - return False - try: - return codecs.lookup(encoding).name == 'ascii' - except LookupError: - return False - - -def get_filesystem_encoding(): - rv = sys.getfilesystemencoding() - if (sys.platform.startswith('linux') or 'bsd' in sys.platform) and not rv \ - or _is_ascii_encoding(rv): - return 'utf-8' - return rv - - -def unquote_header_value(value, is_filename=False): - if value and value[0] == value[-1] == '"': - # this is not the real unquoting, but fixing this so that the - # RFC is met will result in bugs with internet explorer and - # probably some other browsers as well. IE for example is - # uploading files with "C:\foo\bar.txt" as filename - value = value[1:-1] - - # if this is a filename and the starting characters look like - # a UNC path, then just return the value without quotes. Using the - # replace sequence below on a UNC path has the effect of turning - # the leading double slash into a single slash and then - # _fix_ie_filename() doesn't work correctly. See #458. - if not is_filename or value[:2] != '\\\\': - return value.replace('\\\\', '\\').replace('\\"', '"') - return value - - -def parse_options_header(value, multiple=False): - if not value: - return '', {} - result = [] - value = "," + value.replace("\n", ",") - while value: - match = _option_header_start_mime_type.match(value) - if not match: - break - result.append(match.group(1)) # mimetype - options = {} - # Parse options - rest = match.group(2) - while rest: - optmatch = _option_header_piece_re.match(rest) - if not optmatch: - break - option, option_value = optmatch.groups() - option = unquote_header_value(option) - if option_value is not None: - option_value = unquote_header_value( - option_value, - option == 'filename') - options[option] = option_value - rest = rest[optmatch.end():] - result.append(options) - if multiple is False: - return tuple(result) - value = rest - - return tuple(result) - - -def stream_encode_multipart(values, threshold=1024 * 500, boundary=None, - charset='utf-8'): - """Encode a dict of values (either strings or file descriptors or - :class:`FileStorage` objects.) into a multipart encoded string stored - in a file descriptor. - """ - if boundary is None: - from time import time - from random import random - boundary = '---------------EmmettFormPart_%s%s' % (time(), random()) - _closure = [BytesIO(), 0, False] - - write_binary = _closure[0].write - - def write(string): - write_binary(string.encode(charset)) - - for key, values in values.items(): - if not isinstance(values, list): - values = [values] - for value in values: - write('--%s\r\nContent-Disposition: form-data; name="%s"' % - (boundary, key)) - reader = getattr(value, 'read', None) - if reader is not None: - filename = getattr(value, 'filename', - getattr(value, 'name', None)) - content_type = getattr(value, 'content_type', None) - if content_type is None: - content_type = filename and \ - mimetypes.guess_type(filename)[0] or \ - 'application/octet-stream' - if filename is not None: - write('; filename="%s"\r\n' % filename) - else: - write('\r\n') - write('Content-Type: %s\r\n\r\n' % content_type) - while 1: - chunk = reader(16384) - if not chunk: - break - write_binary(chunk) - else: - if not isinstance(value, str): - value = str(value) - else: - value = to_bytes(value, charset) - write('\r\n\r\n') - write_binary(value) - write('\r\n') - write('--%s--\r\n' % boundary) - - length = int(_closure[0].tell()) - _closure[0].seek(0) - return _closure[0], length, boundary diff --git a/emmett/testing/urls.py b/emmett/testing/urls.py deleted file mode 100644 index 4687070e..00000000 --- a/emmett/testing/urls.py +++ /dev/null @@ -1,535 +0,0 @@ -# -*- coding: utf-8 -*- -""" - emmett.testing.urls - ------------------- - - Provides url helpers for testing suite. - - :copyright: 2014 Giovanni Barillari - - Several parts of this code comes from Werkzeug. - :copyright: (c) 2015 by Armin Ronacher. - - :license: BSD-3-Clause -""" - -import os -import re - -from collections import namedtuple - -# TODO: check conversions -from .._shortcuts import to_unicode -from ..datastructures import sdict - - -_always_safe = ( - b'abcdefghijklmnopqrstuvwxyz' - b'ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_.-+') -_hexdigits = '0123456789ABCDEFabcdef' -_hextobyte = dict( - ((a + b).encode(), int(a + b, 16)) - for a in _hexdigits for b in _hexdigits -) - -_scheme_re = re.compile(r'^[a-zA-Z0-9+-.]+$') - -_URLTuple = namedtuple( - '_URLTuple', - ['scheme', 'netloc', 'path', 'query', 'fragment']) - - -class BaseURL(_URLTuple): - __slots__ = () - - def replace(self, **kwargs): - return self._replace(**kwargs) - - @property - def host(self): - return self._split_host()[0] - - @property - def ascii_host(self): - rv = self.host - if rv is not None and isinstance(rv, str): - try: - rv = _encode_idna(rv) - except UnicodeError: - rv = rv.encode('ascii', 'ignore') - return to_unicode(rv, 'ascii', 'ignore') - - @property - def port(self): - try: - rv = int(to_unicode(self._split_host()[1])) - if 0 <= rv <= 65535: - return rv - except (ValueError, TypeError): - pass - - @property - def auth(self): - return self._split_netloc()[0] - - @property - def username(self): - rv = self._split_auth()[0] - if rv is not None: - return url_unquote(rv) - - @property - def raw_username(self): - return self._split_auth()[0] - - @property - def password(self): - rv = self._split_auth()[1] - if rv is not None: - return url_unquote(rv) - - @property - def raw_password(self): - return self._split_auth()[1] - - def to_url(self): - return url_unparse(self) - - def decode_netloc(self): - rv = _decode_idna(self.host or '') - - if ':' in rv: - rv = '[%s]' % rv - port = self.port - if port is not None: - rv = '%s:%d' % (rv, port) - auth = ':'.join(filter(None, [ - url_unquote( - self.raw_username or '', errors='strict', unsafe='/:%@'), - url_unquote( - self.raw_password or '', errors='strict', unsafe='/:%@'), - ])) - if auth: - rv = '%s@%s' % (auth, rv) - return rv - - def to_uri_tuple(self): - return url_parse(iri_to_uri(self).encode('ascii')) - - def to_iri_tuple(self): - return url_parse(uri_to_iri(self)) - - def get_file_location(self, pathformat=None): - if self.scheme != 'file': - return None, None - - path = url_unquote(self.path) - host = self.netloc or None - - if pathformat is None: - if os.name == 'nt': - pathformat = 'windows' - else: - pathformat = 'posix' - - if pathformat == 'windows': - if path[:1] == '/' and path[1:2].isalpha() and path[2:3] in '|:': - path = path[1:2] + ':' + path[3:] - windows_share = path[:3] in ('\\' * 3, '/' * 3) - import ntpath - path = ntpath.normpath(path) - # Windows shared drives are represented as ``\\host\\directory``. - # That results in a URL like ``file://///host/directory``, and a - # path like ``///host/directory``. We need to special-case this - # because the path contains the hostname. - if windows_share and host is None: - parts = path.lstrip('\\').split('\\', 1) - if len(parts) == 2: - host, path = parts - else: - host = parts[0] - path = '' - elif pathformat == 'posix': - import posixpath - path = posixpath.normpath(path) - else: - raise TypeError('Invalid path format %s' % repr(pathformat)) - - if host in ('127.0.0.1', '::1', 'localhost'): - host = None - - return host, path - - def _split_netloc(self): - if self._at in self.netloc: - return self.netloc.split(self._at, 1) - return None, self.netloc - - def _split_auth(self): - auth = self._split_netloc()[0] - if not auth: - return None, None - if self._colon not in auth: - return auth, None - return auth.split(self._colon, 1) - - def _split_host(self): - rv = self._split_netloc()[1] - if not rv: - return None, None - - if not rv.startswith(self._lbracket): - if self._colon in rv: - return rv.split(self._colon, 1) - return rv, None - - idx = rv.find(self._rbracket) - if idx < 0: - return rv, None - - host = rv[1:idx] - rest = rv[idx + 1:] - if rest.startswith(self._colon): - return host, rest[1:] - return host, None - - -class URL(BaseURL): - __slots__ = () - _at = '@' - _colon = ':' - _lbracket = '[' - _rbracket = ']' - - def __str__(self): - return self.to_url() - - def encode_netloc(self): - rv = self.ascii_host or '' - if ':' in rv: - rv = '[%s]' % rv - port = self.port - if port is not None: - rv = '%s:%d' % (rv, port) - auth = ':'.join(filter(None, [ - url_quote(self.raw_username or '', 'utf-8', 'strict', '/:%'), - url_quote(self.raw_password or '', 'utf-8', 'strict', '/:%'), - ])) - if auth: - rv = '%s@%s' % (auth, rv) - return to_unicode(rv) - - def encode(self, charset='utf-8', errors='replace'): - return BytesURL( - self.scheme.encode('ascii'), - self.encode_netloc(), - self.path.encode(charset, errors), - self.query.encode(charset, errors), - self.fragment.encode(charset, errors) - ) - - -class BytesURL(BaseURL): - __slots__ = () - _at = b'@' - _colon = b':' - _lbracket = b'[' - _rbracket = b']' - - def __str__(self): - return self.to_url().decode('utf-8', 'replace') - - def encode_netloc(self): - return self.netloc - - def decode(self, charset='utf-8', errors='replace'): - return URL( - self.scheme.decode('ascii'), - self.decode_netloc(), - self.path.decode(charset, errors), - self.query.decode(charset, errors), - self.fragment.decode(charset, errors) - ) - - -def url_quote(string, charset='utf-8', errors='strict', safe='/:', unsafe=''): - if not isinstance(string, (str, bytes, bytearray)): - string = str(string) - if isinstance(string, str): - string = string.encode(charset, errors) - if isinstance(safe, str): - safe = safe.encode(charset, errors) - if isinstance(unsafe, str): - unsafe = unsafe.encode(charset, errors) - safe = frozenset(bytearray(safe) + _always_safe) - \ - frozenset(bytearray(unsafe)) - rv = bytearray() - for char in bytearray(string): - if char in safe: - rv.append(char) - else: - rv.extend(('%%%02X' % char).encode('ascii')) - return to_unicode(bytes(rv)) - - -def _unquote_to_bytes(string, unsafe=''): - if isinstance(string, str): - string = string.encode('utf-8') - if isinstance(unsafe, str): - unsafe = unsafe.encode('utf-8') - unsafe = frozenset(bytearray(unsafe)) - bits = iter(string.split(b'%')) - result = bytearray(next(bits, b'')) - for item in bits: - try: - char = _hextobyte[item[:2]] - if char in unsafe: - raise KeyError() - result.append(char) - result.extend(item[2:]) - except KeyError: - result.extend(b'%') - result.extend(item) - return bytes(result) - - -def url_unquote(string, charset='utf-8', errors='replace', unsafe=''): - rv = _unquote_to_bytes(string, unsafe) - if charset is not None: - rv = rv.decode(charset, errors) - return rv - - -def url_quote_plus(string, charset='utf-8', errors='strict', safe=''): - return url_quote( - string, charset, errors, safe + ' ', '+').replace(' ', '+') - - -def url_unquote_plus(s, charset='utf-8', errors='replace'): - if isinstance(s, str): - s = s.replace(u'+', u' ') - else: - s = s.replace(b'+', b' ') - return url_unquote(s, charset, errors) - - -def url_parse(url, scheme=None, allow_fragments=True): - #s = make_literal_wrapper(url) - is_text_based = isinstance(url, str) - if scheme is None: - scheme = '' - netloc = query = fragment = '' - i = url.find(':') - if i > 0 and _scheme_re.match(to_unicode(url[:i], errors='replace')): - # make sure "iri" is not actually a port number (in which case - # "scheme" is really part of the path) - rest = url[i + 1:] - if not rest or any(c not in '0123456789' for c in rest): - # not a port number - scheme, url = url[:i].lower(), rest - - if url[:2] == '//': - delim = len(url) - for c in '/?#': - wdelim = url.find(c, 2) - if wdelim >= 0: - delim = min(delim, wdelim) - netloc, url = url[2:delim], url[delim:] - if ('[' in netloc and ']' not in netloc) or \ - (']' in netloc and '[' not in netloc): - raise ValueError('Invalid IPv6 URL') - - if allow_fragments and '#' in url: - url, fragment = url.split('#', 1) - if '?' in url: - url, query = url.split('?', 1) - - result_type = is_text_based and URL or BytesURL - return result_type(scheme, netloc, url, query, fragment) - - -def url_unparse(components): - scheme, netloc, path, query, fragment = components - # normalize_string_tuple(components) - #s = make_literal_wrapper(scheme) - url = '' - - # We generally treat file:///x and file:/x the same which is also - # what browsers seem to do. This also allows us to ignore a schema - # register for netloc utilization or having to differenciate between - # empty and missing netloc. - if netloc or (scheme and path.startswith('/')): - if path and path[:1] != '/': - path = '/' + path - url = '//' + (netloc or '') + path - elif path: - url += path - if scheme: - url = scheme + ':' + url - if query: - url = url + '?' + query - if fragment: - url = url + '#' + fragment - return url - - -def _url_encode_impl(obj, charset, encode_keys, sort, key): - iterable = sdict() - for key, values in obj.items(): - if not isinstance(values, list): - values = [values] - iterable[key] = values - if sort: - iterable = sorted(iterable, key=key) - for key, values in iterable.items(): - for value in values: - if value is None: - continue - if not isinstance(key, bytes): - key = str(key).encode(charset) - if not isinstance(value, bytes): - value = str(value).encode(charset) - yield url_quote_plus(key) + '=' + url_quote_plus(value) - - -def url_encode(obj, charset='utf-8', encode_keys=False, sort=False, key=None, - separator=b'&'): - separator = to_unicode(separator, 'ascii') - return separator.join( - _url_encode_impl(obj, charset, encode_keys, sort, key)) - - -def uri_to_iri(uri, charset='utf-8', errors='replace'): - if isinstance(uri, tuple): - uri = url_unparse(uri) - uri = url_parse(to_unicode(uri, charset)) - path = url_unquote(uri.path, charset, errors, '%/;?') - query = url_unquote(uri.query, charset, errors, '%;/?:@&=+,$#') - fragment = url_unquote(uri.fragment, charset, errors, '%;/?:@&=+,$#') - return url_unparse((uri.scheme, uri.decode_netloc(), - path, query, fragment)) - - -def iri_to_uri(iri, charset='utf-8', errors='strict'): - if isinstance(iri, tuple): - iri = url_unparse(iri) - iri = url_parse(to_unicode(iri, charset, errors)) - - netloc = iri.encode_netloc() - path = url_quote(iri.path, charset, errors, '/:~+%') - query = url_quote(iri.query, charset, errors, '%&[]:;$*()+,!?*/=') - fragment = url_quote(iri.fragment, charset, errors, '=%&[]:;$()+,!?*/') - - return to_unicode(url_unparse((iri.scheme, netloc, path, query, fragment))) - - -def url_fix(s, charset='utf-8'): - # First step is to switch to unicode processing and to convert - # backslashes (which are invalid in URLs anyways) to slashes. This is - # consistent with what Chrome does. - s = to_unicode(s, charset, 'replace').replace('\\', '/') - - # For the specific case that we look like a malformed windows URL - # we want to fix this up manually: - if ( - s.startswith('file://') and s[7:8].isalpha() and - s[8:10] in (':/', '|/') - ): - s = 'file:///' + s[7:] - - url = url_parse(s) - path = url_quote(url.path, charset, safe='/%+$!*\'(),') - qs = url_quote_plus(url.query, charset, safe=':&%=+$!*\'(),') - anchor = url_quote_plus(url.fragment, charset, safe=':&%=+$!*\'(),') - return to_unicode( - url_unparse((url.scheme, url.encode_netloc(), path, qs, anchor))) - - -def _encode_idna(domain): - # If we're given bytes, make sure they fit into ASCII - if not isinstance(domain, str): - domain.decode('ascii') - return domain - - # Otherwise check if it's already ascii, then return - try: - return domain.encode('ascii') - except UnicodeError: - pass - - # Otherwise encode each part separately - parts = domain.split('.') - for idx, part in enumerate(parts): - parts[idx] = part.encode('idna') - return b'.'.join(parts) - - -def _decode_idna(domain): - # If the input is a string try to encode it to ascii to - # do the idna decoding. if that fails because of an - # unicode error, then we already have a decoded idna domain - if isinstance(domain, str): - try: - domain = domain.encode('ascii') - except UnicodeError: - return domain - - # Decode each part separately. If a part fails, try to - # decode it with ascii and silently ignore errors. This makes - # most sense because the idna codec does not have error handling - parts = domain.split(b'.') - for idx, part in enumerate(parts): - try: - parts[idx] = part.decode('idna') - except UnicodeError: - parts[idx] = part.decode('ascii', 'ignore') - - return '.'.join(parts) - - -def _host_is_trusted(hostname, trusted_list): - if not hostname: - return False - - if isinstance(trusted_list, str): - trusted_list = [trusted_list] - - def _normalize(hostname): - if ':' in hostname: - hostname = hostname.rsplit(':', 1)[0] - return _encode_idna(hostname) - - try: - hostname = _normalize(hostname) - except UnicodeError: - return False - for ref in trusted_list: - if ref.startswith('.'): - ref = ref[1:] - suffix_match = True - else: - suffix_match = False - try: - ref = _normalize(ref) - except UnicodeError: - return False - if ref == hostname: - return True - if suffix_match and hostname.endswith('.' + ref): - return True - return False - - -def get_host(scope, headers, trusted_hosts=None): - if 'x-forwarded-host' in headers: - rv = headers['x-forwarded-host'].split(',', 1)[0].strip() - elif 'host' in headers: - rv = headers['host'] - else: - rv = scope['server'][0] - if ( - (scope['scheme'], scope['server'][1]) not in - (('https', '443'), ('http', '80')) - ): - rv += ':{}'.format(scope['server'][1]) - return rv diff --git a/emmett/tools/__init__.py b/emmett/tools/__init__.py index eb311265..c4deab59 100644 --- a/emmett/tools/__init__.py +++ b/emmett/tools/__init__.py @@ -1,4 +1,4 @@ -from .service import ServicePipe from .auth import Auth -from .mailer import Mailer from .decorators import requires, service +from .mailer import Mailer +from .service import ServicePipe diff --git a/emmett/tools/auth/apis.py b/emmett/tools/auth/apis.py index 2fd065e9..66e61e99 100644 --- a/emmett/tools/auth/apis.py +++ b/emmett/tools/auth/apis.py @@ -1,12 +1,12 @@ # -*- coding: utf-8 -*- """ - emmett.tools.auth.apis - ---------------------- +emmett.tools.auth.apis +---------------------- - Provides the interface for the auth system. +Provides the interface for the auth system. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ from __future__ import annotations @@ -14,52 +14,42 @@ from datetime import timedelta from typing import Any, Callable, Dict, List, Optional, Type, Union +from emmett_core.routing.cache import RouteCacheRule from pydal.helpers.classes import Reference as _RecordReference -from ...cache import RouteCacheRule from ...datastructures import sdict from ...locals import now, session -from ...orm.objects import Row from ...orm.models import Model -from ...pipeline import Pipe, Injector -from .ext import AuthExtension +from ...orm.objects import Row +from ...pipeline import Injector, Pipe from .exposer import AuthModule +from .ext import AuthExtension class Auth: - def __init__( - self, - app, - db, - user_model=None, - group_model=None, - membership_model=None, - permission_model=None - ): + def __init__(self, app, db, user_model=None, group_model=None, membership_model=None, permission_model=None): self.ext = app.use_extension(AuthExtension) self.ext.bind_auth(self) - self.ext.use_database( - db, user_model, group_model, membership_model, permission_model - ) + self.ext.use_database(db, user_model, group_model, membership_model, permission_model) self.ext.init_forms() self.pipe = AuthPipe(self) def module( self, import_name: str, - name: str = 'auth', - template_folder: str = 'auth', + name: str = "auth", + template_folder: str = "auth", template_path: Optional[str] = None, static_folder: Optional[str] = None, static_path: Optional[str] = None, - url_prefix: Optional[str] = 'auth', + url_prefix: Optional[str] = "auth", hostname: Optional[str] = None, cache: Optional[RouteCacheRule] = None, root_path: Optional[str] = None, pipeline: Optional[List[Pipe]] = None, injectors: Optional[List[Injector]] = None, module_class: Type[AuthModule] = AuthModule, - **kwargs: Any + **kwargs: Any, ) -> AuthModule: return module_class.from_app( self.ext.app, @@ -75,7 +65,7 @@ def module( root_path=root_path, pipeline=pipeline or [], injectors=injectors or [], - opts=kwargs + opts=kwargs, ) @property @@ -83,7 +73,7 @@ def models(self) -> Dict[str, Model]: return self.ext.config.models def group_for_role(self, role: str) -> Row: - return self.models['group'].get(role=role) + return self.models["group"].get(role=role) #: context @property @@ -106,7 +96,7 @@ def has_membership( self, group: Optional[Union[str, int, Row]] = None, user: Optional[Union[Row, int]] = None, - role: Optional[str] = None + role: Optional[str] = None, ) -> bool: rv = False if not group and role: @@ -116,25 +106,28 @@ def has_membership( if not user and self.user: user = self.user.id if group and user: - if self.models['membership'].where( - lambda m: - (m.table[self.ext.relation_names['user']] == user) & - (m.table[self.ext.relation_names['group']] == group) - ).count(): + if ( + self.models["membership"] + .where( + lambda m: (m.table[self.ext.relation_names["user"]] == user) + & (m.table[self.ext.relation_names["group"]] == group) + ) + .count() + ): rv = True return rv def has_permission( self, - name: str = 'any', + name: str = "any", table_name: Optional[str] = None, record_id: Optional[int] = None, user: Optional[Union[int, Row]] = None, - group: Optional[Union[str, int, Row]] = None + group: Optional[Union[str, int, Row]] = None, ) -> bool: - permission = self.models['permission'] + permission = self.models["permission"] parent = None - query = (permission.name == name) + query = permission.name == name if table_name: query = query & (permission.table_name == table_name) if record_id: @@ -144,98 +137,72 @@ def has_permission( if not user and not group: return False if user is not None: - parent = self.models['user'].get(id=user) + parent = self.models["user"].get(id=user) elif group is not None: if isinstance(group, str): group = self.group_for_role(group) - parent = self.models['group'].get(id=group) + parent = self.models["group"].get(id=group) if not parent: return False - return ( - parent[self.ext.relation_names['permission'] + 's'].where( - query - ).count() > 0 - ) + return parent[self.ext.relation_names["permission"] + "s"].where(query).count() > 0 #: operations - def create_group(self, role: str, description: str = '') -> _RecordReference: - res = self.models['group'].create( - role=role, description=description - ) + def create_group(self, role: str, description: str = "") -> _RecordReference: + res = self.models["group"].create(role=role, description=description) return res.id def delete_group(self, group: Union[str, int, Row]) -> int: if isinstance(group, str): group = self.group_for_role(group) - return self.ext.db(self.models['group'].id == group).delete() + return self.ext.db(self.models["group"].id == group).delete() - def add_membership( - self, - group: Union[str, int, Row], - user: Optional[Row] = None - ) -> _RecordReference: + def add_membership(self, group: Union[str, int, Row], user: Optional[Row] = None) -> _RecordReference: if isinstance(group, int): - group = self.models['group'].get(group) + group = self.models["group"].get(group) elif isinstance(group, str): group = self.group_for_role(group) if not user and self.user: user = self.user.id - res = getattr( - group, self.ext.relation_names['user'] + 's' - ).add(user) + res = getattr(group, self.ext.relation_names["user"] + "s").add(user) return res.id - def remove_membership( - self, - group: Union[str, int, Row], - user: Optional[Row] = None - ): + def remove_membership(self, group: Union[str, int, Row], user: Optional[Row] = None): if isinstance(group, int): - group = self.models['group'].get(group) + group = self.models["group"].get(group) elif isinstance(group, str): group = self.group_for_role(group) if not user and self.user: user = self.user.id - return getattr( - group, self.ext.relation_names['user'] + 's' - ).remove(user) + return getattr(group, self.ext.relation_names["user"] + "s").remove(user) def login(self, email: str, password: str): - user = self.models['user'].get(email=email) - if user and user.get('password', False): - password = self.models['user'].password.validate(password)[0] + user = self.models["user"].get(email=email) + if user and user.get("password", False): + password = self.models["user"].password.validate(password)[0] if not user.registration_key and password == user.password: self.ext.login_user(user) return user return None def change_user_status(self, user: Union[int, Row], status: str) -> int: - return self.ext.db(self.models['user'].id == user).update( - registration_key=status - ) + return self.ext.db(self.models["user"].id == user).update(registration_key=status) def disable_user(self, user: Union[int, Row]) -> int: - return self.change_user_status(user, 'disabled') + return self.change_user_status(user, "disabled") def block_user(self, user: Union[int, Row]) -> int: - return self.change_user_status(user, 'blocked') + return self.change_user_status(user, "blocked") def allow_user(self, user: Union[int, Row]) -> int: - return self.change_user_status(user, '') + return self.change_user_status(user, "") #: emails decorators - def registration_mail( - self, - f: Callable[[Row, Dict[str, Any]], bool] - ) -> Callable[[Row, Dict[str, Any]], bool]: - self.ext.mails['registration'] = f + def registration_mail(self, f: Callable[[Row, Dict[str, Any]], bool]) -> Callable[[Row, Dict[str, Any]], bool]: + self.ext.mails["registration"] = f return f - def reset_password_mail( - self, - f: Callable[[Row, Dict[str, Any]], bool] - ) -> Callable[[Row, Dict[str, Any]], bool]: - self.ext.mails['reset_password'] = f + def reset_password_mail(self, f: Callable[[Row, Dict[str, Any]], bool]) -> Callable[[Row, Dict[str, Any]], bool]: + self.ext.mails["reset_password"] = f return f @@ -255,26 +222,20 @@ def session_open(self): return #: is session expired? visit_dt = now().as_naive_datetime() - if ( - authsess.last_visit + timedelta(seconds=authsess.expiration) < - visit_dt - ): + if authsess.last_visit + timedelta(seconds=authsess.expiration) < visit_dt: del session.auth #: does session need re-sync with db? elif authsess.last_dbcheck + timedelta(seconds=300) < visit_dt: if self.auth.user: #: is user still valid? - dbrow = self.auth.models['user'].get(self.auth.user.id) + dbrow = self.auth.models["user"].get(self.auth.user.id) if dbrow and not dbrow.registration_key: self.auth.ext.login_user(dbrow, authsess.remember) else: del session.auth else: #: set last_visit if make sense - if ( - (visit_dt - authsess.last_visit).seconds > - min(authsess.expiration / 10, 600) - ): + if (visit_dt - authsess.last_visit).seconds > min(authsess.expiration / 10, 600): authsess.last_visit = visit_dt def session_close(self): diff --git a/emmett/tools/auth/exposer.py b/emmett/tools/auth/exposer.py index c998dd7a..54d73fbf 100644 --- a/emmett/tools/auth/exposer.py +++ b/emmett/tools/auth/exposer.py @@ -1,20 +1,21 @@ # -*- coding: utf-8 -*- """ - emmett.tools.auth.exposer - ------------------------- +emmett.tools.auth.exposer +------------------------- - Provides the routes layer for the auth system. +Provides the routes layer for the auth system. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ from __future__ import annotations from typing import Any, List, Optional +from emmett_core.routing.cache import RouteCacheRule + from ...app import App, AppModule -from ...cache import RouteCacheRule from ...helpers import flash, stream_dbfile from ...http import redirect from ...locals import request, session @@ -38,7 +39,7 @@ def __init__( root_path: Optional[str] = None, pipeline: Optional[List[Pipe]] = None, injectors: Optional[List[Injector]] = None, - **kwargs: Any + **kwargs: Any, ): super().__init__( app=app, @@ -54,7 +55,7 @@ def __init__( root_path=root_path, pipeline=pipeline, injectors=injectors, - **kwargs + **kwargs, ) self.init() @@ -63,30 +64,27 @@ def init(self): self.auth = self.ext.auth self.config = self.ext.config self._callbacks = { - 'after_login': self._after_login, - 'after_logout': self._after_logout, - 'after_registration': self._after_registration, - 'after_profile': self._after_profile, - 'after_email_verification': self._after_email_verification, - 'after_password_retrieval': self._after_password_retrieval, - 'after_password_reset': self._after_password_reset, - 'after_password_change': self._after_password_change + "after_login": self._after_login, + "after_logout": self._after_logout, + "after_registration": self._after_registration, + "after_profile": self._after_profile, + "after_email_verification": self._after_email_verification, + "after_password_retrieval": self._after_password_retrieval, + "after_password_reset": self._after_password_reset, + "after_password_change": self._after_password_change, } auth_pipe = [] if not self.config.inject_pipe else [self.auth.pipe] - requires_login = [ - RequirePipe( - lambda: self.auth.is_logged(), - lambda: redirect(self.url('login')))] + requires_login = [RequirePipe(lambda: self.auth.is_logged(), lambda: redirect(self.url("login")))] self._methods_pipelines = { - 'login': [], - 'logout': auth_pipe + requires_login, - 'registration': [], - 'profile': auth_pipe + requires_login, - 'email_verification': [], - 'password_retrieval': [], - 'password_reset': [], - 'password_change': auth_pipe + requires_login, - 'download': [] + "login": [], + "logout": auth_pipe + requires_login, + "registration": [], + "profile": auth_pipe + requires_login, + "email_verification": [], + "password_retrieval": [], + "password_reset": [], + "password_change": auth_pipe + requires_login, + "download": [], # 'not_authorized': [] } self.enabled_routes = list(self.config.enabled_routes) @@ -105,63 +103,57 @@ def init(self): self.ext.bind_exposer(self) def _template_for(self, key): - tname = 'auth' if self.config.single_template else key - return '{}{}'.format(tname, self.app.template_default_extension) + tname = "auth" if self.config.single_template else key + return "{}{}".format(tname, self.app.template_default_extension) def url(self, path, *args, **kwargs): path = "{}.{}".format(self.name, path) return url(path, *args, **kwargs) def _flash(self, message): - return flash(message, 'auth') + return flash(message, "auth") #: routes async def _login(self): def _validate_form(form): - row = self.config.models['user'].get(email=form.params.email) + row = self.config.models["user"].get(email=form.params.email) if row: #: verify password if form.params.password == row.password: - res['user'] = row + res["user"] = row return - form.errors.email = self.config.messages['login_invalid'] + form.errors.email = self.config.messages["login_invalid"] - rv = {'message': None} + rv = {"message": None} res = {} - rv['form'] = await self.ext.forms.login(onvalidation=_validate_form) - if rv['form'].accepted: + rv["form"] = await self.ext.forms.login(onvalidation=_validate_form) + if rv["form"].accepted: messages = self.config.messages - if res['user'].registration_key == 'pending': - rv['message'] = messages['approval_pending'] - elif res['user'].registration_key in ('disabled', 'blocked'): - rv['message'] = messages['login_disabled'] - elif ( - res['user'].registration_key is not None and - res['user'].registration_key.strip() - ): - rv['message'] = messages['verification_pending'] - if rv['message']: - self.flash(rv['message']) + if res["user"].registration_key == "pending": + rv["message"] = messages["approval_pending"] + elif res["user"].registration_key in ("disabled", "blocked"): + rv["message"] = messages["login_disabled"] + elif res["user"].registration_key is not None and res["user"].registration_key.strip(): + rv["message"] = messages["verification_pending"] + if rv["message"]: + self.flash(rv["message"]) else: - self.ext.login_user( - res['user'], rv['form'].params.get('remember', False)) - self.ext.log_event( - self.config.messages['login_log'], {'id': res['user'].id}) + self.ext.login_user(res["user"], rv["form"].params.get("remember", False)) + self.ext.log_event(self.config.messages["login_log"], {"id": res["user"].id}) redirect_after = (await request.body_params)._after if redirect_after: redirect(redirect_after) - self._callbacks['after_login'](rv['form']) + self._callbacks["after_login"](rv["form"]) return rv async def _logout(self): - self.ext.log_event( - self.config.messages['logout_log'], {'id': self.auth.user.id}) + self.ext.log_event(self.config.messages["logout_log"], {"id": self.auth.user.id}) session.auth = None - self.flash(self.config.messages['logged_out']) + self.flash(self.config.messages["logged_out"]) redirect_after = request.query_params._after if redirect_after: redirect(redirect_after) - self._callbacks['after_logout']() + self._callbacks["after_logout"]() async def _registration(self): def _validate_form(form): @@ -170,122 +162,105 @@ def _validate_form(form): form.errors.password2 = "password mismatch" return del form.params.password2 - res['id'] = self.config.models['user'].table.insert( - **form.params) + res["id"] = self.config.models["user"].table.insert(**form.params) - rv = {'message': None} + rv = {"message": None} res = {} - rv['form'] = await self.ext.forms.registration( - onvalidation=_validate_form) - if rv['form'].accepted: + rv["form"] = await self.ext.forms.registration(onvalidation=_validate_form) + if rv["form"].accepted: logged_in = False - row = self.config.models['user'].get(res['id']) + row = self.config.models["user"].get(res["id"]) if self.config.registration_verification: - email_data = { - 'link': self.url( - 'email_verification', row.registration_key, - scheme=True)} - if not self.ext.mails['registration'](row, email_data): - rv['message'] = self.config.messages['mail_failure'] + email_data = {"link": self.url("email_verification", row.registration_key, scheme=True)} + if not self.ext.mails["registration"](row, email_data): + rv["message"] = self.config.messages["mail_failure"] self.ext.db.rollback() - self.flash(rv['message']) + self.flash(rv["message"]) return rv - rv['message'] = self.config.messages['mail_success'] - self.flash(rv['message']) + rv["message"] = self.config.messages["mail_success"] + self.flash(rv["message"]) elif self.config.registration_approval: - rv['message'] = self.config.messages['approval_pending'] - self.flash(rv['message']) + rv["message"] = self.config.messages["approval_pending"] + self.flash(rv["message"]) else: - rv['message'] = self.config.messages['registration_success'] - self.flash(rv['message']) + rv["message"] = self.config.messages["registration_success"] + self.flash(rv["message"]) self.ext.login_user(row) logged_in = True - self.ext.log_event( - self.config.messages['registration_log'], - {'id': res['id']}) + self.ext.log_event(self.config.messages["registration_log"], {"id": res["id"]}) redirect_after = (await request.body_params)._after if redirect_after: redirect(redirect_after) - self._callbacks['after_registration'](rv['form'], row, logged_in) + self._callbacks["after_registration"](rv["form"], row, logged_in) return rv async def _profile(self): - rv = {'message': None, 'form': await self.ext.forms.profile()} - if rv['form'].accepted: - self.auth.user.update( - self.config.models['user'].table._filter_fields( - rv['form'].params)) - rv['message'] = self.config.messages['profile_updated'] - self.flash(rv['message']) - self.ext.log_event( - self.config.messages['profile_log'], {'id': self.auth.user.id}) + rv = {"message": None, "form": await self.ext.forms.profile()} + if rv["form"].accepted: + self.auth.user.update(self.config.models["user"].table._filter_fields(rv["form"].params)) + rv["message"] = self.config.messages["profile_updated"] + self.flash(rv["message"]) + self.ext.log_event(self.config.messages["profile_log"], {"id": self.auth.user.id}) redirect_after = (await request.body_params)._after if redirect_after: redirect(redirect_after) - self._callbacks['after_profile'](rv['form']) + self._callbacks["after_profile"](rv["form"]) return rv async def _email_verification(self, key): - rv = {'message': None} - user = self.config.models['user'].get(registration_key=key) + rv = {"message": None} + user = self.config.models["user"].get(registration_key=key) if not user: - redirect(self.url('login')) + redirect(self.url("login")) if self.config.registration_approval: - user.update_record(registration_key='pending') - rv['message'] = self.config.messages['approval_pending'] - self.flash(rv['message']) + user.update_record(registration_key="pending") + rv["message"] = self.config.messages["approval_pending"] + self.flash(rv["message"]) else: - user.update_record(registration_key='') - rv['message'] = self.config.messages['verification_success'] - self.flash(rv['message']) + user.update_record(registration_key="") + rv["message"] = self.config.messages["verification_success"] + self.flash(rv["message"]) #: make sure session has same user.registration_key as db record if self.auth.user: self.auth.user.registration_key = user.registration_key - self.ext.log_event( - self.config.messages['email_verification_log'], {'id': user.id}) + self.ext.log_event(self.config.messages["email_verification_log"], {"id": user.id}) redirect_after = request.query_params._after if redirect_after: redirect(redirect_after) - self._callbacks['after_email_verification'](user) + self._callbacks["after_email_verification"](user) return rv async def _password_retrieval(self): def _validate_form(form): messages = self.config.messages - row = self.config.models['user'].get(email=form.params.email) + row = self.config.models["user"].get(email=form.params.email) if not row: form.errors.email = "invalid email" return - if row.registration_key == 'pending': - form.errors.email = messages['approval_pending'] + if row.registration_key == "pending": + form.errors.email = messages["approval_pending"] return - if row.registration_key == 'blocked': - form.errors.email = messages['login_disabled'] + if row.registration_key == "blocked": + form.errors.email = messages["login_disabled"] return - res['user'] = row + res["user"] = row - rv = {'message': None} + rv = {"message": None} res = {} - rv['form'] = await self.ext.forms.password_retrieval( - onvalidation=_validate_form) - if rv['form'].accepted: - user = res['user'] + rv["form"] = await self.ext.forms.password_retrieval(onvalidation=_validate_form) + if rv["form"].accepted: + user = res["user"] reset_key = self.ext.generate_reset_key(user) - email_data = { - 'link': self.url( - 'password_reset', reset_key, scheme=True)} - if not self.ext.mails['reset_password'](user, email_data): - rv['message'] = self.config.messages['mail_failure'] - rv['message'] = self.config.messages['mail_success'] - self.flash(rv['message']) - self.ext.log_event( - self.config.messages['password_retrieval_log'], - {'id': user.id}, - user=user) + email_data = {"link": self.url("password_reset", reset_key, scheme=True)} + if not self.ext.mails["reset_password"](user, email_data): + rv["message"] = self.config.messages["mail_failure"] + rv["message"] = self.config.messages["mail_success"] + self.flash(rv["message"]) + self.ext.log_event(self.config.messages["password_retrieval_log"], {"id": user.id}, user=user) redirect_after = (await request.body_params)._after if redirect_after: redirect(redirect_after) - self._callbacks['after_password_retrieval'](user) + self._callbacks["after_password_retrieval"](user) return rv async def _password_reset(self, key): @@ -294,63 +269,49 @@ def _validate_form(form): form.errors.password = "password mismatch" form.errors.password2 = "password mismatch" - rv = {'message': None} + rv = {"message": None} redirect_after = request.query_params._after user = self.ext.get_user_by_reset_key(key) if not user: - rv['message'] = self.config.messages['reset_key_invalid'] - self.flash(rv['message']) + rv["message"] = self.config.messages["reset_key_invalid"] + self.flash(rv["message"]) if redirect_after: redirect(redirect_after) - self._callbacks['after_password_reset'](user) + self._callbacks["after_password_reset"](user) return rv - rv['form'] = await self.ext.forms.password_reset( - onvalidation=_validate_form) - if rv['form'].accepted: - user.update_record( - password=str(rv['form'].params.password), - registration_key='', - reset_password_key='' - ) - rv['message'] = self.config.messages['password_changed'] - self.flash(rv['message']) - self.ext.log_event( - self.config.messages['password_reset_log'], - {'id': user.id}, - user=user) + rv["form"] = await self.ext.forms.password_reset(onvalidation=_validate_form) + if rv["form"].accepted: + user.update_record(password=str(rv["form"].params.password), registration_key="", reset_password_key="") + rv["message"] = self.config.messages["password_changed"] + self.flash(rv["message"]) + self.ext.log_event(self.config.messages["password_reset_log"], {"id": user.id}, user=user) if redirect_after: redirect(redirect_after) - self._callbacks['after_password_reset'](user) + self._callbacks["after_password_reset"](user) return rv async def _password_change(self): def _validate_form(form): messages = self.config.messages if form.params.old_password != row.password: - form.errors.old_password = messages['invalid_password'] + form.errors.old_password = messages["invalid_password"] return - if ( - form.params.new_password.password != - form.params.new_password2.password - ): + if form.params.new_password.password != form.params.new_password2.password: form.errors.new_password = "password mismatch" form.errors.new_password2 = "password mismatch" - rv = {'message': None} - row = self.config.models['user'].get(self.auth.user.id) - rv['form'] = await self.ext.forms.password_change( - onvalidation=_validate_form) - if rv['form'].accepted: - row.update_record(password=str(rv['form'].params.new_password)) - rv['message'] = self.config.messages['password_changed'] - self.flash(rv['message']) - self.ext.log_event( - self.config.messages['password_change_log'], - {'id': row.id}) + rv = {"message": None} + row = self.config.models["user"].get(self.auth.user.id) + rv["form"] = await self.ext.forms.password_change(onvalidation=_validate_form) + if rv["form"].accepted: + row.update_record(password=str(rv["form"].params.new_password)) + rv["message"] = self.config.messages["password_changed"] + self.flash(rv["message"]) + self.ext.log_event(self.config.messages["password_change_log"], {"id": row.id}) redirect_after = request.query_params._after if redirect_after: redirect(redirect_after) - self._callbacks['after_password_change']() + self._callbacks["after_password_change"]() return rv def _download(self, file_name): @@ -358,70 +319,96 @@ def _download(self, file_name): #: routes decorators def login(self, template=None, pipeline=None, injectors=None): - pipeline = self._methods_pipelines['login'] + (pipeline or []) + pipeline = self._methods_pipelines["login"] + (pipeline or []) return self.route( - self.config['routes_paths']['login'], name='login', - template=template or self._template_for('login'), - pipeline=pipeline, injectors=injectors or []) + self.config["routes_paths"]["login"], + name="login", + template=template or self._template_for("login"), + pipeline=pipeline, + injectors=injectors or [], + ) def logout(self, template=None, pipeline=None, injectors=None): - pipeline = self._methods_pipelines['logout'] + (pipeline or []) + pipeline = self._methods_pipelines["logout"] + (pipeline or []) return self.route( - self.config['routes_paths']['logout'], name='logout', - template=template or self._template_for('logout'), - pipeline=pipeline, injectors=injectors or [], methods='get') + self.config["routes_paths"]["logout"], + name="logout", + template=template or self._template_for("logout"), + pipeline=pipeline, + injectors=injectors or [], + methods="get", + ) def registration(self, template=None, pipeline=None, injectors=None): - pipeline = self._methods_pipelines['registration'] + (pipeline or []) + pipeline = self._methods_pipelines["registration"] + (pipeline or []) return self.route( - self.config['routes_paths']['registration'], name='registration', - template=template or self._template_for('registration'), - pipeline=pipeline, injectors=injectors or []) + self.config["routes_paths"]["registration"], + name="registration", + template=template or self._template_for("registration"), + pipeline=pipeline, + injectors=injectors or [], + ) def profile(self, template=None, pipeline=None, injectors=None): - pipeline = self._methods_pipelines['profile'] + (pipeline or []) + pipeline = self._methods_pipelines["profile"] + (pipeline or []) return self.route( - self.config['routes_paths']['profile'], name='profile', - template=template or self._template_for('profile'), - pipeline=pipeline, injectors=injectors or []) + self.config["routes_paths"]["profile"], + name="profile", + template=template or self._template_for("profile"), + pipeline=pipeline, + injectors=injectors or [], + ) def email_verification(self, template=None, pipeline=None, injectors=None): - pipeline = self._methods_pipelines['email_verification'] + (pipeline or []) + pipeline = self._methods_pipelines["email_verification"] + (pipeline or []) return self.route( - self.config['routes_paths']['email_verification'], - name='email_verification', - template=template or self._template_for('email_verification'), - pipeline=pipeline, injectors=injectors or [], methods='get') + self.config["routes_paths"]["email_verification"], + name="email_verification", + template=template or self._template_for("email_verification"), + pipeline=pipeline, + injectors=injectors or [], + methods="get", + ) def password_retrieval(self, template=None, pipeline=None, injectors=None): - pipeline = self._methods_pipelines['password_retrieval'] + (pipeline or []) + pipeline = self._methods_pipelines["password_retrieval"] + (pipeline or []) return self.route( - self.config['routes_paths']['password_retrieval'], - name='password_retrieval', - template=template or self._template_for('password_retrieval'), - pipeline=pipeline, injectors=injectors or []) + self.config["routes_paths"]["password_retrieval"], + name="password_retrieval", + template=template or self._template_for("password_retrieval"), + pipeline=pipeline, + injectors=injectors or [], + ) def password_reset(self, template=None, pipeline=None, injectors=None): - pipeline = self._methods_pipelines['password_reset'] + (pipeline or []) + pipeline = self._methods_pipelines["password_reset"] + (pipeline or []) return self.route( - self.config['routes_paths']['password_reset'], - name='password_reset', - template=template or self._template_for('password_reset'), - pipeline=pipeline, injectors=injectors or []) + self.config["routes_paths"]["password_reset"], + name="password_reset", + template=template or self._template_for("password_reset"), + pipeline=pipeline, + injectors=injectors or [], + ) def password_change(self, template=None, pipeline=None, injectors=None): - pipeline = self._methods_pipelines['password_change'] + (pipeline or []) + pipeline = self._methods_pipelines["password_change"] + (pipeline or []) return self.route( - self.config['routes_paths']['password_change'], - name='password_change', - template=template or self._template_for('password_change'), - pipeline=pipeline, injectors=injectors or []) + self.config["routes_paths"]["password_change"], + name="password_change", + template=template or self._template_for("password_change"), + pipeline=pipeline, + injectors=injectors or [], + ) def download(self, template=None, pipeline=None, injectors=None): - pipeline = self._methods_pipelines['download'] + (pipeline or []) + pipeline = self._methods_pipelines["download"] + (pipeline or []) return self.route( - self.config['routes_paths']['download'], name='download', - pipeline=pipeline, injectors=injectors or [], methods='get') + self.config["routes_paths"]["download"], + name="download", + pipeline=pipeline, + injectors=injectors or [], + methods="get", + ) #: callbacks def _after_login(self, form): @@ -452,33 +439,33 @@ def _after_password_change(self): #: callbacks decorators def after_login(self, f): - self._callbacks['after_login'] = f + self._callbacks["after_login"] = f return f def after_logout(self, f): - self._callbacks['after_logout'] = f + self._callbacks["after_logout"] = f return f def after_registration(self, f): - self._callbacks['after_registration'] = f + self._callbacks["after_registration"] = f return f def after_profile(self, f): - self._callbacks['after_profile'] = f + self._callbacks["after_profile"] = f return f def after_email_verification(self, f): - self._callbacks['after_email_verification'] = f + self._callbacks["after_email_verification"] = f return f def after_password_retrieval(self, f): - self._callbacks['after_password_retrieval'] = f + self._callbacks["after_password_retrieval"] = f return f def after_password_reset(self, f): - self._callbacks['after_password_reset'] = f + self._callbacks["after_password_reset"] = f return f def after_password_change(self, f): - self._callbacks['after_password_change'] = f + self._callbacks["after_password_change"] = f return f diff --git a/emmett/tools/auth/ext.py b/emmett/tools/auth/ext.py index 4477af19..fd9dcd36 100644 --- a/emmett/tools/auth/ext.py +++ b/emmett/tools/auth/ext.py @@ -1,112 +1,114 @@ # -*- coding: utf-8 -*- """ - emmett.tools.auth.ext - --------------------- +emmett.tools.auth.ext +--------------------- - Provides the main auth layer. +Provides the main auth layer. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ import time - from functools import wraps import click +from ..._shortcuts import uuid from ...cli import pass_script_info from ...datastructures import sdict from ...extensions import Extension, Signals, listen_signal from ...language.helpers import Tstr from ...locals import T, now, session from ...orm.helpers import decamelize -from ...security import uuid from .forms import AuthForms -from .models import ( - AuthModel, AuthUser, AuthGroup, AuthMembership, AuthPermission, AuthEvent -) +from .models import AuthEvent, AuthGroup, AuthMembership, AuthModel, AuthPermission, AuthUser class AuthExtension(Extension): - namespace = 'auth' + namespace = "auth" default_config = { - 'models': { - 'user': AuthUser, - 'group': AuthGroup, - 'membership': AuthMembership, - 'permission': AuthPermission, - 'event': AuthEvent + "models": { + "user": AuthUser, + "group": AuthGroup, + "membership": AuthMembership, + "permission": AuthPermission, + "event": AuthEvent, }, - 'hmac_key': None, - 'hmac_alg': 'pbkdf2(2000,20,sha512)', - 'inject_pipe': False, - 'log_events': True, - 'flash_messages': True, - 'csrf': True, - 'enabled_routes': [ - 'login', 'logout', 'registration', 'profile', 'email_verification', - 'password_retrieval', 'password_reset', 'password_change', - 'download'], - 'disabled_routes': [], - 'routes_paths': { - 'login': '/login', - 'logout': '/logout', - 'registration': '/registration', - 'profile': '/profile', - 'email_verification': '/email_verification/', - 'password_retrieval': '/password_retrieval', - 'password_reset': '/password_reset/', - 'password_change': '/password_change', - 'download': '/download/' + "hmac_key": None, + "hmac_alg": "pbkdf2(2000,20,sha512)", + "inject_pipe": False, + "log_events": True, + "flash_messages": True, + "csrf": True, + "enabled_routes": [ + "login", + "logout", + "registration", + "profile", + "email_verification", + "password_retrieval", + "password_reset", + "password_change", + "download", + ], + "disabled_routes": [], + "routes_paths": { + "login": "/login", + "logout": "/logout", + "registration": "/registration", + "profile": "/profile", + "email_verification": "/email_verification/", + "password_retrieval": "/password_retrieval", + "password_reset": "/password_reset/", + "password_change": "/password_change", + "download": "/download/", }, - 'single_template': False, - 'password_min_length': 6, - 'remember_option': True, - 'session_expiration': 3600, - 'session_long_expiration': 3600 * 24 * 30, - 'registration_verification': True, - 'registration_approval': False + "single_template": False, + "password_min_length": 6, + "remember_option": True, + "session_expiration": 3600, + "session_long_expiration": 3600 * 24 * 30, + "registration_verification": True, + "registration_approval": False, } default_messages = { - 'approval_pending': 'Registration is pending approval', - 'verification_pending': 'Registration needs verification', - 'login_disabled': 'The account is locked', - 'login_invalid': 'Invalid credentials', - 'logged_out': 'Logged out successfully', - 'registration_success': 'Registration completed', - 'profile_updated': 'Profile updated successfully', - 'verification_success': 'Account verification completed', - 'password_changed': 'Password changed successfully', - 'mail_failure': 'Something went wrong with the email, try again later', - 'mail_success': 'We sent you an email, check your inbox', - 'reset_key_invalid': 'The reset link was invalid or expired', - 'login_button': 'Sign in', - 'registration_button': 'Register', - 'profile_button': 'Save', - 'remember_button': 'Remember me', - 'password_retrieval_button': 'Retrieve password', - 'password_reset_button': 'Reset password', - 'password_change_button': 'Change password', - 'login_log': 'User {id} logged in', - 'logout_log': 'User {id} logged out', - 'registration_log': 'User {id} registered', - 'profile_log': 'User {id} updated profile', - 'email_verification_log': 'Verification email sent to user {id}', - 'password_retrieval_log': 'User {id} asked for password retrieval', - 'password_reset_log': 'User {id} reset the password', - 'password_change_log': 'User {id} changed the password', - 'old_password': 'Current password', - 'new_password': 'New password', - 'verify_password': 'Confirm password', - 'registration_email_subject': 'Email verification', - 'registration_email_text': - 'Hello {email}! Click on the link {link} to verify your email', - 'reset_password_email_subject': 'Password reset requested', - 'reset_password_email_text': - 'A password reset was requested for your account, ' - 'click on the link {link} to proceed' + "approval_pending": "Registration is pending approval", + "verification_pending": "Registration needs verification", + "login_disabled": "The account is locked", + "login_invalid": "Invalid credentials", + "logged_out": "Logged out successfully", + "registration_success": "Registration completed", + "profile_updated": "Profile updated successfully", + "verification_success": "Account verification completed", + "password_changed": "Password changed successfully", + "mail_failure": "Something went wrong with the email, try again later", + "mail_success": "We sent you an email, check your inbox", + "reset_key_invalid": "The reset link was invalid or expired", + "login_button": "Sign in", + "registration_button": "Register", + "profile_button": "Save", + "remember_button": "Remember me", + "password_retrieval_button": "Retrieve password", + "password_reset_button": "Reset password", + "password_change_button": "Change password", + "login_log": "User {id} logged in", + "logout_log": "User {id} logged out", + "registration_log": "User {id} registered", + "profile_log": "User {id} updated profile", + "email_verification_log": "Verification email sent to user {id}", + "password_retrieval_log": "User {id} asked for password retrieval", + "password_reset_log": "User {id} reset the password", + "password_change_log": "User {id} changed the password", + "old_password": "Current password", + "new_password": "New password", + "verify_password": "Confirm password", + "registration_email_subject": "Email verification", + "registration_email_text": "Hello {email}! Click on the link {link} to verify your email", + "reset_password_email_subject": "Password reset requested", + "reset_password_email_text": "A password reset was requested for your account, " + "click on the link {link} to proceed", } def __init__(self, app, env, config): @@ -117,45 +119,38 @@ def __init__(self, app, env, config): AuthModel.auth = self def __init_messages(self): - self.config.messages = self.config.get('messages', sdict()) + self.config.messages = self.config.get("messages", sdict()) for key, dval in self.default_messages.items(): self.config.messages[key] = T(self.config.messages.get(key, dval)) def __init_mails(self): - self.mails = { - 'registration': self._registration_email, - 'reset_password': self._reset_password_email - } + self.mails = {"registration": self._registration_email, "reset_password": self._reset_password_email} def __register_commands(self): - @self.app.cli.group('auth', short_help='Auth commands') + @self.app.cli.group("auth", short_help="Auth commands") def cli_group(): pass - @cli_group.command('generate_key', short_help='Generate an auth key') + @cli_group.command("generate_key", short_help="Generate an auth key") @pass_script_info def generate_key(info): click.echo(uuid()) def __ensure_config(self): - for key in ( - set(self.default_config['routes_paths'].keys()) - - set(self.config['routes_paths'].keys()) - ): - self.config['routes_paths'][key] = \ - self.default_config['routes_paths'][key] + for key in set(self.default_config["routes_paths"].keys()) - set(self.config["routes_paths"].keys()): + self.config["routes_paths"][key] = self.default_config["routes_paths"][key] def __get_relnames(self): rv = {} def_names = { - 'user': 'user', - 'group': 'auth_group', - 'membership': 'auth_membership', - 'permission': 'auth_permission', - 'event': 'auth_event' + "user": "user", + "group": "auth_group", + "membership": "auth_membership", + "permission": "auth_permission", + "event": "auth_event", } - for m in ['user', 'group', 'membership', 'permission', 'event']: - if self.config.models[m] == self.default_config['models'][m]: + for m in ["user", "group", "membership", "permission", "event"]: + if self.config.models[m] == self.default_config["models"][m]: rv[m] = def_names[m] else: rv[m] = decamelize(self.config.models[m].__name__) @@ -168,13 +163,12 @@ def on_load(self): "An auto-generated 'hmac_key' was added to the auth " "configuration.\nPlase add your own key to the configuration. " "You can generate a key using the auth command.\n" - "> emmett -a {your_app_name} auth generate_key") + "> emmett -a {your_app_name} auth generate_key" + ) self.config.hmac_key = uuid() - self._hmac_key = self.config.hmac_alg + ':' + self.config.hmac_key - if 'MailExtension' not in self.app.ext: - self.app.log.warn( - "No mailer seems to be configured. The auth features " - "requiring mailer won't work.") + self._hmac_key = self.config.hmac_alg + ":" + self.config.hmac_key + if "MailExtension" not in self.app.ext: + self.app.log.warn("No mailer seems to be configured. The auth features " "requiring mailer won't work.") self.__ensure_config() self.relation_names = self.__get_relnames() @@ -193,23 +187,16 @@ def __set_model_for_key(self, key, model): if not model: return _model_bases = { - 'user': AuthModel, - 'group': AuthGroup, - 'membership': AuthMembership, - 'permission': AuthPermission + "user": AuthModel, + "group": AuthGroup, + "membership": AuthMembership, + "permission": AuthPermission, } if not issubclass(model, _model_bases[key]): - raise RuntimeError(f'{model.__name__} is an invalid {key} auth model') + raise RuntimeError(f"{model.__name__} is an invalid {key} auth model") self.config.models[key] = model - def use_database( - self, - db, - user_model=None, - group_model=None, - membership_model=None, - permission_model=None - ): + def use_database(self, db, user_model=None, group_model=None, membership_model=None, permission_model=None): self.db = db self.__set_model_for_key("user", user_model) self.__set_model_for_key("group", group_model) @@ -218,13 +205,11 @@ def use_database( self.define_models() def __set_models_labels(self): - for model in self.default_config['models'].values(): + for model in self.default_config["models"].values(): for supmodel in list(reversed(model.__mro__))[1:]: - if not supmodel.__module__.startswith( - 'emmett.tools.auth.models' - ): + if not supmodel.__module__.startswith("emmett.tools.auth.models"): continue - if not hasattr(supmodel, 'form_labels'): + if not hasattr(supmodel, "form_labels"): continue current_labels = {} for key, val in supmodel.form_labels.items(): @@ -236,66 +221,58 @@ def define_models(self): names = self.relation_names models = self.config.models #: AuthUser - user_model = models['user'] + user_model = models["user"] many_refs = [ - {names['membership'] + 's': models['membership'].__name__}, - {names['event'] + 's': models['event'].__name__}, - {names['group'] + 's': {'via': names['membership'] + 's'}}, - {names['permission'] + 's': {'via': names['group'] + 's'}} + {names["membership"] + "s": models["membership"].__name__}, + {names["event"] + "s": models["event"].__name__}, + {names["group"] + "s": {"via": names["membership"] + "s"}}, + {names["permission"] + "s": {"via": names["group"] + "s"}}, ] - if getattr(user_model, '_auto_relations', True): + if getattr(user_model, "_auto_relations", True): for el in many_refs: key = list(el)[0] user_model._all_hasmany_ref_[key] = el - if user_model.validation.get('password') is None: - user_model.validation['password'] = { - 'len': {'gte': self.config.password_min_length}, - 'crypt': {'key': self._hmac_key} + if user_model.validation.get("password") is None: + user_model.validation["password"] = { + "len": {"gte": self.config.password_min_length}, + "crypt": {"key": self._hmac_key}, } #: AuthGroup - group_model = models['group'] - if not hasattr(group_model, 'format'): - setattr(group_model, 'format', '%(role)s (%(id)s)') + group_model = models["group"] + if not hasattr(group_model, "format"): + group_model.format = "%(role)s (%(id)s)" many_refs = [ - {names['membership'] + 's': models['membership'].__name__}, - {names['permission'] + 's': models['permission'].__name__}, - {names['user'] + 's': {'via': names['membership'] + 's'}} + {names["membership"] + "s": models["membership"].__name__}, + {names["permission"] + "s": models["permission"].__name__}, + {names["user"] + "s": {"via": names["membership"] + "s"}}, ] - if getattr(group_model, '_auto_relations', True): + if getattr(group_model, "_auto_relations", True): for el in many_refs: key = list(el)[0] group_model._all_hasmany_ref_[key] = el #: AuthMembership - membership_model = models['membership'] - belongs_refs = [ - {names['user']: models['user'].__name__}, - {names['group']: models['group'].__name__} - ] - if getattr(membership_model, '_auto_relations', True): + membership_model = models["membership"] + belongs_refs = [{names["user"]: models["user"].__name__}, {names["group"]: models["group"].__name__}] + if getattr(membership_model, "_auto_relations", True): for el in belongs_refs: key = list(el)[0] membership_model._all_belongs_ref_[key] = el #: AuthPermission - permission_model = models['permission'] - belongs_refs = [ - {names['group']: models['group'].__name__} - ] - if getattr(permission_model, '_auto_relations', True): + permission_model = models["permission"] + belongs_refs = [{names["group"]: models["group"].__name__}] + if getattr(permission_model, "_auto_relations", True): for el in belongs_refs: key = list(el)[0] permission_model._all_belongs_ref_[key] = el #: AuthEvent - event_model = models['event'] - belongs_refs = [ - {names['user']: models['user'].__name__} - ] - if getattr(event_model, '_auto_relations', True): + event_model = models["event"] + belongs_refs = [{names["user"]: models["user"].__name__}] + if getattr(event_model, "_auto_relations", True): for el in belongs_refs: key = list(el)[0] event_model._all_belongs_ref_[key] = el self.db.define_models( - models['user'], models['group'], models['membership'], - models['permission'], models['event'] + models["user"], models["group"], models["membership"], models["permission"], models["event"] ) self.model_names = sdict() for key, value in models.items(): @@ -304,25 +281,23 @@ def define_models(self): def init_forms(self): self.forms = sdict() for key, (method, fields_method) in AuthForms.map().items(): - self.forms[key] = _wrap_form( - method, fields_method(self.auth), self.auth) + self.forms[key] = _wrap_form(method, fields_method(self.auth), self.auth) def login_user(self, user, remember=False): try: del user.password except Exception: pass - expiration = remember and self.config.session_long_expiration or \ - self.config.session_expiration + expiration = remember and self.config.session_long_expiration or self.config.session_expiration session.auth = sdict( user=user, last_visit=now().as_naive_datetime(), last_dbcheck=now().as_naive_datetime(), expiration=expiration, - remember=remember + remember=remember, ) - def log_event(self, description, data={}, origin='auth', user=None): + def log_event(self, description, data={}, origin="auth", user=None): if not self.config.log_events or not description: return try: @@ -332,42 +307,43 @@ def log_event(self, description, data={}, origin='auth', user=None): # log messages should not be translated if isinstance(description, Tstr): description = description.text - self.config.models['event'].table.insert( - description=str(description % data), - origin=origin, user=user_id) + self.config.models["event"].table.insert(description=str(description % data), origin=origin, user=user_id) def generate_reset_key(self, user): - key = str(int(time.time())) + '-' + uuid() + key = str(int(time.time())) + "-" + uuid() user.update_record(reset_password_key=key) return key def get_user_by_reset_key(self, key): try: - generated_at = int(key.split('-')[0]) + generated_at = int(key.split("-")[0]) if time.time() - generated_at > 60 * 60 * 24: raise ValueError - user = self.config.models['user'].get(reset_password_key=key) + user = self.config.models["user"].get(reset_password_key=key) except ValueError: user = None return user def _registration_email(self, user, data): - data['email'] = user.email + data["email"] = user.email return self.app.ext.MailExtension.send_mail( recipients=user.email, - subject=str(self.config.messages['registration_email_subject']), - body=str(self.config.messages['registration_email_text'] % data)) + subject=str(self.config.messages["registration_email_subject"]), + body=str(self.config.messages["registration_email_text"] % data), + ) def _reset_password_email(self, user, data): - data['email'] = user.email + data["email"] = user.email return self.app.ext.MailExtension.send_mail( recipients=user.email, - subject=str(self.config.messages['reset_password_email_subject']), - body=str(self.config.messages['reset_password_email_text'] % data)) + subject=str(self.config.messages["reset_password_email_subject"]), + body=str(self.config.messages["reset_password_email_text"] % data), + ) def _wrap_form(f, fields, auth): @wraps(f) def wrapped(*args, **kwargs): return f(auth, fields, *args, **kwargs) + return wrapped diff --git a/emmett/tools/auth/forms.py b/emmett/tools/auth/forms.py index 5a161b5d..d77ce600 100644 --- a/emmett/tools/auth/forms.py +++ b/emmett/tools/auth/forms.py @@ -1,12 +1,12 @@ # -*- coding: utf-8 -*- """ - emmett.tools.auth.forms - ----------------------- +emmett.tools.auth.forms +----------------------- - Provides the forms for the authorization system. +Provides the forms for the authorization system. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ from ...forms import Form, ModelForm @@ -23,6 +23,7 @@ def wrap(f): cls._registry_[target] = f cls._fields_registry_[target] = fields return f + return wrap @classmethod @@ -38,37 +39,25 @@ def map(cls): def login_fields(auth): - model = auth.models['user'] + model = auth.models["user"] rv = { - 'email': Field( - validation={'is': 'email', 'presence': True}, - label=model.email.label), - 'password': Field( - 'password', validation=model.password._requires, - label=model.password.label) + "email": Field(validation={"is": "email", "presence": True}, label=model.email.label), + "password": Field("password", validation=model.password._requires, label=model.password.label), } if auth.ext.config.remember_option: - rv['remember'] = Field( - 'bool', default=True, - label=auth.ext.config.messages['remember_button']) + rv["remember"] = Field("bool", default=True, label=auth.ext.config.messages["remember_button"]) return rv def registration_fields(auth): - rw_data = auth.models['user']._instance_()._merged_form_rw_ - user_table = auth.models['user'].table - all_fields = [ - (field_name, user_table[field_name].clone()) - for field_name in rw_data['registration']['writable'] - ] - for i, (field_name, field) in enumerate(all_fields): - if field_name == 'password': + rw_data = auth.models["user"]._instance_()._merged_form_rw_ + user_table = auth.models["user"].table + all_fields = [(field_name, user_table[field_name].clone()) for field_name in rw_data["registration"]["writable"]] + for i, (field_name, _) in enumerate(all_fields): + if field_name == "password": all_fields.insert( - i + 1, ( - 'password2', - Field( - 'password', - label=auth.ext.config.messages['verify_password']))) + i + 1, ("password2", Field("password", label=auth.ext.config.messages["verify_password"])) + ) break rv = {} for i, (field_name, field) in enumerate(all_fields): @@ -79,112 +68,81 @@ def registration_fields(auth): def profile_fields(auth): - rw_data = auth.models['user']._instance_()._merged_form_rw_ - return rw_data['profile'] + rw_data = auth.models["user"]._instance_()._merged_form_rw_ + return rw_data["profile"] def password_retrieval_fields(auth): rv = { - 'email': Field( - validation={'is': 'email', 'presence': True, 'lower': True}), + "email": Field(validation={"is": "email", "presence": True, "lower": True}), } return rv def password_reset_fields(auth): - password_field = auth.ext.config.models['user'].password + password_field = auth.ext.config.models["user"].password rv = { - 'password': Field( - 'password', validation=password_field._requires, - label=auth.ext.config.messages['new_password']), - 'password2': Field( - 'password', label=auth.ext.config.messages['verify_password']) + "password": Field( + "password", validation=password_field._requires, label=auth.ext.config.messages["new_password"] + ), + "password2": Field("password", label=auth.ext.config.messages["verify_password"]), } return rv def password_change_fields(auth): - password_validation = auth.ext.config.models['user'].password._requires + password_validation = auth.ext.config.models["user"].password._requires rv = { - 'old_password': Field( - 'password', validation=password_validation, - label=auth.ext.config.messages['old_password']), - 'new_password': Field( - 'password', validation=password_validation, - label=auth.ext.config.messages['new_password']), - 'new_password2': Field( - 'password', label=auth.ext.config.messages['verify_password']) + "old_password": Field( + "password", validation=password_validation, label=auth.ext.config.messages["old_password"] + ), + "new_password": Field( + "password", validation=password_validation, label=auth.ext.config.messages["new_password"] + ), + "new_password2": Field("password", label=auth.ext.config.messages["verify_password"]), } return rv -@AuthForms.register_for('login', fields=login_fields) +@AuthForms.register_for("login", fields=login_fields) def login_form(auth, fields, **kwargs): - opts = { - 'submit': auth.ext.config.messages['login_button'], 'keepvalues': True} + opts = {"submit": auth.ext.config.messages["login_button"], "keepvalues": True} opts.update(**kwargs) - return Form( - fields, - **opts - ) + return Form(fields, **opts) -@AuthForms.register_for('registration', fields=registration_fields) +@AuthForms.register_for("registration", fields=registration_fields) def registration_form(auth, fields, **kwargs): - opts = { - 'submit': auth.ext.config.messages['registration_button'], - 'keepvalues': True} + opts = {"submit": auth.ext.config.messages["registration_button"], "keepvalues": True} opts.update(**kwargs) - return Form( - fields, - **opts - ) + return Form(fields, **opts) -@AuthForms.register_for('profile', fields=profile_fields) +@AuthForms.register_for("profile", fields=profile_fields) def profile_form(auth, fields, **kwargs): - opts = { - 'submit': auth.ext.config.messages['profile_button'], - 'keepvalues': True} + opts = {"submit": auth.ext.config.messages["profile_button"], "keepvalues": True} opts.update(**kwargs) return ModelForm( - auth.models['user'], - record_id=auth.user.id, - fields=fields, - upload=auth.ext.exposer.url('download'), - **opts + auth.models["user"], record_id=auth.user.id, fields=fields, upload=auth.ext.exposer.url("download"), **opts ) -@AuthForms.register_for('password_retrieval', fields=password_retrieval_fields) +@AuthForms.register_for("password_retrieval", fields=password_retrieval_fields) def password_retrieval_form(auth, fields, **kwargs): - opts = {'submit': auth.ext.config.messages['password_retrieval_button']} + opts = {"submit": auth.ext.config.messages["password_retrieval_button"]} opts.update(**kwargs) - return Form( - fields, - **opts - ) + return Form(fields, **opts) -@AuthForms.register_for('password_reset', fields=password_reset_fields) +@AuthForms.register_for("password_reset", fields=password_reset_fields) def password_reset_form(auth, fields, **kwargs): - opts = { - 'submit': auth.ext.config.messages['password_reset_button'], - 'keepvalues': True} + opts = {"submit": auth.ext.config.messages["password_reset_button"], "keepvalues": True} opts.update(**kwargs) - return Form( - fields, - **opts - ) + return Form(fields, **opts) -@AuthForms.register_for('password_change', fields=password_change_fields) +@AuthForms.register_for("password_change", fields=password_change_fields) def password_change_form(auth, fields, **kwargs): - opts = { - 'submit': auth.ext.config.messages['password_change_button'], - 'keepvalues': True} + opts = {"submit": auth.ext.config.messages["password_change_button"], "keepvalues": True} opts.update(**kwargs) - return Form( - fields, - **opts - ) + return Form(fields, **opts) diff --git a/emmett/tools/auth/models.py b/emmett/tools/auth/models.py index fe665f93..868f9a83 100644 --- a/emmett/tools/auth/models.py +++ b/emmett/tools/auth/models.py @@ -1,18 +1,18 @@ # -*- coding: utf-8 -*- """ - emmett.tools.auth.models - ------------------------ +emmett.tools.auth.models +------------------------ - Provides models for the authorization system. +Provides models for the authorization system. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ +from ..._shortcuts import uuid from ...ctx import current from ...locals import now, request -from ...orm import Model, Field, before_insert, rowmethod -from ...security import uuid +from ...orm import Field, Model, before_insert, rowmethod class TimestampedModel(Model): @@ -21,8 +21,7 @@ class TimestampedModel(Model): class AuthModel(Model): - _additional_inheritable_dict_attrs_ = [ - 'form_registration_rw', 'form_profile_rw'] + _additional_inheritable_dict_attrs_ = ["form_registration_rw", "form_profile_rw"] auth = None @classmethod @@ -33,8 +32,7 @@ def _init_inheritable_dicts_(cls): else: attr_name, default = attr, {} if not isinstance(default, dict): - raise SyntaxError( - "{} is not a dictionary".format(attr_name)) + raise SyntaxError("{} is not a dictionary".format(attr_name)) setattr(cls, attr_name, default) @classmethod @@ -57,33 +55,32 @@ def _merge_inheritable_dicts_(cls, models): setattr(cls, attr_name, attrs) def __super_method(self, name): - return getattr(super(AuthModel, self), '_Model__' + name) + return getattr(super(AuthModel, self), "_Model__" + name) def _define_(self): self.__hide_all() - self.__super_method('define_indexes')() - self.__super_method('define_validation')() - self.__super_method('define_access')() - self.__super_method('define_defaults')() - self.__super_method('define_updates')() - self.__super_method('define_representation')() - self.__super_method('define_computations')() - self.__super_method('define_callbacks')() - self.__super_method('define_scopes')() - self.__super_method('define_query_helpers')() - self.__super_method('define_form_utils')() + self.__super_method("define_indexes")() + self.__super_method("define_validation")() + self.__super_method("define_access")() + self.__super_method("define_defaults")() + self.__super_method("define_updates")() + self.__super_method("define_representation")() + self.__super_method("define_computations")() + self.__super_method("define_callbacks")() + self.__super_method("define_scopes")() + self.__super_method("define_query_helpers")() + self.__super_method("define_form_utils")() self.__define_authform_utils() self.setup() - #def __define_extra_fields(self): + # def __define_extra_fields(self): # self.auth.settings.extra_fields['auth_user'] = self.fields def __hide_all(self): - alwaysvisible = ['first_name', 'last_name', 'password', 'email'] + alwaysvisible = ["first_name", "last_name", "password", "email"] for field in self.table.fields: if field not in alwaysvisible: - self.table[field].writable = self.table[field].readable = \ - False + self.table[field].writable = self.table[field].readable = False def __base_visibility(self): rv = {} @@ -93,8 +90,9 @@ def __base_visibility(self): def __define_authform_utils(self): settings = { - 'form_registration_rw': {'writable': [], 'readable': []}, - 'form_profile_rw': {'writable': [], 'readable': []}} + "form_registration_rw": {"writable": [], "readable": []}, + "form_profile_rw": {"writable": [], "readable": []}, + } for config_dict in settings.keys(): rw_data = self.__base_visibility() rw_data.update(**self.fields_rw) @@ -105,17 +103,18 @@ def __define_authform_utils(self): else: readable = writable = value if readable: - settings[config_dict]['readable'].append(key) + settings[config_dict]["readable"].append(key) if writable: - settings[config_dict]['writable'].append(key) - setattr(self, '_merged_form_rw_', { - 'registration': settings['form_registration_rw'], - 'profile': settings['form_profile_rw']}) + settings[config_dict]["writable"].append(key) + self._merged_form_rw_ = { + "registration": settings["form_registration_rw"], + "profile": settings["form_profile_rw"], + } class AuthUserBasic(AuthModel, TimestampedModel): tablename = "auth_users" - format = '%(email)s (%(id)s)' + format = "%(email)s (%(id)s)" #: injected by Auth # has_many( # {'auth_memberships': 'AuthMembership'}, @@ -126,55 +125,48 @@ class AuthUserBasic(AuthModel, TimestampedModel): email = Field(length=255, unique=True) password = Field.password(length=512) - registration_key = Field(length=512, rw=False, default='') - reset_password_key = Field(length=512, rw=False, default='') - registration_id = Field(length=512, rw=False, default='') + registration_key = Field(length=512, rw=False, default="") + reset_password_key = Field(length=512, rw=False, default="") + registration_id = Field(length=512, rw=False, default="") - form_labels = { - 'email': 'E-mail', - 'password': 'Password' - } + form_labels = {"email": "E-mail", "password": "Password"} - form_profile_rw = { - 'email': (True, False), - 'password': False - } + form_profile_rw = {"email": (True, False), "password": False} @before_insert def set_registration_key(self, fields): - if self.auth.config.registration_verification and not \ - fields.get('registration_key'): - fields['registration_key'] = uuid() + if self.auth.config.registration_verification and not fields.get("registration_key"): + fields["registration_key"] = uuid() elif self.auth.config.registration_approval: - fields['registration_key'] = 'pending' + fields["registration_key"] = "pending" - @rowmethod('disable') + @rowmethod("disable") def _set_disabled(self, row): - return row.update_record(registration_key='disabled') + return row.update_record(registration_key="disabled") - @rowmethod('block') + @rowmethod("block") def _set_blocked(self, row): - return row.update_record(registration_key='blocked') + return row.update_record(registration_key="blocked") - @rowmethod('allow') + @rowmethod("allow") def _set_allowed(self, row): - return row.update_record(registration_key='') + return row.update_record(registration_key="") class AuthUser(AuthUserBasic): - format = '%(first_name)s %(last_name)s (%(id)s)' + format = "%(first_name)s %(last_name)s (%(id)s)" first_name = Field(length=128, notnull=True) last_name = Field(length=128, notnull=True) form_labels = { - 'first_name': 'First name', - 'last_name': 'Last name', + "first_name": "First name", + "last_name": "Last name", } class AuthGroup(TimestampedModel): - format = '%(role)s (%(id)s)' + format = "%(role)s (%(id)s)" #: injected by Auth # has_many( # {'auth_memberships': 'AuthMembership'}, @@ -182,13 +174,10 @@ class AuthGroup(TimestampedModel): # {'users': {'via': 'memberships'}} # ) - role = Field(length=255, default='', unique=True) + role = Field(length=255, default="", unique=True) description = Field.text() - form_labels = { - 'role': 'Role', - 'description': 'Description' - } + form_labels = {"role": "Role", "description": "Description"} class AuthMembership(TimestampedModel): @@ -201,19 +190,13 @@ class AuthPermission(TimestampedModel): #: injected by Auth # belongs_to({'auth_group': 'AuthGroup'}) - name = Field(length=512, default='default', notnull=True) + name = Field(length=512, default="default", notnull=True) table_name = Field(length=512) record_id = Field.int(default=0) - validation = { - 'record_id': {'in': {'range': (0, 10**9)}} - } + validation = {"record_id": {"in": {"range": (0, 10**9)}}} - form_labels = { - 'name': 'Name', - 'table_name': 'Object or table name', - 'record_id': 'Record ID' - } + form_labels = {"name": "Name", "table_name": "Object or table name", "record_id": "Record ID"} class AuthEvent(TimestampedModel): @@ -225,15 +208,10 @@ class AuthEvent(TimestampedModel): description = Field.text(notnull=True) default_values = { - 'client_ip': lambda: - request.client if hasattr(current, 'request') else 'unavailable', - 'origin': 'auth', - 'description': '' + "client_ip": lambda: request.client if hasattr(current, "request") else "unavailable", + "origin": "auth", + "description": "", } #: labels injected by Auth - form_labels = { - 'client_ip': 'Client IP', - 'origin': 'Origin', - 'description': 'Description' - } + form_labels = {"client_ip": "Client IP", "origin": "Origin", "description": "Description"} diff --git a/emmett/tools/decorators.py b/emmett/tools/decorators.py index 5303bd33..a59ae77e 100644 --- a/emmett/tools/decorators.py +++ b/emmett/tools/decorators.py @@ -1,57 +1,35 @@ # -*- coding: utf-8 -*- """ - emmett.tools.decorators - ----------------------- +emmett.tools.decorators +----------------------- - Provides requires and service decorators. +Provides requires and service decorators. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ -from ..routing.router import Router +from emmett_core.pipeline.dyn import ( + ServicePipeBuilder as _ServicePipeBuilder, + requires as _requires, + service as _service, +) +from ..pipeline import RequirePipe +from .service import JSONServicePipe, XMLServicePipe -class Decorator(object): - def build_pipe(self): - pass - def __call__(self, func): - obj = Router.exposing() - obj.pipeline.append(self.build_pipe()) - return func +class ServicePipeBuilder(_ServicePipeBuilder): + _pipe_cls = {"json": JSONServicePipe, "xml": XMLServicePipe} -class requires(Decorator): - def __init__(self, condition=None, otherwise=None): - if condition is None or otherwise is None: - raise SyntaxError( - 'requires usage: @requires(condition, otherwise)' - ) - if not callable(otherwise) and not isinstance(otherwise, str): - raise SyntaxError( - "requires 'otherwise' param must be string or callable" - ) - self.condition = condition - self.otherwise = otherwise +class requires(_requires): + _pipe_cls = RequirePipe - def build_pipe(self): - from ..pipeline import RequirePipe - return RequirePipe(self.condition, self.otherwise) - -class service(Decorator): - def __init__(self, procedure): - self.procedure = procedure - - @staticmethod - def json(f): - return service('json')(f) +class service(_service): + _inner_builder = ServicePipeBuilder() @staticmethod def xml(f): - return service('xml')(f) - - def build_pipe(self): - from .service import ServicePipe - return ServicePipe(self.procedure) + return service("xml")(f) diff --git a/emmett/tools/mailer.py b/emmett/tools/mailer.py index 6aaa45e4..3c478fbc 100644 --- a/emmett/tools/mailer.py +++ b/emmett/tools/mailer.py @@ -1,73 +1,77 @@ # -*- coding: utf-8 -*- """ - emmett.tools.mailer - ------------------- +emmett.tools.mailer +------------------- - Provides mail facilities for Emmett. +Provides mail facilities for Emmett. - :copyright: 2014 Giovanni Barillari +:copyright: 2014 Giovanni Barillari - Based on the code of flask-mail - :copyright: (c) 2010 by Dan Jacob. +Based on the code of flask-mail +:copyright: (c) 2010 by Dan Jacob. - :license: BSD-3-Clause +:license: BSD-3-Clause """ import smtplib import time - from contextlib import contextmanager from email import charset as _charsetreg, policy from email.encoders import encode_base64 +from email.header import Header from email.mime.base import MIMEBase from email.mime.multipart import MIMEMultipart from email.mime.text import MIMEText -from email.header import Header -from email.utils import formatdate, formataddr, make_msgid, parseaddr +from email.utils import formataddr, formatdate, make_msgid, parseaddr from functools import wraps +from emmett_core.utils import cachedprop + from ..extensions import Extension from ..libs.contenttype import contenttype -from ..utils import cachedprop -_charsetreg.add_charset('utf-8', _charsetreg.SHORTEST, None, 'utf-8') + +_charsetreg.add_charset("utf-8", _charsetreg.SHORTEST, None, "utf-8") message_policy = policy.SMTP def _has_newline(line): - if line and ('\r' in line or '\n' in line): + if line and ("\r" in line or "\n" in line): return True return False -def sanitize_subject(subject, encoding='utf-8'): +def sanitize_subject(subject, encoding="utf-8"): try: - subject.encode('ascii') + subject.encode("ascii") except UnicodeEncodeError: subject = Header(subject, encoding).encode() return subject -def sanitize_address(address, encoding='utf-8'): +def sanitize_address(address, encoding="utf-8"): if isinstance(address, str): - address = parseaddr(address) + try: + address = parseaddr(address, strict=False) + except Exception: + address = parseaddr(address) name, address = address name = Header(name, encoding).encode() try: - address.encode('ascii') + address.encode("ascii") except UnicodeEncodeError: - if '@' in address: - localpart, domain = address.split('@', 1) + if "@" in address: + localpart, domain = address.split("@", 1) localpart = str(Header(localpart, encoding)) - domain = domain.encode('idna').decode('ascii') - address = '@'.join([localpart, domain]) + domain = domain.encode("idna").decode("ascii") + address = "@".join([localpart, domain]) else: address = Header(address, encoding).encode() return formataddr((name, address)) -def sanitize_addresses(addresses, encoding='utf-8'): - return map(lambda address: sanitize_address(address, encoding), addresses) +def sanitize_addresses(addresses, encoding="utf-8"): + return map(lambda address: sanitize_address(address, encoding), addresses) # noqa: C417 class MailServer(object): @@ -98,32 +102,43 @@ def send(self, message): self.host.sendmail( sanitize_address(message.sender), list(sanitize_addresses(message.all_recipients)), - str(message).encode('utf8'), + str(message).encode("utf8"), message.mail_options, - message.rcpt_options) + message.rcpt_options, + ) return True class Attachment(object): - def __init__( - self, filename=None, data=None, content_type=None, disposition=None, - headers=None - ): + def __init__(self, filename=None, data=None, content_type=None, disposition=None, headers=None): if not content_type and filename: content_type = contenttype(filename).split(";")[0] self.filename = filename self.content_type = content_type or contenttype(filename).split(";")[0] self.data = data - self.disposition = disposition or 'attachment' + self.disposition = disposition or "attachment" self.headers = headers or {} class Mail(object): def __init__( - self, ext, subject='', recipients=None, body=None, html=None, - alts=None, sender=None, cc=None, bcc=None, attachments=None, - reply_to=None, date=None, charset='utf-8', extra_headers=None, - mail_options=None, rcpt_options=None + self, + ext, + subject="", + recipients=None, + body=None, + html=None, + alts=None, + sender=None, + cc=None, + bcc=None, + attachments=None, + reply_to=None, + date=None, + charset="utf-8", + extra_headers=None, + mail_options=None, + rcpt_options=None, ): sender = sender or ext.config.sender if isinstance(sender, tuple): @@ -145,7 +160,7 @@ def __init__( self.mail_options = mail_options or [] self.rcpt_options = rcpt_options or [] self.attachments = attachments or [] - for attr in ['recipients', 'cc', 'bcc']: + for attr in ["recipients", "cc", "bcc"]: if not isinstance(getattr(self, attr), list): setattr(self, attr, [getattr(self, attr)]) @@ -155,14 +170,14 @@ def all_recipients(self): @property def html(self): - return self.alts.get('html') + return self.alts.get("html") @html.setter def html(self, value): if value is None: - self.alts.pop('html', None) + self.alts.pop("html", None) else: - self.alts['html'] = value + self.alts["html"] = value def has_bad_headers(self): headers = [self.sender, self.reply_to] + self.recipients @@ -171,10 +186,10 @@ def has_bad_headers(self): return True if self.subject: if _has_newline(self.subject): - for linenum, line in enumerate(self.subject.split('\r\n')): + for linenum, line in enumerate(self.subject.split("\r\n")): if not line: return True - if linenum > 0 and line[0] not in '\t ': + if linenum > 0 and line[0] not in "\t ": return True if _has_newline(line): return True @@ -182,7 +197,7 @@ def has_bad_headers(self): return True return False - def _mimetext(self, text, subtype='plain'): + def _mimetext(self, text, subtype="plain"): return MIMEText(text, _subtype=subtype, _charset=self.charset) @cachedprop @@ -198,38 +213,34 @@ def message(self): else: # Anything else msg = MIMEMultipart() - alternative = MIMEMultipart('alternative') - alternative.attach(self._mimetext(self.body, 'plain')) + alternative = MIMEMultipart("alternative") + alternative.attach(self._mimetext(self.body, "plain")) for mimetype, content in self.alts.items(): alternative.attach(self._mimetext(content, mimetype)) msg.attach(alternative) if self.subject: - msg['Subject'] = sanitize_subject(self.subject, self.charset) - msg['From'] = sanitize_address(self.sender, self.charset) - msg['To'] = ', '.join( - list(set(sanitize_addresses(self.recipients, self.charset)))) - msg['Date'] = formatdate(self.date, localtime=True) - msg['Message-ID'] = self.msgId + msg["Subject"] = sanitize_subject(self.subject, self.charset) + msg["From"] = sanitize_address(self.sender, self.charset) + msg["To"] = ", ".join(list(set(sanitize_addresses(self.recipients, self.charset)))) + msg["Date"] = formatdate(self.date, localtime=True) + msg["Message-ID"] = self.msgId if self.cc: - msg['Cc'] = ', '.join( - list(set(sanitize_addresses(self.cc, self.charset)))) + msg["Cc"] = ", ".join(list(set(sanitize_addresses(self.cc, self.charset)))) if self.reply_to: - msg['Reply-To'] = sanitize_address(self.reply_to, self.charset) + msg["Reply-To"] = sanitize_address(self.reply_to, self.charset) if self.extra_headers: for k, v in self.extra_headers.items(): msg[k] = v for attachment in attachments: - f = MIMEBase(*attachment.content_type.split('/')) + f = MIMEBase(*attachment.content_type.split("/")) f.set_payload(attachment.data) encode_base64(f) filename = attachment.filename try: - filename and filename.encode('ascii') + filename and filename.encode("ascii") except UnicodeEncodeError: - filename = ('UTF8', '', filename) - f.add_header( - 'Content-Disposition', attachment.disposition, - filename=filename) + filename = ("UTF8", "", filename) + f.add_header("Content-Disposition", attachment.disposition, filename=filename) for key, value in attachment.headers.items(): f.add_header(key, value) msg.attach(f) @@ -246,27 +257,23 @@ def send(self): def add_recipient(self, recipient): self.recipients.append(recipient) - def attach( - self, filename=None, data=None, content_type=None, disposition=None, - headers=None - ): - self.attachments.append( - Attachment(filename, data, content_type, disposition, headers)) + def attach(self, filename=None, data=None, content_type=None, disposition=None, headers=None): + self.attachments.append(Attachment(filename, data, content_type, disposition, headers)) class MailExtension(Extension): - namespace = 'mailer' + namespace = "mailer" default_config = { - 'server': '127.0.0.1', - 'username': None, - 'password': None, - 'port': 25, - 'use_tls': False, - 'use_ssl': False, - 'sender': None, - 'debug': False, - 'suppress': False + "server": "127.0.0.1", + "username": None, + "password": None, + "port": 25, + "use_tls": False, + "use_ssl": False, + "sender": None, + "debug": False, + "suppress": False, } def on_load(self): @@ -345,4 +352,5 @@ def _wrap_dispatcher(dispatcher): @wraps(dispatcher) def wrapped(ext, *args, **kwargs): return dispatcher(*args, **kwargs) + return wrapped diff --git a/emmett/tools/service.py b/emmett/tools/service.py index de9de955..16090efe 100644 --- a/emmett/tools/service.py +++ b/emmett/tools/service.py @@ -1,48 +1,35 @@ # -*- coding: utf-8 -*- """ - emmett.tools.service - -------------------- +emmett.tools.service +-------------------- - Provides the services handler. +Provides the services handler. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ +from emmett_core.pipeline.extras import JSONPipe + from ..ctx import current -from ..parsers import Parsers from ..pipeline import Pipe -from ..serializers import Serializers, _json_type - - -class JSONServicePipe(Pipe): - __slots__ = ['decoder', 'encoder'] - output = _json_type - - def __init__(self): - self.decoder = Parsers.get_for('json') - self.encoder = Serializers.get_for('json') +from ..serializers import Serializers - async def pipe_request(self, next_pipe, **kwargs): - current.response.headers._data['content-type'] = 'application/json' - return self.encoder(await next_pipe(**kwargs)) - - def on_receive(self, data): - return self.decoder(data) - def on_send(self, data): - return self.encoder(data) +class JSONServicePipe(JSONPipe): + __slots__ = [] + _current = current class XMLServicePipe(Pipe): - __slots__ = ['encoder'] - output = 'str' + __slots__ = ["encoder"] + output = "str" def __init__(self): - self.encoder = Serializers.get_for('xml') + self.encoder = Serializers.get_for("xml") async def pipe_request(self, next_pipe, **kwargs): - current.response.headers._data['content-type'] = 'text/xml' + current.response.headers._data["content-type"] = "text/xml" return self.encoder(await next_pipe(**kwargs)) def on_send(self, data): @@ -50,13 +37,7 @@ def on_send(self, data): def ServicePipe(procedure: str) -> Pipe: - pipe_cls = { - 'json': JSONServicePipe, - 'xml': XMLServicePipe - }.get(procedure) + pipe_cls = {"json": JSONServicePipe, "xml": XMLServicePipe}.get(procedure) if not pipe_cls: - raise RuntimeError( - 'Emmett cannot handle the service you requested: %s' % - procedure - ) + raise RuntimeError("Emmett cannot handle the service you requested: %s" % procedure) return pipe_cls() diff --git a/emmett/typing.py b/emmett/typing.py deleted file mode 100644 index 53803da5..00000000 --- a/emmett/typing.py +++ /dev/null @@ -1,18 +0,0 @@ -# -*- coding: utf-8 -*- -""" - emmett.typing - ------------- - - Provides typing helpers. - - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause -""" - -from typing import Awaitable, Callable, TypeVar - -T = TypeVar("T") -KT = TypeVar("KT") -VT = TypeVar("VT") - -ErrorHandlerType = TypeVar("ErrorHandlerType", bound=Callable[[], Awaitable[str]]) diff --git a/emmett/utils.py b/emmett/utils.py index 0ae4bde3..739d9d7a 100644 --- a/emmett/utils.py +++ b/emmett/utils.py @@ -1,159 +1,66 @@ # -*- coding: utf-8 -*- """ - emmett.utils - ------------ +emmett.utils +------------ - Provides some utilities for Emmett. +Provides some utilities for Emmett. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ from __future__ import annotations -import asyncio import re import socket - -from datetime import datetime, date, time -from typing import Any, Callable, Generic, Optional, Union, overload +from datetime import date, datetime, time import pendulum - +from emmett_core.utils import cachedprop as cachedprop from pendulum.parsing import _parse as _pendulum_parse from .datastructures import sdict -from .typing import T - - -class _cached_prop(Generic[T]): - def __init__( - self, - fget: Callable[..., T], - name: str, - doc: Optional[str] = None - ): - self.fget = fget - self.__doc__ = doc - self.__name__ = name - - def __get__(self, obj: Optional[object], cls: Any) -> T: - raise NotImplementedError - - -def cachedprop( - fget: Callable[..., T], - doc: Optional[str] = None, - name: Optional[str] = None -) -> _cached_prop[T]: - doc = doc or fget.__doc__ - name = name or fget.__name__ - if asyncio.iscoroutinefunction(fget): - return _cached_prop_loop[T](fget, name, doc) - return _cached_prop_sync[T](fget, name, doc) - - - -class _cached_prop_sync(_cached_prop[T]): - @overload - def __get__(self, obj: None, cls: Any) -> _cached_prop_sync: - ... - - @overload - def __get__(self, obj: object, cls: Any) -> T: - ... - - def __get__(self, obj: Optional[object], cls: Any) -> Union[_cached_prop_sync, T]: - if obj is None: - return self - obj.__dict__[self.__name__] = rv = self.fget(obj) - return rv - - -class _cached_awaitable_coro(Generic[T]): - slots = ['coro_f', 'obj', '_result', '_awaitable'] - - def __init__(self, coro_f: Callable[..., T], obj: object): - self.coro_f = coro_f - self.obj = obj - self._awaitable = self.__fetcher - - async def __fetcher(self) -> T: - self._result = rv = await self.coro_f(self.obj) # type: ignore - self._awaitable = self.__cached - return rv - - async def __cached(self) -> T: - return self._result - - def __await__(self): - return self._awaitable().__await__() - - -class _cached_prop_loop(_cached_prop[T]): - @overload - def __get__(self, obj: None, cls: Any) -> _cached_prop_loop: - ... - - @overload - def __get__(self, obj: object, cls: Any) -> T: - ... - - def __get__(self, obj: Optional[object], cls: Any) -> Union[_cached_prop_loop, T]: - if obj is None: - return self - obj.__dict__[self.__name__] = rv = _cached_awaitable_coro[T]( - self.fget, obj - ) - return rv # type: ignore -_pendulum_parsing_opts = { - 'day_first': False, - 'year_first': True, - 'strict': True, - 'exact': False, - 'now': None -} +_pendulum_parsing_opts = {"day_first": False, "year_first": True, "strict": True, "exact": False, "now": None} def _pendulum_normalize(obj): if isinstance(obj, time): now = datetime.utcnow() - obj = datetime( - now.year, now.month, now.day, - obj.hour, obj.minute, obj.second, obj.microsecond - ) + obj = datetime(now.year, now.month, now.day, obj.hour, obj.minute, obj.second, obj.microsecond) elif isinstance(obj, date) and not isinstance(obj, datetime): obj = datetime(obj.year, obj.month, obj.day) return obj def parse_datetime(text): - parsed = _pendulum_normalize( - _pendulum_parse(text, **_pendulum_parsing_opts)) + parsed = _pendulum_normalize(_pendulum_parse(text, **_pendulum_parsing_opts)) return pendulum.datetime( - parsed.year, parsed.month, parsed.day, - parsed.hour, parsed.minute, parsed.second, parsed.microsecond, - tz=parsed.tzinfo or pendulum.UTC + parsed.year, + parsed.month, + parsed.day, + parsed.hour, + parsed.minute, + parsed.second, + parsed.microsecond, + tz=parsed.tzinfo or pendulum.UTC, ) -_re_ipv4 = re.compile(r'(\d+)\.(\d+)\.(\d+)\.(\d+)') +_re_ipv4 = re.compile(r"(\d+)\.(\d+)\.(\d+)\.(\d+)") def is_valid_ip_address(address): # deal with special cases - if address.lower() in [ - '127.0.0.1', 'localhost', '::1', '::ffff:127.0.0.1' - ]: + if address.lower() in ["127.0.0.1", "localhost", "::1", "::ffff:127.0.0.1"]: return True - elif address.lower() in ('unknown', ''): + elif address.lower() in ("unknown", ""): return False - elif address.count('.') == 3: # assume IPv4 - if address.startswith('::ffff:'): + elif address.count(".") == 3: # assume IPv4 + if address.startswith("::ffff:"): address = address[7:] - if hasattr(socket, 'inet_aton'): # try validate using the OS + if hasattr(socket, "inet_aton"): # try validate using the OS try: socket.inet_aton(address) return True @@ -161,12 +68,10 @@ def is_valid_ip_address(address): return False else: # try validate using Regex match = _re_ipv4.match(address) - if match and all( - 0 <= int(match.group(i)) < 256 for i in (1, 2, 3, 4) - ): + if match and all(0 <= int(match.group(i)) < 256 for i in (1, 2, 3, 4)): return True return False - elif hasattr(socket, 'inet_pton'): # assume IPv6, try using the OS + elif hasattr(socket, "inet_pton"): # assume IPv6, try using the OS try: socket.inet_pton(socket.AF_INET6, address) return True @@ -176,7 +81,7 @@ def is_valid_ip_address(address): return True -def read_file(filename, mode='r'): +def read_file(filename, mode="r"): # returns content from filename, making sure to close the file on exit. f = open(filename, mode) try: @@ -185,7 +90,7 @@ def read_file(filename, mode='r'): f.close() -def write_file(filename, value, mode='w'): +def write_file(filename, value, mode="w"): # writes to filename, making sure to close the file on exit. f = open(filename, mode) try: diff --git a/emmett/validators/__init__.py b/emmett/validators/__init__.py index 2e6eca14..95574ae4 100644 --- a/emmett/validators/__init__.py +++ b/emmett/validators/__init__.py @@ -1,26 +1,15 @@ # -*- coding: utf-8 -*- """ - emmett.validators - ----------------- +emmett.validators +----------------- - Implements validators for pyDAL. +Implements validators for pyDAL. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ -from .basic import ( - Allow, - Any, - Equals, - hasLength, - isEmpty, - isEmptyOr, - isntEmpty, - Matches, - Not, - Validator -) +from .basic import Allow, Any, Equals, Matches, Not, Validator, hasLength, isEmpty, isEmptyOr, isntEmpty from .consist import ( isAlphanumeric, isDate, @@ -34,10 +23,10 @@ isJSON, isList, isTime, - isUrl + isUrl, ) -from .inside import inRange, inSet, inDB, notInDB -from .process import Cleanup, Crypt, Lower, Urlify, Upper +from .inside import inDB, inRange, inSet, notInDB +from .process import Cleanup, Crypt, Lower, Upper, Urlify class ValidateFromDict(object): @@ -56,15 +45,9 @@ def __init__(self): "ip": isIP, "json": isJSON, "time": isTime, - "url": isUrl - } - self.proc_validators = { - "clean": Cleanup, - "crypt": Crypt, - "lower": Lower, - "upper": Upper, - "urlify": Urlify + "url": isUrl, } + self.proc_validators = {"clean": Cleanup, "crypt": Crypt, "lower": Lower, "upper": Upper, "urlify": Urlify} def parse_num_comparisons(self, data, minv=None, maxv=None): inclusions = [True, False] @@ -96,10 +79,10 @@ def parse_is(self, data, message=None): def parse_is_list(self, data, message=None): #: map types with 'list' fields - key = '' + key = "" options = {} suboptions = {} - lopts = ['splitter'] + lopts = ["splitter"] if isinstance(data, str): #: map {'is': 'list:int'} key = data @@ -114,26 +97,22 @@ def parse_is_list(self, data, message=None): else: raise SyntaxError("'is' validator accepts only string or dict") try: - keyspecs = key.split(':') + keyspecs = key.split(":") subkey = keyspecs[1].strip() - assert keyspecs[0].strip() == 'list' + assert keyspecs[0].strip() == "list" except Exception: - subkey = '_missing_' + subkey = "_missing_" validator = self.is_validators.get(subkey) - return isList( - [validator(message=message, **suboptions)], - message=message, - **options - ) if validator else None + return isList([validator(message=message, **suboptions)], message=message, **options) if validator else None def parse_reference(self, field): ref_table, ref_field, multiple = None, None, None - if field.type.startswith('reference'): + if field.type.startswith("reference"): multiple = False - elif field.type.startswith('list:reference'): + elif field.type.startswith("list:reference"): multiple = True if multiple is not None: - ref_table = field.type.split(' ')[1] + ref_table = field.type.split(" ")[1] model = field.table._model_ #: can't support (yet?) multi pks if model._belongs_ref_[field.name].compound: @@ -146,40 +125,36 @@ def __call__(self, field, data): validators = [] message = data.pop("message", None) #: parse 'presence' and 'empty' - presence = data.get('presence') - empty = data.get('empty') + presence = data.get("presence") + empty = data.get("empty") if presence is None and empty is not None: presence = not empty #: parse 'is' - _is = data.get('is') + _is = data.get("is") if _is is not None: validator = self.parse_is(_is, message) or self.parse_is_list(_is, message) if validator is None: - raise SyntaxError( - "Unknown type %s for 'is' validator" % data - ) + raise SyntaxError("Unknown type %s for 'is' validator" % data) validators.append(validator) #: parse 'len' - _len = data.get('len') + _len = data.get("len") if _len is not None: if isinstance(_len, int): #: allows {'len': 2} - validators.append( - hasLength(_len + 1, _len, message='Enter {min} characters') - ) + validators.append(hasLength(_len + 1, _len, message="Enter {min} characters")) else: #: allows # {'len': {'gt': 1, 'gte': 2, 'lt': 5, 'lte' 6}} # {'len': {'range': (2, 6)}} - if _len.get('range') is not None: - minv, maxv = _len['range'] + if _len.get("range") is not None: + minv, maxv = _len["range"] inc = (True, False) else: minv, maxv, inc = self.parse_num_comparisons(_len, 0, 256) validators.append(hasLength(maxv, minv, inc, message=message)) #: parse 'in' _dbset = None - _in = data.get('in', []) + _in = data.get("in", []) if _in: if isinstance(_in, (list, tuple, set)): #: allows {'in': [1, 2]} @@ -187,22 +162,22 @@ def __call__(self, field, data): elif isinstance(_in, dict): options = {} #: allows {'in': {'range': (1, 5)}} - _range = _in.get('range') + _range = _in.get("range") if isinstance(_range, (tuple, list)): validators.append(inRange(_range[0], _range[1], message=message)) #: allows {'in': {'set': [1, 5]}} with options - _set = _in.get('set') + _set = _in.get("set") if isinstance(_set, (list, tuple, set)): - opt_keys = [key for key in list(_in) if key != 'set'] + opt_keys = [key for key in list(_in) if key != "set"] for key in opt_keys: options[key] = _in[key] validators.append(inSet(_set, message=message, **options)) #: allows {'in': {'dbset': lambda db: db.where(query)}} - _dbset = _in.get('dbset') + _dbset = _in.get("dbset") if callable(_dbset): ref_table, ref_field, multiple = self.parse_reference(field) if ref_table: - opt_keys = [key for key in list(_in) if key != 'dbset'] + opt_keys = [key for key in list(_in) if key != "dbset"] for key in opt_keys: options[key] = _in[key] validators.append( @@ -213,108 +188,86 @@ def __call__(self, field, data): dbset=_dbset, multiple=multiple, message=message, - **options + **options, ) ) else: - raise SyntaxError( - "'in:dbset' validator needs a reference field" - ) + raise SyntaxError("'in:dbset' validator needs a reference field") else: - raise SyntaxError( - "'in' validator accepts only a set or a dict" - ) + raise SyntaxError("'in' validator accepts only a set or a dict") #: parse 'gt', 'gte', 'lt', 'lte' minv, maxv, inc = self.parse_num_comparisons(data) if minv is not None or maxv is not None: validators.append(inRange(minv, maxv, inc, message=message)) #: parse 'equals' - if 'equals' in data: - validators.append(Equals(data['equals'], message=message)) + if "equals" in data: + validators.append(Equals(data["equals"], message=message)) #: parse 'match' - if 'match' in data: - if isinstance(data['match'], dict): - validators.append(Matches(**data['match'], message=message)) + if "match" in data: + if isinstance(data["match"], dict): + validators.append(Matches(**data["match"], message=message)) else: - validators.append(Matches(data['match'], message=message)) + validators.append(Matches(data["match"], message=message)) #: parse transforming validators for key, vclass in self.proc_validators.items(): if key in data: options = {} if isinstance(data[key], dict): options = data[key] - elif data[key] != True: - if key == 'crypt' and isinstance(data[key], str): - options = {'algorithm': data[key]} + elif data[key] is not True: + if key == "crypt" and isinstance(data[key], str): + options = {"algorithm": data[key]} else: - raise SyntaxError( - key + " validator accepts only dict or True" - ) + raise SyntaxError(key + " validator accepts only dict or True") validators.append(vclass(message=message, **options)) #: parse 'unique' - _unique = data.get('unique', False) + _unique = data.get("unique", False) if _unique: _udbset = None if isinstance(_unique, dict): - whr = _unique.get('where', None) + whr = _unique.get("where", None) if callable(whr): _dbset = whr - validators.append( - notInDB( - field.db, - field.table, - field.name, - dbset=_udbset, - message=message - ) - ) + validators.append(notInDB(field.db, field.table, field.name, dbset=_udbset, message=message)) table = field.db[field._tablename] table._unique_fields_validation_[field.name] = 1 #: apply 'format' option - if 'format' in data: + if "format" in data: for validator in validators: children = [validator] - if hasattr(validator, 'children'): + if hasattr(validator, "children"): children += validator.children for child in children: - if hasattr(child, 'format'): - child.format = data['format'] + if hasattr(child, "format"): + child.format = data["format"] break #: parse 'custom' - if 'custom' in data: - if isinstance(data['custom'], list): - for element in data['custom']: + if "custom" in data: + if isinstance(data["custom"], list): + for element in data["custom"]: validators.append(element) else: - validators.append(data['custom']) + validators.append(data["custom"]) #: parse 'any' - if 'any' in data: - validators.append(Any(self(field, data['any']), message=message)) + if "any" in data: + validators.append(Any(self(field, data["any"]), message=message)) #: parse 'not' - if 'not' in data: - validators.append(Not(self(field, data['not']), message=message)) + if "not" in data: + validators.append(Not(self(field, data["not"]), message=message)) #: insert presence/empty validation if needed if presence: ref_table, ref_field, multiple = self.parse_reference(field) if ref_table: if not _dbset: - validators.append( - inDB( - field.db, - ref_table, - ref_field, - multiple=multiple, - message=message - ) - ) + validators.append(inDB(field.db, ref_table, ref_field, multiple=multiple, message=message)) else: validators.insert(0, isntEmpty(message=message)) if empty: validators.insert(0, isEmpty(message=message)) #: parse 'allow' - if 'allow' in data: - if data['allow'] in ['empty', 'blank']: + if "allow" in data: + if data["allow"] in ["empty", "blank"]: validators = [isEmptyOr(validators, message=message)] else: - validators = [Allow(data['allow'], validators, message=message)] + validators = [Allow(data["allow"], validators, message=message)] return validators diff --git a/emmett/validators/basic.py b/emmett/validators/basic.py index d0665d42..2fbb58b9 100644 --- a/emmett/validators/basic.py +++ b/emmett/validators/basic.py @@ -1,27 +1,25 @@ # -*- coding: utf-8 -*- """ - emmett.validators.basic - ----------------------- +emmett.validators.basic +----------------------- - Provide basic validators. +Provide basic validators. - :copyright: 2014 Giovanni Barillari +:copyright: 2014 Giovanni Barillari - Based on the web2py's validators (http://www.web2py.com) - :copyright: (c) by Massimo Di Pierro +Based on the web2py's validators (http://www.web2py.com) +:copyright: (c) by Massimo Di Pierro - :license: LGPLv3 (http://www.gnu.org/licenses/lgpl.html) +:license: LGPLv3 (http://www.gnu.org/licenses/lgpl.html) """ import re - -from cgi import FieldStorage from functools import reduce from os import SEEK_END, SEEK_SET # TODO: check unicode conversions from .._shortcuts import to_unicode -from .helpers import translate, is_empty +from .helpers import is_empty, translate class Validator: @@ -45,11 +43,7 @@ def __init__(self, children, message=None): self.children = children def formatter(self, value): - return reduce( - lambda formatted_val, child: child.formatter(formatted_val), - self.children, - value - ) + return reduce(lambda formatted_val, child: child.formatter(formatted_val), self.children, value) def __call__(self, value): raise NotImplementedError @@ -60,12 +54,7 @@ class _is(Validator): rule = None def __call__(self, value): - if ( - self.rule is None or ( - self.rule is not None and - self.rule.match(to_unicode(value) or '') - ) - ): + if self.rule is None or (self.rule is not None and self.rule.match(to_unicode(value) or "")): return self.check(value) return value, translate(self.message) @@ -139,19 +128,19 @@ def __init__(self, children, empty_regex=None, message=None): super().__init__(children, message=message) self.empty_regex = re.compile(empty_regex) if empty_regex is not None else None for child in self.children: - if hasattr(child, 'multiple'): + if hasattr(child, "multiple"): self.multiple = child.multiple break for child in self.children: - if hasattr(child, 'options'): + if hasattr(child, "options"): self._options_ = child.options self.options = self._get_options_ break def _get_options_(self): options = self._options_() - if (not options or options[0][0] != '') and not self.multiple: - options.insert(0, ('', '')) + if (not options or options[0][0] != "") and not self.multiple: + options.insert(0, ("", "")) return options def __call__(self, value): @@ -182,21 +171,19 @@ def __call__(self, value): class Matches(Validator): message = "Invalid expression" - def __init__( - self, expression, strict=False, search=False, extract=False, message=None - ): + def __init__(self, expression, strict=False, search=False, extract=False, message=None): super().__init__(message=message) if strict or not search: - if not expression.startswith('^'): - expression = '^(%s)' % expression + if not expression.startswith("^"): + expression = "^(%s)" % expression if strict: - if not expression.endswith('$'): - expression = '(%s)$' % expression + if not expression.endswith("$"): + expression = "(%s)$" % expression self.regex = re.compile(expression) self.extract = extract def __call__(self, value): - match = self.regex.search(to_unicode(value) or '') + match = self.regex.search(to_unicode(value) or "") if match is not None: return self.extract and match.group() or value, None return value, translate(self.message) @@ -205,9 +192,7 @@ def __call__(self, value): class hasLength(Validator): message = "Enter from {min} to {max} characters" - def __init__( - self, maxsize=256, minsize=0, include=(True, False), message=None - ): + def __init__(self, maxsize=256, minsize=0, include=(True, False), message=None): super().__init__(message=message) self.maxsize = maxsize self.minsize = minsize @@ -229,24 +214,26 @@ def __call__(self, value): length = 0 if self._between(length): return value, None - elif getattr(value, '_emt_field_hashed_contents_', False): + elif getattr(value, "_emt_field_hashed_contents_", False): return value, None - elif isinstance(value, FieldStorage): + elif hasattr(value, "file"): if value.file: value.file.seek(0, SEEK_END) length = value.file.tell() value.file.seek(0, SEEK_SET) - elif hasattr(value, 'value'): - val = value.value - if val: - length = len(val) - else: - length = 0 + if self._between(length): + return value, None + elif hasattr(value, "value"): + val = value.value + if val: + length = len(val) + else: + length = 0 if self._between(length): return value, None elif isinstance(value, bytes): try: - lvalue = len(value.decode('utf8')) + lvalue = len(value.decode("utf8")) except Exception: lvalue = len(value) if self._between(lvalue): diff --git a/emmett/validators/consist.py b/emmett/validators/consist.py index 6b278857..db13c8b2 100644 --- a/emmett/validators/consist.py +++ b/emmett/validators/consist.py @@ -1,30 +1,31 @@ # -*- coding: utf-8 -*- """ - emmett.validators.consist - ------------------------- +emmett.validators.consist +------------------------- - Validators that check the value is of a certain type. +Validators that check the value is of a certain type. - :copyright: 2014 Giovanni Barillari +:copyright: 2014 Giovanni Barillari - Based on the web2py's validators (http://www.web2py.com) - :copyright: (c) by Massimo Di Pierro +Based on the web2py's validators (http://www.web2py.com) +:copyright: (c) by Massimo Di Pierro - :license: LGPLv3 (http://www.gnu.org/licenses/lgpl.html) +:license: LGPLv3 (http://www.gnu.org/licenses/lgpl.html) """ import decimal import re import struct - -from datetime import date, time, datetime, timedelta +from datetime import date, datetime, time, timedelta from time import strptime from urllib.parse import unquote as url_unquote +from emmett_core.utils import cachedprop + from ..parsers import Parsers from ..serializers import Serializers -from ..utils import cachedprop, parse_datetime -from .basic import Validator, ParentValidator, _is, Matches +from ..utils import parse_datetime +from .basic import Matches, ParentValidator, Validator, _is from .helpers import ( _DEFAULT, _UTC, @@ -32,9 +33,10 @@ official_url_schemes, translate, unofficial_url_schemes, - url_split_regex + url_split_regex, ) + try: import ipaddress except ImportError: @@ -57,7 +59,7 @@ def __init__(self, dot=".", message=None): def check(self, value): try: - v = float(str(value).replace(self.dot, '.')) + v = float(str(value).replace(self.dot, ".")) return v, None except (ValueError, TypeError): pass @@ -67,11 +69,11 @@ def formatter(self, value): if value is None: return None val = str(value) - if '.' not in val: - val += '.00' + if "." not in val: + val += ".00" else: - val += '0' * (2 - len(val.split('.')[1])) - return val.replace('.', self.dot) + val += "0" * (2 - len(val.split(".")[1])) + return val.replace(".", self.dot) class isDecimal(isFloat): @@ -80,17 +82,14 @@ def check(self, value): if isinstance(value, decimal.Decimal): v = value else: - v = decimal.Decimal(str(value).replace(self.dot, '.')) + v = decimal.Decimal(str(value).replace(self.dot, ".")) return v, None except (ValueError, TypeError, decimal.InvalidOperation): return value, translate(self.message) class isTime(_is): - rule = re.compile( - r"((?P[0-9]+))([^0-9 ]+(?P[0-9 ]+))?" - r"([^0-9ap ]+(?P[0-9]*))?((?P[ap]m))?" - ) + rule = re.compile(r"((?P[0-9]+))([^0-9 ]+(?P[0-9 ]+))?" r"([^0-9ap ]+(?P[0-9]*))?((?P[ap]m))?") def __call__(self, value): return super().__call__(value.lower() if value else value) @@ -100,20 +99,17 @@ def check(self, value): return value, None val = self.rule.match(value) try: - (h, m, s) = (int(val.group('h')), 0, 0) - if not val.group('m') is None: - m = int(val.group('m')) - if not val.group('s') is None: - s = int(val.group('s')) - if val.group('d') == 'pm' and 0 < h < 12: + (h, m, s) = (int(val.group("h")), 0, 0) + if val.group("m") is not None: + m = int(val.group("m")) + if val.group("s") is not None: + s = int(val.group("s")) + if val.group("d") == "pm" and 0 < h < 12: h = h + 12 - if val.group('d') == 'am' and h == 12: + if val.group("d") == "am" and h == 12: h = 0 - if not (h in range(24) and m in range(60) and s - in range(60)): - raise ValueError( - 'Hours or minutes or seconds are outside of allowed range' - ) + if not (h in range(24) and m in range(60) and s in range(60)): + raise ValueError("Hours or minutes or seconds are outside of allowed range") val = time(h, m, s) return val, None except AttributeError: @@ -124,7 +120,7 @@ def check(self, value): class isDate(_is): - def __init__(self, format='%Y-%m-%d', timezone=None, message=None): + def __init__(self, format="%Y-%m-%d", timezone=None, message=None): super().__init__(message=message) self.format = translate(format) self.timezone = timezone @@ -165,13 +161,21 @@ def formatter(self, value): @staticmethod def nice(format): codes = ( - ('%Y', '1963'), ('%y', '63'), ('%d', '28'), ('%m', '08'), - ('%b', 'Aug'), ('%B', 'August'), ('%H', '14'), ('%I', '02'), - ('%p', 'PM'), ('%M', '30'), ('%S', '59') + ("%Y", "1963"), + ("%y", "63"), + ("%d", "28"), + ("%m", "08"), + ("%b", "Aug"), + ("%B", "August"), + ("%H", "14"), + ("%I", "02"), + ("%p", "PM"), + ("%M", "30"), + ("%S", "59"), ) - for (a, b) in codes: + for a, b in codes: format = format.replace(a, b) - return dict(format=format) + return {"format": format} class isDatetime(isDate): @@ -181,7 +185,7 @@ def __init__(self, format=_DEFAULT, **kwargs): def _get_parser(self, format): if format is _DEFAULT: - return self._parse_pendulum, '%Y-%m-%dT%H:%M:%S' + return self._parse_pendulum, "%Y-%m-%dT%H:%M:%S" return self._parse_strptime, format def _parse_strptime(self, value): @@ -189,27 +193,20 @@ def _parse_strptime(self, value): return datetime(y, m, d, hh, mm, ss) def _parse_pendulum(self, value): - return parse_datetime(value).in_timezone('UTC') + return parse_datetime(value).in_timezone("UTC") def _check_instance(self, value): return isinstance(value, datetime) def _formatter_obj(self, value): - return datetime( - value.year, - value.month, - value.day, - value.hour, - value.minute, - value.second - ) + return datetime(value.year, value.month, value.day, value.hour, value.minute, value.second) class isEmail(_is): rule = re.compile( r"^(?!\.)([-a-z0-9!\#$%&'*+/=?^_`{|}~]|(?= 0 - extension = value.filename[extension + 1:].lower() - if extension == 'jpg': - extension = 'jpeg' + extension = value.filename[extension + 1 :].lower() + if extension == "jpg": + extension = "jpeg" assert extension in self.extensions - if extension == 'bmp': + if extension == "bmp": width, height = self.__bmp(value.file) - elif extension == 'gif': + elif extension == "gif": width, height = self.__gif(value.file) - elif extension == 'jpeg': + elif extension == "jpeg": width, height = self.__jpeg(value.file) - elif extension == 'png': + elif extension == "png": width, height = self.__png(value.file) else: width = -1 height = -1 - assert self.minsize[0] <= width <= self.maxsize[0] \ - and self.minsize[1] <= height <= self.maxsize[1] + assert self.minsize[0] <= width <= self.maxsize[0] and self.minsize[1] <= height <= self.maxsize[1] value.file.seek(0) return value, None - except: + except Exception: return value, translate(self.message) def __bmp(self, stream): - if stream.read(2) == 'BM': + if stream.read(2) == "BM": stream.read(16) return struct.unpack("= 0xC0 and code <= 0xC3: - return tuple(reversed( - struct.unpack("!xHH", stream.read(5)))) + return tuple(reversed(struct.unpack("!xHH", stream.read(5)))) else: stream.read(length - 2) return -1, -1 def __png(self, stream): - if stream.read(8) == '\211PNG\r\n\032\n': + if stream.read(8) == "\211PNG\r\n\032\n": stream.read(4) if stream.read(4) == "IHDR": return struct.unpack("!LL", stream.read(8)) @@ -379,6 +371,7 @@ class _isGenericUrl(Validator): Based on RFC 2396: http://www.faqs.org/rfcs/rfc2396.html @author: Jonathan Benn """ + message = "Invalid URL" all_url_schemes = [None] + official_url_schemes + unofficial_url_schemes @@ -388,14 +381,11 @@ def __init__(self, schemes=None, prepend_scheme=None, message=None): self.prepend_scheme = prepend_scheme if self.prepend_scheme not in self.allowed_schemes: raise SyntaxError( - "prepend_scheme='{}' is not in allowed_schemes={}".format( - self.prepend_scheme, self.allowed_schemes - ) + "prepend_scheme='{}' is not in allowed_schemes={}".format(self.prepend_scheme, self.allowed_schemes) ) GENERIC_URL = re.compile( - r"%[^0-9A-Fa-f]{2}|%[^0-9A-Fa-f][0-9A-Fa-f]|%[0-9A-Fa-f][^0-9A-Fa-f]|" - r"%$|%[0-9A-Fa-f]$|%[^0-9A-Fa-f]$" + r"%[^0-9A-Fa-f]{2}|%[^0-9A-Fa-f][0-9A-Fa-f]|%[0-9A-Fa-f][^0-9A-Fa-f]|" r"%$|%[0-9A-Fa-f]$|%[^0-9A-Fa-f]$" ) GENERIC_URL_VALID = re.compile(r"[A-Za-z0-9;/?:@&=+$,\-_\.!~*'\(\)%#]+$") @@ -420,11 +410,9 @@ def __call__(self, value): # ports, check to see if adding a valid scheme fixes # the problem (but only do this if it doesn't have # one already!) - if value.find('://') < 0 and \ - None in self.allowed_schemes: - schemeToUse = self.prepend_scheme or 'http' - prependTest = self.__call__( - schemeToUse + '://' + value) + if value.find("://") < 0 and None in self.allowed_schemes: + schemeToUse = self.prepend_scheme or "http" + prependTest = self.__call__(schemeToUse + "://" + value) # if the prepend test succeeded if prependTest[1] is None: # if prepending in the output is enabled @@ -455,42 +443,30 @@ class _isHTTPUrl(Validator): """ message = "Invalid URL" - http_schemes = [None, 'http', 'https'] - GENERIC_VALID_IP = re.compile( - r"([\w.!~*'|;:&=+$,-]+@)?\d+\.\d+\.\d+\.\d+(:\d*)*$" - ) + http_schemes = [None, "http", "https"] + GENERIC_VALID_IP = re.compile(r"([\w.!~*'|;:&=+$,-]+@)?\d+\.\d+\.\d+\.\d+(:\d*)*$") GENERIC_VALID_DOMAIN = re.compile( r"([\w.!~*'|;:&=+$,-]+@)?(([A-Za-z0-9]+[A-Za-z0-9\-]*[A-Za-z0-9]+\.)" r"*([A-Za-z0-9]+\.)*)*([A-Za-z]+[A-Za-z0-9\-]*[A-Za-z0-9]+)\.?(:\d*)*$" ) - def __init__(self, schemes=None, prepend_scheme='http', tlds=None, message=None): + def __init__(self, schemes=None, prepend_scheme="http", tlds=None, message=None): super().__init__(message=message) self.allowed_schemes = schemes or self.http_schemes self.allowed_tlds = tlds or official_top_level_domains self.prepend_scheme = prepend_scheme for i in self.allowed_schemes: if i not in self.http_schemes: - raise SyntaxError( - "allowed_scheme value '{}' is not in {}".format( - i, self.http_schemes - ) - ) + raise SyntaxError("allowed_scheme value '{}' is not in {}".format(i, self.http_schemes)) if self.prepend_scheme not in self.allowed_schemes: raise SyntaxError( - "prepend_scheme='{}' is not in allowed_schemes={}".format( - self.prepend_scheme, self.allowed_schemes - ) + "prepend_scheme='{}' is not in allowed_schemes={}".format(self.prepend_scheme, self.allowed_schemes) ) def __call__(self, value): try: # if the URL passes generic validation - x = _isGenericUrl( - schemes=self.allowed_schemes, - prepend_scheme=self.prepend_scheme, - message=self.message - ) + x = _isGenericUrl(schemes=self.allowed_schemes, prepend_scheme=self.prepend_scheme, message=self.message) if x(value)[1] is None: componentsMatch = url_split_regex.match(value) authority = componentsMatch.group(4) @@ -502,12 +478,10 @@ def __call__(self, value): return value, None else: # else if authority is a valid domain name - domainMatch = self.GENERIC_VALID_DOMAIN.match( - authority) + domainMatch = self.GENERIC_VALID_DOMAIN.match(authority) if domainMatch: # if the top-level domain really exists - if domainMatch.group(5).lower()\ - in self.allowed_tlds: + if domainMatch.group(5).lower() in self.allowed_tlds: # Then this HTTP URL is valid return value, None else: @@ -516,15 +490,15 @@ def __call__(self, value): path = componentsMatch.group(5) # relative case: if this is a valid path (if it starts with # a slash) - if path.startswith('/'): + if path.startswith("/"): # Then this HTTP URL is valid return value, None else: # abbreviated case: if we haven't already, prepend a # scheme and see if it fixes the problem - if value.find('://') < 0: - schemeToUse = self.prepend_scheme or 'http' - prependTest = self(schemeToUse + '://' + value) + if value.find("://") < 0: + schemeToUse = self.prepend_scheme or "http" + prependTest = self(schemeToUse + "://" + value) # if the prepend test succeeded if prependTest[1] is None: # if prepending in the output is enabled @@ -534,7 +508,7 @@ def __call__(self, value): # else return the original, non-prepended # value return value, None - except: + except Exception: pass # else the HTTP URL is not valid return value, translate(self.message) @@ -544,39 +518,33 @@ class isUrl(Validator): #: use `_isGenericUrl` and `_isHTTPUrl` depending on `mode` parameter message = "Invalid URL" - def __init__( - self, mode='http', schemes=None, prepend_scheme='http', tlds=None, message=None - ): + def __init__(self, mode="http", schemes=None, prepend_scheme="http", tlds=None, message=None): super().__init__(message=message) self.mode = mode.lower() - if self.mode not in ['generic', 'http']: + if self.mode not in ["generic", "http"]: raise SyntaxError("invalid mode '{}' in isUrl".format(self.mode)) self.allowed_tlds = tlds self.allowed_schemes = schemes if self.allowed_schemes: if prepend_scheme not in self.allowed_schemes: raise SyntaxError( - "prepend_scheme='{}' is not in allowed_schemes={}".format( - prepend_scheme, self.allowed_schemes - ) + "prepend_scheme='{}' is not in allowed_schemes={}".format(prepend_scheme, self.allowed_schemes) ) # if allowed_schemes is None, then we will defer testing # prepend_scheme's validity to a sub-method self.prepend_scheme = prepend_scheme def __call__(self, value): - if self.mode == 'generic': + if self.mode == "generic": subValidator = _isGenericUrl( - schemes=self.allowed_schemes, - prepend_scheme=self.prepend_scheme, - message=self.message + schemes=self.allowed_schemes, prepend_scheme=self.prepend_scheme, message=self.message ) - elif self.mode == 'http': + elif self.mode == "http": subValidator = _isHTTPUrl( schemes=self.allowed_schemes, prepend_scheme=self.prepend_scheme, tlds=self.allowed_tlds, - message=self.message + message=self.message, ) else: raise SyntaxError("invalid mode '{}' in isUrl".format(self.mode)) @@ -619,38 +587,27 @@ class isIPv4(Validator): """ message = "Invalid IPv4 address" - regex = re.compile( - r"^(([1-9]?\d|1\d\d|2[0-4]\d|25[0-5])\.){3}([1-9]?\d|1\d\d|2[0-4]\d|25[0-5])$" - ) + regex = re.compile(r"^(([1-9]?\d|1\d\d|2[0-4]\d|25[0-5])\.){3}([1-9]?\d|1\d\d|2[0-4]\d|25[0-5])$") numbers = (16777216, 65536, 256, 1) localhost = 2130706433 private = ((2886729728, 2886795263), (3232235520, 3232301055)) automatic = (2851995648, 2852061183) def __init__( - self, - min='0.0.0.0', - max='255.255.255.255', - invert=False, - localhost=None, - private=None, - auto=None, - message=None + self, min="0.0.0.0", max="255.255.255.255", invert=False, localhost=None, private=None, auto=None, message=None ): super().__init__(message=message) for n, value in enumerate((min, max)): temp = [] if isinstance(value, str): - temp.append(value.split('.')) + temp.append(value.split(".")) elif isinstance(value, (list, tuple)): - if len(value) == len( - list(filter(lambda item: isinstance(item, int), value)) - ) == 4: + if len(value) == len(list(filter(lambda item: isinstance(item, int), value))) == 4: temp.append(value) else: for item in value: if isinstance(item, str): - temp.append(item.split('.')) + temp.append(item.split(".")) elif isinstance(item, (list, tuple)): temp.append(item) numbers = [] @@ -671,31 +628,21 @@ def __init__( def __call__(self, value): if self.regex.match(value): number = 0 - for i, j in zip(self.numbers, value.split('.')): + for i, j in zip(self.numbers, value.split(".")): number += i * int(j) ok = False for bottom, top in zip(self.minip, self.maxip): if self.invert != (bottom <= number <= top): ok = True - if not ( - self.is_localhost is None or self.is_localhost == ( - number == self.localhost - ) - ): + if not (self.is_localhost is None or self.is_localhost == (number == self.localhost)): ok = False if not ( - self.is_private is None or self.is_private == ( - sum([ - el[0] <= number <= el[1] - for el in self.private - ]) > 0 - ) + self.is_private is None + or self.is_private == (sum([el[0] <= number <= el[1] for el in self.private]) > 0) ): ok = False if not ( - self.is_automatic is None or self.is_automatic == ( - self.automatic[0] <= number <= self.automatic[1] - ) + self.is_automatic is None or self.is_automatic == (self.automatic[0] <= number <= self.automatic[1]) ): ok = False if ok: @@ -734,7 +681,7 @@ def __init__( to4=None, teredo=None, subnets=None, - message=None + message=None, ): super().__init__(message=message) self.is_private = private @@ -747,9 +694,7 @@ def __init__( self.subnets = subnets if ipaddress is None: - raise RuntimeError( - "You need 'ipaddress' python module to use isIPv6 validator." - ) + raise RuntimeError("You need 'ipaddress' python module to use isIPv6 validator.") def __call__(self, value): try: @@ -766,9 +711,8 @@ def __call__(self, value): for network in self.subnets: try: ipnet = ipaddress.IPv6Network(network) - except (ipaddress.NetmaskValueError, - ipaddress.AddressValueError): - return value, translate('invalid subnet provided') + except (ipaddress.NetmaskValueError, ipaddress.AddressValueError): + return value, translate("invalid subnet provided") if ip in ipnet: ok = True @@ -778,22 +722,13 @@ def __call__(self, value): self.is_reserved = False self.is_multicast = False - if not ( - self.is_private is None or self.is_private == ip.is_private - ): + if not (self.is_private is None or self.is_private == ip.is_private): ok = False - if not ( - self.is_link_local is None or - self.is_link_local == ip.is_link_local - ): + if not (self.is_link_local is None or self.is_link_local == ip.is_link_local): ok = False - if not ( - self.is_reserved is None or self.is_reserved == ip.is_reserved - ): + if not (self.is_reserved is None or self.is_reserved == ip.is_reserved): ok = False - if not ( - self.is_multicast is None or self.is_multicast == ip.is_multicast - ): + if not (self.is_multicast is None or self.is_multicast == ip.is_multicast): ok = False if not (self.is_6to4 is None or self.is_6to4 == ip.is_6to4): ok = False @@ -819,8 +754,8 @@ class isIP(Validator): def __init__( self, - min='0.0.0.0', - max='255.255.255.255', + min="0.0.0.0", + max="255.255.255.255", invert=False, localhost=None, private=None, @@ -834,11 +769,11 @@ def __init__( teredo=None, subnets=None, ipv6=None, - message=None + message=None, ): super().__init__(message=message) - self.minip = min, - self.maxip = max, + self.minip = (min,) + self.maxip = (max,) self.invert = invert self.is_localhost = localhost self.is_private = private @@ -855,9 +790,7 @@ def __init__( self.is_ipv6 = ipv6 if ipaddress is None: - raise RuntimeError( - "You need 'ipaddress' python module to use isIP validator." - ) + raise RuntimeError("You need 'ipaddress' python module to use isIP validator.") def __call__(self, value): try: @@ -877,7 +810,7 @@ def __call__(self, value): localhost=self.is_localhost, private=self.is_private, auto=self.is_automatic, - message=self.message + message=self.message, )(value) elif self.is_ipv6 or isinstance(ip, ipaddress.IPv6Address): rv = isIPv6( @@ -889,7 +822,7 @@ def __call__(self, value): to4=self.is_6to4, teredo=self.is_teredo, subnets=self.subnets, - message=self.message + message=self.message, )(value) else: rv = (value, translate(self.message)) diff --git a/emmett/validators/helpers.py b/emmett/validators/helpers.py index 1f249ea6..8f7c8314 100644 --- a/emmett/validators/helpers.py +++ b/emmett/validators/helpers.py @@ -1,25 +1,25 @@ # -*- coding: utf-8 -*- """ - emmett.validators.helpers - ------------------------- +emmett.validators.helpers +------------------------- - Provides utilities for validators. +Provides utilities for validators. - Ported from the original validators of web2py (http://www.web2py.com) +Ported from the original validators of web2py (http://www.web2py.com) - :copyright: (c) by Massimo Di Pierro - :license: LGPLv3 (http://www.gnu.org/licenses/lgpl.html) +:copyright: (c) by Massimo Di Pierro +:license: LGPLv3 (http://www.gnu.org/licenses/lgpl.html) """ import re - -from datetime import tzinfo, timedelta +from datetime import timedelta, tzinfo +from functools import reduce from io import StringIO # TODO: check unicode conversions -from .._shortcuts import to_unicode +from .._shortcuts import to_unicode, uuid from ..ctx import current -from ..security import simple_hash, uuid, DIGEST_ALG_BY_SIZE +from ..security import DIGEST_ALG_BY_SIZE, simple_hash _DEFAULT = lambda: None @@ -43,7 +43,7 @@ def is_empty(value, empty_regex=None): if isinstance(value, (str, bytes)): vclean = value.strip() if empty_regex is not None and empty_regex.match(vclean): - vclean = '' + vclean = "" return vclean, len(vclean) == 0 if isinstance(value, (list, dict)): return value, len(value) == 0 @@ -95,21 +95,21 @@ def __str__(self): if self.crypted: return self.crypted if self.crypt.key: - if ':' in self.crypt.key: - digest_alg, key = self.crypt.key.split(':', 1) + if ":" in self.crypt.key: + digest_alg, key = self.crypt.key.split(":", 1) else: digest_alg, key = self.crypt.digest_alg, self.crypt.key else: - digest_alg, key = self.crypt.digest_alg, '' + digest_alg, key = self.crypt.digest_alg, "" if self.crypt.salt: - if self.crypt.salt == True: - salt = str(uuid()).replace('-', '')[-16:] + if self.crypt.salt is True: + salt = str(uuid()).replace("-", "")[-16:] else: salt = self.crypt.salt else: - salt = '' + salt = "" hashed = simple_hash(self.password, key, salt, digest_alg) - self.crypted = '%s$%s$%s' % (digest_alg, salt, hashed) + self.crypted = "%s$%s$%s" % (digest_alg, salt, hashed) return self.crypted def __eq__(self, stored_password): @@ -117,30 +117,30 @@ def __eq__(self, stored_password): compares the current lazy crypted password with a stored password """ if isinstance(stored_password, self.__class__): - return ((self is stored_password) or - ((self.crypt.key == stored_password.crypt.key) and - (self.password == stored_password.password))) + return (self is stored_password) or ( + (self.crypt.key == stored_password.crypt.key) and (self.password == stored_password.password) + ) if self.crypt.key: - if ':' in self.crypt.key: - key = self.crypt.key.split(':')[1] + if ":" in self.crypt.key: + key = self.crypt.key.split(":")[1] else: key = self.crypt.key else: - key = '' + key = "" if stored_password is None: return False - elif stored_password.count('$') == 2: - (digest_alg, salt, hash) = stored_password.split('$') + elif stored_password.count("$") == 2: + (digest_alg, salt, hash) = stored_password.split("$") h = simple_hash(self.password, key, salt, digest_alg) - temp_pass = '%s$%s$%s' % (digest_alg, salt, h) + temp_pass = "%s$%s$%s" % (digest_alg, salt, h) else: # no salting # guess digest_alg digest_alg = DIGEST_ALG_BY_SIZE.get(len(stored_password), None) if not digest_alg: return False else: - temp_pass = simple_hash(self.password, key, '', digest_alg) + temp_pass = simple_hash(self.password, key, "", digest_alg) return temp_pass == stored_password def __ne__(self, other): @@ -148,7 +148,7 @@ def __ne__(self, other): def _escape_unicode(string): - ''' + """ Converts a unicode string into US-ASCII, using a simple conversion scheme. Each unicode character that does not have a US-ASCII equivalent is converted into a URL escaped form based on its hexadecimal value. @@ -162,14 +162,14 @@ def _escape_unicode(string): string: the US-ASCII escaped form of the inputted string @author: Jonathan Benn - ''' + """ returnValue = StringIO() for character in string: code = ord(character) if code > 0x7F: hexCode = hex(code) - returnValue.write('%' + hexCode[2:4] + '%' + hexCode[4:6]) + returnValue.write("%" + hexCode[2:4] + "%" + hexCode[4:6]) else: returnValue.write(character) @@ -177,7 +177,7 @@ def _escape_unicode(string): def _unicode_to_ascii_authority(authority): - ''' + """ Follows the steps in RFC 3490, Section 4 to convert a unicode authority string into its ASCII equivalent. For example, u'www.Alliancefran\xe7aise.nu' will be converted into @@ -196,40 +196,41 @@ def _unicode_to_ascii_authority(authority): authority @author: Jonathan Benn - ''' - label_split_regex = re.compile(u'[\u002e\u3002\uff0e\uff61]') + """ + label_split_regex = re.compile("[\u002e\u3002\uff0e\uff61]") - #RFC 3490, Section 4, Step 1 - #The encodings.idna Python module assumes that AllowUnassigned == True + # RFC 3490, Section 4, Step 1 + # The encodings.idna Python module assumes that AllowUnassigned == True - #RFC 3490, Section 4, Step 2 + # RFC 3490, Section 4, Step 2 labels = label_split_regex.split(authority) - #RFC 3490, Section 4, Step 3 - #The encodings.idna Python module assumes that UseSTD3ASCIIRules == False + # RFC 3490, Section 4, Step 3 + # The encodings.idna Python module assumes that UseSTD3ASCIIRules == False - #RFC 3490, Section 4, Step 4 - #We use the ToASCII operation because we are about to put the authority - #into an IDN-unaware slot + # RFC 3490, Section 4, Step 4 + # We use the ToASCII operation because we are about to put the authority + # into an IDN-unaware slot asciiLabels = [] try: import encodings.idna + for label in labels: if label: asciiLabels.append(encodings.idna.ToASCII(label)) else: - #encodings.idna.ToASCII does not accept an empty string, but - #it is necessary for us to allow for empty labels so that we - #don't modify the URL - asciiLabels.append('') - except: + # encodings.idna.ToASCII does not accept an empty string, but + # it is necessary for us to allow for empty labels so that we + # don't modify the URL + asciiLabels.append("") + except Exception: asciiLabels = [str(label) for label in labels] - #RFC 3490, Section 4, Step 5 - return str(reduce(lambda x, y: x + unichr(0x002E) + y, asciiLabels)) + # RFC 3490, Section 4, Step 5 + return str(reduce(lambda x, y: x + chr(0x002E) + y, asciiLabels)) def unicode_to_ascii_url(url, prepend_scheme): - ''' + """ Converts the inputed unicode url into a US-ASCII equivalent. This function goes a little beyond RFC 3490, which is limited in scope to the domain name (authority) only. Here, the functionality is expanded to what was observed @@ -259,223 +260,1016 @@ def unicode_to_ascii_url(url, prepend_scheme): string: a US-ASCII equivalent of the inputed url @author: Jonathan Benn - ''' - #convert the authority component of the URL into an ASCII punycode string, - #but encode the rest using the regular URI character encoding + """ + # convert the authority component of the URL into an ASCII punycode string, + # but encode the rest using the regular URI character encoding groups = url_split_regex.match(url).groups() - #If no authority was found + # If no authority was found if not groups[3]: - #Try appending a scheme to see if that fixes the problem - scheme_to_prepend = prepend_scheme or 'http' - groups = url_split_regex.match( - to_unicode(scheme_to_prepend) + u'://' + url).groups() - #if we still can't find the authority + # Try appending a scheme to see if that fixes the problem + scheme_to_prepend = prepend_scheme or "http" + groups = url_split_regex.match(to_unicode(scheme_to_prepend) + "://" + url).groups() + # if we still can't find the authority if not groups[3]: - raise Exception('No authority component found, ' + - 'could not decode unicode to US-ASCII') + raise Exception("No authority component found, " + "could not decode unicode to US-ASCII") - #We're here if we found an authority, let's rebuild the URL + # We're here if we found an authority, let's rebuild the URL scheme = groups[1] authority = groups[3] - path = groups[4] or '' - query = groups[5] or '' - fragment = groups[7] or '' + path = groups[4] or "" + query = groups[5] or "" + fragment = groups[7] or "" if prepend_scheme: - scheme = str(scheme) + '://' + scheme = str(scheme) + "://" else: - scheme = '' - return scheme + _unicode_to_ascii_authority(authority) +\ - _escape_unicode(path) + _escape_unicode(query) + str(fragment) + scheme = "" + return ( + scheme + _unicode_to_ascii_authority(authority) + _escape_unicode(path) + _escape_unicode(query) + str(fragment) + ) -url_split_regex = \ - re.compile(r'^(([^:/?#]+):)?(//([^/?#]*))?([^?#]*)(\?([^#]*))?(#(.*))?') +url_split_regex = re.compile(r"^(([^:/?#]+):)?(//([^/?#]*))?([^?#]*)(\?([^#]*))?(#(.*))?") official_url_schemes = [ - 'aaa', 'aaas', 'acap', 'cap', 'cid', 'crid', 'data', 'dav', 'dict', - 'dns', 'fax', 'file', 'ftp', 'go', 'gopher', 'h323', 'http', 'https', - 'icap', 'im', 'imap', 'info', 'ipp', 'iris', 'iris.beep', 'iris.xpc', - 'iris.xpcs', 'iris.lws', 'ldap', 'mailto', 'mid', 'modem', 'msrp', - 'msrps', 'mtqp', 'mupdate', 'news', 'nfs', 'nntp', 'opaquelocktoken', - 'pop', 'pres', 'prospero', 'rtsp', 'service', 'shttp', 'sip', 'sips', - 'snmp', 'soap.beep', 'soap.beeps', 'tag', 'tel', 'telnet', 'tftp', - 'thismessage', 'tip', 'tv', 'urn', 'vemmi', 'wais', 'xmlrpc.beep', - 'xmlrpc.beep', 'xmpp', 'z39.50r', 'z39.50s' + "aaa", + "aaas", + "acap", + "cap", + "cid", + "crid", + "data", + "dav", + "dict", + "dns", + "fax", + "file", + "ftp", + "go", + "gopher", + "h323", + "http", + "https", + "icap", + "im", + "imap", + "info", + "ipp", + "iris", + "iris.beep", + "iris.xpc", + "iris.xpcs", + "iris.lws", + "ldap", + "mailto", + "mid", + "modem", + "msrp", + "msrps", + "mtqp", + "mupdate", + "news", + "nfs", + "nntp", + "opaquelocktoken", + "pop", + "pres", + "prospero", + "rtsp", + "service", + "shttp", + "sip", + "sips", + "snmp", + "soap.beep", + "soap.beeps", + "tag", + "tel", + "telnet", + "tftp", + "thismessage", + "tip", + "tv", + "urn", + "vemmi", + "wais", + "xmlrpc.beep", + "xmlrpc.beep", + "xmpp", + "z39.50r", + "z39.50s", ] unofficial_url_schemes = [ - 'about', 'adiumxtra', 'aim', 'afp', 'aw', 'callto', 'chrome', 'cvs', - 'ed2k', 'feed', 'fish', 'gg', 'gizmoproject', 'iax2', 'irc', 'ircs', - 'itms', 'jar', 'javascript', 'keyparc', 'lastfm', 'ldaps', 'magnet', - 'mms', 'msnim', 'mvn', 'notes', 'nsfw', 'psyc', 'paparazzi:http', - 'rmi', 'rsync', 'secondlife', 'sgn', 'skype', 'ssh', 'sftp', 'smb', - 'sms', 'soldat', 'steam', 'svn', 'teamspeak', 'unreal', 'ut2004', - 'ventrilo', 'view-source', 'webcal', 'wyciwyg', 'xfire', 'xri', 'ymsgr' + "about", + "adiumxtra", + "aim", + "afp", + "aw", + "callto", + "chrome", + "cvs", + "ed2k", + "feed", + "fish", + "gg", + "gizmoproject", + "iax2", + "irc", + "ircs", + "itms", + "jar", + "javascript", + "keyparc", + "lastfm", + "ldaps", + "magnet", + "mms", + "msnim", + "mvn", + "notes", + "nsfw", + "psyc", + "paparazzi:http", + "rmi", + "rsync", + "secondlife", + "sgn", + "skype", + "ssh", + "sftp", + "smb", + "sms", + "soldat", + "steam", + "svn", + "teamspeak", + "unreal", + "ut2004", + "ventrilo", + "view-source", + "webcal", + "wyciwyg", + "xfire", + "xri", + "ymsgr", ] official_top_level_domains = [ # a - 'abogado', 'ac', 'academy', 'accountants', 'active', 'actor', - 'ad', 'adult', 'ae', 'aero', 'af', 'ag', 'agency', 'ai', - 'airforce', 'al', 'allfinanz', 'alsace', 'am', 'amsterdam', 'an', - 'android', 'ao', 'apartments', 'aq', 'aquarelle', 'ar', 'archi', - 'army', 'arpa', 'as', 'asia', 'associates', 'at', 'attorney', - 'au', 'auction', 'audio', 'autos', 'aw', 'ax', 'axa', 'az', + "abogado", + "ac", + "academy", + "accountants", + "active", + "actor", + "ad", + "adult", + "ae", + "aero", + "af", + "ag", + "agency", + "ai", + "airforce", + "al", + "allfinanz", + "alsace", + "am", + "amsterdam", + "an", + "android", + "ao", + "apartments", + "aq", + "aquarelle", + "ar", + "archi", + "army", + "arpa", + "as", + "asia", + "associates", + "at", + "attorney", + "au", + "auction", + "audio", + "autos", + "aw", + "ax", + "axa", + "az", # b - 'ba', 'band', 'bank', 'bar', 'barclaycard', 'barclays', - 'bargains', 'bayern', 'bb', 'bd', 'be', 'beer', 'berlin', 'best', - 'bf', 'bg', 'bh', 'bi', 'bid', 'bike', 'bingo', 'bio', 'biz', - 'bj', 'black', 'blackfriday', 'bloomberg', 'blue', 'bm', 'bmw', - 'bn', 'bnpparibas', 'bo', 'boo', 'boutique', 'br', 'brussels', - 'bs', 'bt', 'budapest', 'build', 'builders', 'business', 'buzz', - 'bv', 'bw', 'by', 'bz', 'bzh', + "ba", + "band", + "bank", + "bar", + "barclaycard", + "barclays", + "bargains", + "bayern", + "bb", + "bd", + "be", + "beer", + "berlin", + "best", + "bf", + "bg", + "bh", + "bi", + "bid", + "bike", + "bingo", + "bio", + "biz", + "bj", + "black", + "blackfriday", + "bloomberg", + "blue", + "bm", + "bmw", + "bn", + "bnpparibas", + "bo", + "boo", + "boutique", + "br", + "brussels", + "bs", + "bt", + "budapest", + "build", + "builders", + "business", + "buzz", + "bv", + "bw", + "by", + "bz", + "bzh", # c - 'ca', 'cab', 'cal', 'camera', 'camp', 'cancerresearch', 'canon', - 'capetown', 'capital', 'caravan', 'cards', 'care', 'career', - 'careers', 'cartier', 'casa', 'cash', 'casino', 'cat', - 'catering', 'cbn', 'cc', 'cd', 'center', 'ceo', 'cern', 'cf', - 'cg', 'ch', 'channel', 'chat', 'cheap', 'christmas', 'chrome', - 'church', 'ci', 'citic', 'city', 'ck', 'cl', 'claims', - 'cleaning', 'click', 'clinic', 'clothing', 'club', 'cm', 'cn', - 'co', 'coach', 'codes', 'coffee', 'college', 'cologne', 'com', - 'community', 'company', 'computer', 'condos', 'construction', - 'consulting', 'contractors', 'cooking', 'cool', 'coop', - 'country', 'cr', 'credit', 'creditcard', 'cricket', 'crs', - 'cruises', 'cu', 'cuisinella', 'cv', 'cw', 'cx', 'cy', 'cymru', - 'cz', + "ca", + "cab", + "cal", + "camera", + "camp", + "cancerresearch", + "canon", + "capetown", + "capital", + "caravan", + "cards", + "care", + "career", + "careers", + "cartier", + "casa", + "cash", + "casino", + "cat", + "catering", + "cbn", + "cc", + "cd", + "center", + "ceo", + "cern", + "cf", + "cg", + "ch", + "channel", + "chat", + "cheap", + "christmas", + "chrome", + "church", + "ci", + "citic", + "city", + "ck", + "cl", + "claims", + "cleaning", + "click", + "clinic", + "clothing", + "club", + "cm", + "cn", + "co", + "coach", + "codes", + "coffee", + "college", + "cologne", + "com", + "community", + "company", + "computer", + "condos", + "construction", + "consulting", + "contractors", + "cooking", + "cool", + "coop", + "country", + "cr", + "credit", + "creditcard", + "cricket", + "crs", + "cruises", + "cu", + "cuisinella", + "cv", + "cw", + "cx", + "cy", + "cymru", + "cz", # d - 'dabur', 'dad', 'dance', 'dating', 'day', 'dclk', 'de', 'deals', - 'degree', 'delivery', 'democrat', 'dental', 'dentist', 'desi', - 'design', 'dev', 'diamonds', 'diet', 'digital', 'direct', - 'directory', 'discount', 'dj', 'dk', 'dm', 'dnp', 'do', 'docs', - 'domains', 'doosan', 'durban', 'dvag', 'dz', + "dabur", + "dad", + "dance", + "dating", + "day", + "dclk", + "de", + "deals", + "degree", + "delivery", + "democrat", + "dental", + "dentist", + "desi", + "design", + "dev", + "diamonds", + "diet", + "digital", + "direct", + "directory", + "discount", + "dj", + "dk", + "dm", + "dnp", + "do", + "docs", + "domains", + "doosan", + "durban", + "dvag", + "dz", # e - 'eat', 'ec', 'edu', 'education', 'ee', 'eg', 'email', 'emerck', - 'energy', 'engineer', 'engineering', 'enterprises', 'equipment', - 'er', 'es', 'esq', 'estate', 'et', 'eu', 'eurovision', 'eus', - 'events', 'everbank', 'exchange', 'expert', 'exposed', + "eat", + "ec", + "edu", + "education", + "ee", + "eg", + "email", + "emerck", + "energy", + "engineer", + "engineering", + "enterprises", + "equipment", + "er", + "es", + "esq", + "estate", + "et", + "eu", + "eurovision", + "eus", + "events", + "everbank", + "exchange", + "expert", + "exposed", # f - 'fail', 'fans', 'farm', 'fashion', 'feedback', 'fi', 'finance', - 'financial', 'firmdale', 'fish', 'fishing', 'fit', 'fitness', - 'fj', 'fk', 'flights', 'florist', 'flowers', 'flsmidth', 'fly', - 'fm', 'fo', 'foo', 'football', 'forsale', 'foundation', 'fr', - 'frl', 'frogans', 'fund', 'furniture', 'futbol', + "fail", + "fans", + "farm", + "fashion", + "feedback", + "fi", + "finance", + "financial", + "firmdale", + "fish", + "fishing", + "fit", + "fitness", + "fj", + "fk", + "flights", + "florist", + "flowers", + "flsmidth", + "fly", + "fm", + "fo", + "foo", + "football", + "forsale", + "foundation", + "fr", + "frl", + "frogans", + "fund", + "furniture", + "futbol", # g - 'ga', 'gal', 'gallery', 'garden', 'gb', 'gbiz', 'gd', 'gdn', - 'ge', 'gent', 'gf', 'gg', 'ggee', 'gh', 'gi', 'gift', 'gifts', - 'gives', 'gl', 'glass', 'gle', 'global', 'globo', 'gm', 'gmail', - 'gmo', 'gmx', 'gn', 'goldpoint', 'goog', 'google', 'gop', 'gov', - 'gp', 'gq', 'gr', 'graphics', 'gratis', 'green', 'gripe', 'gs', - 'gt', 'gu', 'guide', 'guitars', 'guru', 'gw', 'gy', + "ga", + "gal", + "gallery", + "garden", + "gb", + "gbiz", + "gd", + "gdn", + "ge", + "gent", + "gf", + "gg", + "ggee", + "gh", + "gi", + "gift", + "gifts", + "gives", + "gl", + "glass", + "gle", + "global", + "globo", + "gm", + "gmail", + "gmo", + "gmx", + "gn", + "goldpoint", + "goog", + "google", + "gop", + "gov", + "gp", + "gq", + "gr", + "graphics", + "gratis", + "green", + "gripe", + "gs", + "gt", + "gu", + "guide", + "guitars", + "guru", + "gw", + "gy", # h - 'hamburg', 'hangout', 'haus', 'healthcare', 'help', 'here', - 'hermes', 'hiphop', 'hiv', 'hk', 'hm', 'hn', 'holdings', - 'holiday', 'homes', 'horse', 'host', 'hosting', 'house', 'how', - 'hr', 'ht', 'hu', + "hamburg", + "hangout", + "haus", + "healthcare", + "help", + "here", + "hermes", + "hiphop", + "hiv", + "hk", + "hm", + "hn", + "holdings", + "holiday", + "homes", + "horse", + "host", + "hosting", + "house", + "how", + "hr", + "ht", + "hu", # i - 'ibm', 'id', 'ie', 'ifm', 'il', 'im', 'immo', 'immobilien', 'in', - 'industries', 'info', 'ing', 'ink', 'institute', 'insure', 'int', - 'international', 'investments', 'io', 'iq', 'ir', 'irish', 'is', - 'it', 'iwc', + "ibm", + "id", + "ie", + "ifm", + "il", + "im", + "immo", + "immobilien", + "in", + "industries", + "info", + "ing", + "ink", + "institute", + "insure", + "int", + "international", + "investments", + "io", + "iq", + "ir", + "irish", + "is", + "it", + "iwc", # j - 'jcb', 'je', 'jetzt', 'jm', 'jo', 'jobs', 'joburg', 'jp', - 'juegos', + "jcb", + "je", + "jetzt", + "jm", + "jo", + "jobs", + "joburg", + "jp", + "juegos", # k - 'kaufen', 'kddi', 'ke', 'kg', 'kh', 'ki', 'kim', 'kitchen', - 'kiwi', 'km', 'kn', 'koeln', 'kp', 'kr', 'krd', 'kred', 'kw', - 'ky', 'kyoto', 'kz', + "kaufen", + "kddi", + "ke", + "kg", + "kh", + "ki", + "kim", + "kitchen", + "kiwi", + "km", + "kn", + "koeln", + "kp", + "kr", + "krd", + "kred", + "kw", + "ky", + "kyoto", + "kz", # l - 'la', 'lacaixa', 'land', 'lat', 'latrobe', 'lawyer', 'lb', 'lc', - 'lds', 'lease', 'legal', 'lgbt', 'li', 'lidl', 'life', - 'lighting', 'limited', 'limo', 'link', 'lk', 'loans', - 'localhost', 'london', 'lotte', 'lotto', 'lr', 'ls', 'lt', - 'ltda', 'lu', 'luxe', 'luxury', 'lv', 'ly', + "la", + "lacaixa", + "land", + "lat", + "latrobe", + "lawyer", + "lb", + "lc", + "lds", + "lease", + "legal", + "lgbt", + "li", + "lidl", + "life", + "lighting", + "limited", + "limo", + "link", + "lk", + "loans", + "localhost", + "london", + "lotte", + "lotto", + "lr", + "ls", + "lt", + "ltda", + "lu", + "luxe", + "luxury", + "lv", + "ly", # m - 'ma', 'madrid', 'maison', 'management', 'mango', 'market', - 'marketing', 'marriott', 'mc', 'md', 'me', 'media', 'meet', - 'melbourne', 'meme', 'memorial', 'menu', 'mg', 'mh', 'miami', - 'mil', 'mini', 'mk', 'ml', 'mm', 'mn', 'mo', 'mobi', 'moda', - 'moe', 'monash', 'money', 'mormon', 'mortgage', 'moscow', - 'motorcycles', 'mov', 'mp', 'mq', 'mr', 'ms', 'mt', 'mu', - 'museum', 'mv', 'mw', 'mx', 'my', 'mz', + "ma", + "madrid", + "maison", + "management", + "mango", + "market", + "marketing", + "marriott", + "mc", + "md", + "me", + "media", + "meet", + "melbourne", + "meme", + "memorial", + "menu", + "mg", + "mh", + "miami", + "mil", + "mini", + "mk", + "ml", + "mm", + "mn", + "mo", + "mobi", + "moda", + "moe", + "monash", + "money", + "mormon", + "mortgage", + "moscow", + "motorcycles", + "mov", + "mp", + "mq", + "mr", + "ms", + "mt", + "mu", + "museum", + "mv", + "mw", + "mx", + "my", + "mz", # n - 'na', 'nagoya', 'name', 'navy', 'nc', 'ne', 'net', 'network', - 'neustar', 'new', 'nexus', 'nf', 'ng', 'ngo', 'nhk', 'ni', - 'nico', 'ninja', 'nl', 'no', 'np', 'nr', 'nra', 'nrw', 'ntt', - 'nu', 'nyc', 'nz', + "na", + "nagoya", + "name", + "navy", + "nc", + "ne", + "net", + "network", + "neustar", + "new", + "nexus", + "nf", + "ng", + "ngo", + "nhk", + "ni", + "nico", + "ninja", + "nl", + "no", + "np", + "nr", + "nra", + "nrw", + "ntt", + "nu", + "nyc", + "nz", # o - 'okinawa', 'om', 'one', 'ong', 'onl', 'ooo', 'org', 'organic', - 'osaka', 'otsuka', 'ovh', + "okinawa", + "om", + "one", + "ong", + "onl", + "ooo", + "org", + "organic", + "osaka", + "otsuka", + "ovh", # p - 'pa', 'paris', 'partners', 'parts', 'party', 'pe', 'pf', 'pg', - 'ph', 'pharmacy', 'photo', 'photography', 'photos', 'physio', - 'pics', 'pictures', 'pink', 'pizza', 'pk', 'pl', 'place', - 'plumbing', 'pm', 'pn', 'pohl', 'poker', 'porn', 'post', 'pr', - 'praxi', 'press', 'pro', 'prod', 'productions', 'prof', - 'properties', 'property', 'ps', 'pt', 'pub', 'pw', 'py', + "pa", + "paris", + "partners", + "parts", + "party", + "pe", + "pf", + "pg", + "ph", + "pharmacy", + "photo", + "photography", + "photos", + "physio", + "pics", + "pictures", + "pink", + "pizza", + "pk", + "pl", + "place", + "plumbing", + "pm", + "pn", + "pohl", + "poker", + "porn", + "post", + "pr", + "praxi", + "press", + "pro", + "prod", + "productions", + "prof", + "properties", + "property", + "ps", + "pt", + "pub", + "pw", + "py", # q - 'qa', 'qpon', 'quebec', + "qa", + "qpon", + "quebec", # r - 're', 'realtor', 'recipes', 'red', 'rehab', 'reise', 'reisen', - 'reit', 'ren', 'rentals', 'repair', 'report', 'republican', - 'rest', 'restaurant', 'reviews', 'rich', 'rio', 'rip', 'ro', - 'rocks', 'rodeo', 'rs', 'rsvp', 'ru', 'ruhr', 'rw', 'ryukyu', + "re", + "realtor", + "recipes", + "red", + "rehab", + "reise", + "reisen", + "reit", + "ren", + "rentals", + "repair", + "report", + "republican", + "rest", + "restaurant", + "reviews", + "rich", + "rio", + "rip", + "ro", + "rocks", + "rodeo", + "rs", + "rsvp", + "ru", + "ruhr", + "rw", + "ryukyu", # s - 'sa', 'saarland', 'sale', 'samsung', 'sarl', 'saxo', 'sb', 'sc', - 'sca', 'scb', 'schmidt', 'school', 'schule', 'schwarz', - 'science', 'scot', 'sd', 'se', 'services', 'sew', 'sexy', 'sg', - 'sh', 'shiksha', 'shoes', 'shriram', 'si', 'singles', 'sj', 'sk', - 'sky', 'sl', 'sm', 'sn', 'so', 'social', 'software', 'sohu', - 'solar', 'solutions', 'soy', 'space', 'spiegel', 'sr', 'st', - 'style', 'su', 'supplies', 'supply', 'support', 'surf', - 'surgery', 'suzuki', 'sv', 'sx', 'sy', 'sydney', 'systems', 'sz', + "sa", + "saarland", + "sale", + "samsung", + "sarl", + "saxo", + "sb", + "sc", + "sca", + "scb", + "schmidt", + "school", + "schule", + "schwarz", + "science", + "scot", + "sd", + "se", + "services", + "sew", + "sexy", + "sg", + "sh", + "shiksha", + "shoes", + "shriram", + "si", + "singles", + "sj", + "sk", + "sky", + "sl", + "sm", + "sn", + "so", + "social", + "software", + "sohu", + "solar", + "solutions", + "soy", + "space", + "spiegel", + "sr", + "st", + "style", + "su", + "supplies", + "supply", + "support", + "surf", + "surgery", + "suzuki", + "sv", + "sx", + "sy", + "sydney", + "systems", + "sz", # t - 'taipei', 'tatar', 'tattoo', 'tax', 'tc', 'td', 'technology', - 'tel', 'temasek', 'tennis', 'tf', 'tg', 'th', 'tienda', 'tips', - 'tires', 'tirol', 'tj', 'tk', 'tl', 'tm', 'tn', 'to', 'today', - 'tokyo', 'tools', 'top', 'toshiba', 'town', 'toys', 'tp', 'tr', - 'trade', 'training', 'travel', 'trust', 'tt', 'tui', 'tv', 'tw', - 'tz', + "taipei", + "tatar", + "tattoo", + "tax", + "tc", + "td", + "technology", + "tel", + "temasek", + "tennis", + "tf", + "tg", + "th", + "tienda", + "tips", + "tires", + "tirol", + "tj", + "tk", + "tl", + "tm", + "tn", + "to", + "today", + "tokyo", + "tools", + "top", + "toshiba", + "town", + "toys", + "tp", + "tr", + "trade", + "training", + "travel", + "trust", + "tt", + "tui", + "tv", + "tw", + "tz", # u - 'ua', 'ug', 'uk', 'university', 'uno', 'uol', 'us', 'uy', 'uz', + "ua", + "ug", + "uk", + "university", + "uno", + "uol", + "us", + "uy", + "uz", # v - 'va', 'vacations', 'vc', 've', 'vegas', 'ventures', - 'versicherung', 'vet', 'vg', 'vi', 'viajes', 'video', 'villas', - 'vision', 'vlaanderen', 'vn', 'vodka', 'vote', 'voting', 'voto', - 'voyage', 'vu', + "va", + "vacations", + "vc", + "ve", + "vegas", + "ventures", + "versicherung", + "vet", + "vg", + "vi", + "viajes", + "video", + "villas", + "vision", + "vlaanderen", + "vn", + "vodka", + "vote", + "voting", + "voto", + "voyage", + "vu", # w - 'wales', 'wang', 'watch', 'webcam', 'website', 'wed', 'wedding', - 'wf', 'whoswho', 'wien', 'wiki', 'williamhill', 'wme', 'work', - 'works', 'world', 'ws', 'wtc', 'wtf', + "wales", + "wang", + "watch", + "webcam", + "website", + "wed", + "wedding", + "wf", + "whoswho", + "wien", + "wiki", + "williamhill", + "wme", + "work", + "works", + "world", + "ws", + "wtc", + "wtf", # x - 'xn--1qqw23a', 'xn--3bst00m', 'xn--3ds443g', 'xn--3e0b707e', - 'xn--45brj9c', 'xn--45q11c', 'xn--4gbrim', 'xn--55qw42g', - 'xn--55qx5d', 'xn--6frz82g', 'xn--6qq986b3xl', 'xn--80adxhks', - 'xn--80ao21a', 'xn--80asehdb', 'xn--80aswg', 'xn--90a3ac', - 'xn--90ais', 'xn--b4w605ferd', 'xn--c1avg', 'xn--cg4bki', - 'xn--clchc0ea0b2g2a9gcd', 'xn--czr694b', 'xn--czrs0t', - 'xn--czru2d', 'xn--d1acj3b', 'xn--d1alf', 'xn--fiq228c5hs', - 'xn--fiq64b', 'xn--fiqs8s', 'xn--fiqz9s', 'xn--flw351e', - 'xn--fpcrj9c3d', 'xn--fzc2c9e2c', 'xn--gecrj9c', 'xn--h2brj9c', - 'xn--hxt814e', 'xn--i1b6b1a6a2e', 'xn--io0a7i', 'xn--j1amh', - 'xn--j6w193g', 'xn--kprw13d', 'xn--kpry57d', 'xn--kput3i', - 'xn--l1acc', 'xn--lgbbat1ad8j', 'xn--mgb9awbf', - 'xn--mgba3a4f16a', 'xn--mgbaam7a8h', 'xn--mgbab2bd', - 'xn--mgbayh7gpa', 'xn--mgbbh1a71e', 'xn--mgbc0a9azcg', - 'xn--mgberp4a5d4ar', 'xn--mgbx4cd0ab', 'xn--ngbc5azd', - 'xn--node', 'xn--nqv7f', 'xn--nqv7fs00ema', 'xn--o3cw4h', - 'xn--ogbpf8fl', 'xn--p1acf', 'xn--p1ai', 'xn--pgbs0dh', - 'xn--q9jyb4c', 'xn--qcka1pmc', 'xn--rhqv96g', 'xn--s9brj9c', - 'xn--ses554g', 'xn--unup4y', 'xn--vermgensberater-ctb', - 'xn--vermgensberatung-pwb', 'xn--vhquv', 'xn--wgbh1c', - 'xn--wgbl6a', 'xn--xhq521b', 'xn--xkc2al3hye2a', - 'xn--xkc2dl3a5ee0h', 'xn--yfro4i67o', 'xn--ygbi2ammx', - 'xn--zfr164b', 'xxx', 'xyz', + "xn--1qqw23a", + "xn--3bst00m", + "xn--3ds443g", + "xn--3e0b707e", + "xn--45brj9c", + "xn--45q11c", + "xn--4gbrim", + "xn--55qw42g", + "xn--55qx5d", + "xn--6frz82g", + "xn--6qq986b3xl", + "xn--80adxhks", + "xn--80ao21a", + "xn--80asehdb", + "xn--80aswg", + "xn--90a3ac", + "xn--90ais", + "xn--b4w605ferd", + "xn--c1avg", + "xn--cg4bki", + "xn--clchc0ea0b2g2a9gcd", + "xn--czr694b", + "xn--czrs0t", + "xn--czru2d", + "xn--d1acj3b", + "xn--d1alf", + "xn--fiq228c5hs", + "xn--fiq64b", + "xn--fiqs8s", + "xn--fiqz9s", + "xn--flw351e", + "xn--fpcrj9c3d", + "xn--fzc2c9e2c", + "xn--gecrj9c", + "xn--h2brj9c", + "xn--hxt814e", + "xn--i1b6b1a6a2e", + "xn--io0a7i", + "xn--j1amh", + "xn--j6w193g", + "xn--kprw13d", + "xn--kpry57d", + "xn--kput3i", + "xn--l1acc", + "xn--lgbbat1ad8j", + "xn--mgb9awbf", + "xn--mgba3a4f16a", + "xn--mgbaam7a8h", + "xn--mgbab2bd", + "xn--mgbayh7gpa", + "xn--mgbbh1a71e", + "xn--mgbc0a9azcg", + "xn--mgberp4a5d4ar", + "xn--mgbx4cd0ab", + "xn--ngbc5azd", + "xn--node", + "xn--nqv7f", + "xn--nqv7fs00ema", + "xn--o3cw4h", + "xn--ogbpf8fl", + "xn--p1acf", + "xn--p1ai", + "xn--pgbs0dh", + "xn--q9jyb4c", + "xn--qcka1pmc", + "xn--rhqv96g", + "xn--s9brj9c", + "xn--ses554g", + "xn--unup4y", + "xn--vermgensberater-ctb", + "xn--vermgensberatung-pwb", + "xn--vhquv", + "xn--wgbh1c", + "xn--wgbl6a", + "xn--xhq521b", + "xn--xkc2al3hye2a", + "xn--xkc2dl3a5ee0h", + "xn--yfro4i67o", + "xn--ygbi2ammx", + "xn--zfr164b", + "xxx", + "xyz", # y - 'yachts', 'yandex', 'ye', 'yodobashi', 'yoga', 'yokohama', - 'youtube', 'yt', + "yachts", + "yandex", + "ye", + "yodobashi", + "yoga", + "yokohama", + "youtube", + "yt", # z - 'za', 'zip', 'zm', 'zone', 'zuerich', 'zw' + "za", + "zip", + "zm", + "zone", + "zuerich", + "zw", ] diff --git a/emmett/validators/inside.py b/emmett/validators/inside.py index 6d3eaec2..c391909d 100644 --- a/emmett/validators/inside.py +++ b/emmett/validators/inside.py @@ -1,22 +1,23 @@ # -*- coding: utf-8 -*- """ - emmett.validators.inside - ------------------------ +emmett.validators.inside +------------------------ - Validators that check presence/absence of given value in a set. +Validators that check presence/absence of given value in a set. - :copyright: 2014 Giovanni Barillari +:copyright: 2014 Giovanni Barillari - Based on the web2py's validators (http://www.web2py.com) - :copyright: (c) by Massimo Di Pierro +Based on the web2py's validators (http://www.web2py.com) +:copyright: (c) by Massimo Di Pierro - :license: LGPLv3 (http://www.gnu.org/licenses/lgpl.html) +:license: LGPLv3 (http://www.gnu.org/licenses/lgpl.html) """ +from emmett_core.utils import cachedprop + # TODO: check unicode conversions from .._shortcuts import to_unicode from ..ctx import current -from ..utils import cachedprop from .basic import Validator from .helpers import options_sorter, translate @@ -41,14 +42,11 @@ def _lt(self, val1, val2, eq=False): def __call__(self, value): minimum = self.minimum() if callable(self.minimum) else self.minimum maximum = self.maximum() if callable(self.maximum) else self.maximum - if ( - (minimum is None or self._gt(value, minimum, self.inc[0])) and - (maximum is None or self._lt(value, maximum, self.inc[1])) + if (minimum is None or self._gt(value, minimum, self.inc[0])) and ( + maximum is None or self._lt(value, maximum, self.inc[1]) ): return value, None - return value, translate( - self._range_error(self.message, minimum, maximum) - ) + return value, translate(self._range_error(self.message, minimum, maximum)) def _range_error(self, message, minimum, maximum): if message is None: @@ -65,22 +63,14 @@ def _range_error(self, message, minimum, maximum): class inSet(Validator): - def __init__( - self, - theset, - labels=None, - multiple=False, - zero=None, - sort=False, - message=None - ): + def __init__(self, theset, labels=None, multiple=False, zero=None, sort=False, message=None): super().__init__(message=message) self.multiple = multiple if ( - theset and - isinstance(theset, (tuple, list)) and - isinstance(theset[0], (tuple, list)) and - len(theset[0]) == 2 + theset + and isinstance(theset, (tuple, list)) + and isinstance(theset[0], (tuple, list)) + and len(theset[0]) == 2 ): lset, llabels = [], [] for item, label in theset: @@ -88,10 +78,7 @@ def __init__( llabels.append(str(label)) self.theset = lset self.labels = llabels - elif ( - theset and - isinstance(theset, dict) - ): + elif theset and isinstance(theset, dict): lset, llabels = [], [] for item, label in theset.items(): lset.append(str(item)) @@ -112,7 +99,7 @@ def options(self, zero=True): if self.sort: items.sort(options_sorter) if zero and self.zero is not None and not self.multiple: - items.insert(0, ('', self.zero)) + items.insert(0, ("", self.zero)) return items def __call__(self, value): @@ -125,25 +112,20 @@ def __call__(self, value): values = [value] else: values = [value] - failures = [ - x for x in values - if (to_unicode(x) or '') not in self.theset] + failures = [x for x in values if (to_unicode(x) or "") not in self.theset] if failures and self.theset: - if self.multiple and (value is None or value == ''): + if self.multiple and (value is None or value == ""): return ([], None) return value, translate(self.message) if self.multiple: - if ( - isinstance(self.multiple, (tuple, list)) and - not self.multiple[0] <= len(values) < self.multiple[1] - ): + if isinstance(self.multiple, (tuple, list)) and not self.multiple[0] <= len(values) < self.multiple[1]: return values, translate(self.message) return values, None return value, None class DBValidator(Validator): - def __init__(self, db, tablename, fieldname='id', dbset=None, message=None): + def __init__(self, db, tablename, fieldname="id", dbset=None, message=None): super().__init__(message=message) self.db = db self.tablename = tablename @@ -170,23 +152,9 @@ def field(self): class inDB(DBValidator): def __init__( - self, - db, - tablename, - fieldname='id', - dbset=None, - label_field=None, - multiple=False, - orderby=None, - message=None + self, db, tablename, fieldname="id", dbset=None, label_field=None, multiple=False, orderby=None, message=None ): - super().__init__( - db, - tablename, - fieldname=fieldname, - dbset=dbset, - message=message - ) + super().__init__(db, tablename, fieldname=fieldname, dbset=dbset, message=message) self.label_field = label_field self.multiple = multiple self.orderby = orderby @@ -208,18 +176,16 @@ def options(self, zero=True): items = [(r.id, self.db[self.tablename]._format % r) for r in records] else: items = [(r.id, r.id) for r in records] - #if self.sort: + # if self.sort: # items.sort(options_sorter) - #if zero and self.zero is not None and not self.multiple: + # if zero and self.zero is not None and not self.multiple: # items.insert(0, ('', self.zero)) return items def __call__(self, value): if self.multiple: values = value if isinstance(value, list) else [value] - records = self.dbset.where( - self.field.belongs(values) - ).select(self.field, distinct=True).column(self.field) + records = self.dbset.where(self.field.belongs(values)).select(self.field, distinct=True).column(self.field) if set(values).issubset(set(records)): return values, None else: @@ -230,11 +196,9 @@ def __call__(self, value): class notInDB(DBValidator): def __call__(self, value): - row = self.dbset.where( - self.field == value - ).select(limitby=(0, 1)).first() + row = self.dbset.where(self.field == value).select(limitby=(0, 1)).first() if row: - record_id = getattr(current, '_dbvalidation_record_id_', None) + record_id = getattr(current, "_dbvalidation_record_id_", None) if row.id != record_id: return value, translate(self.message) return value, None diff --git a/emmett/validators/process.py b/emmett/validators/process.py index 73eb4341..8e3b852b 100644 --- a/emmett/validators/process.py +++ b/emmett/validators/process.py @@ -1,14 +1,14 @@ # -*- coding: utf-8 -*- """ - emmett.validators.process - ------------------------- +emmett.validators.process +------------------------- - Validators that transform values. +Validators that transform values. - Ported from the original validators of web2py (http://www.web2py.com) +Ported from the original validators of web2py (http://www.web2py.com) - :copyright: (c) by Massimo Di Pierro - :license: LGPLv3 (http://www.gnu.org/licenses/lgpl.html) +:copyright: (c) by Massimo Di Pierro +:license: LGPLv3 (http://www.gnu.org/licenses/lgpl.html) """ import re @@ -17,7 +17,7 @@ # TODO: check unicode conversions from .._shortcuts import to_unicode from .basic import Validator -from .helpers import translate, LazyCrypt +from .helpers import LazyCrypt, translate class Cleanup(Validator): @@ -28,7 +28,7 @@ def __init__(self, regex=None, message=None): self.regex = self.rule if regex is None else re.compile(regex) def __call__(self, value): - v = self.regex.sub('', (to_unicode(value) or '').strip()) + v = self.regex.sub("", (to_unicode(value) or "").strip()) return v, None @@ -91,7 +91,7 @@ def _urlify(self, s): # remove leading and trailing hyphens s = s.strip(r"-") # enforce maximum length - return s[:self.maxlen] + return s[: self.maxlen] class Crypt(Validator): @@ -126,16 +126,14 @@ class Crypt(Validator): an existing salted password """ - def __init__( - self, key=None, algorithm='pbkdf2(1000,20,sha512)', salt=True, message=None - ): + def __init__(self, key=None, algorithm="pbkdf2(1000,20,sha512)", salt=True, message=None): super().__init__(message=message) self.key = key self.digest_alg = algorithm self.salt = salt def __call__(self, value): - if getattr(value, '_emt_field_hashed_contents_', False): + if getattr(value, "_emt_field_hashed_contents_", False): return value, None crypt = LazyCrypt(self, value) if isinstance(value, LazyCrypt) and value == crypt: diff --git a/emmett/wrappers/__init__.py b/emmett/wrappers/__init__.py index 6d6abd35..e69de29b 100644 --- a/emmett/wrappers/__init__.py +++ b/emmett/wrappers/__init__.py @@ -1,86 +0,0 @@ -# -*- coding: utf-8 -*- -""" - emmett.wrappers - --------------- - - Provides request and response wrappers. - - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause -""" - -from __future__ import annotations - -import re - -from abc import ABCMeta, abstractmethod -from http.cookies import SimpleCookie -from typing import Any, List, Mapping, Type, TypeVar, Union - -from ..datastructures import Accept, sdict -from ..language.helpers import LanguageAccept -from ..typing import T -from ..utils import cachedprop - -AcceptType = TypeVar("AcceptType", bound=Accept) - -_regex_accept = re.compile(r''' - ([^\s;,]+(?:[ \t]*;[ \t]*(?:[^\s;,q][^\s;,]*|q[^\s;,=][^\s;,]*))*) - (?:[ \t]*;[ \t]*q=(\d*(?:\.\d+)?)[^,]*)?''', re.VERBOSE) - - -class Wrapper: - def __getitem__(self, name: str) -> Any: - return getattr(self, name, None) - - def __setitem__(self, name: str, value: Any): - setattr(self, name, value) - - -class IngressWrapper(Wrapper, metaclass=ABCMeta): - __slots__ = ['scheme', 'path'] - - scheme: str - path: str - - @property - @abstractmethod - def headers(self) -> Mapping[str, str]: ... - - @cachedprop - def host(self) -> str: - return self.headers.get('host') - - def __parse_accept_header( - self, - value: str, - cls: Type[AcceptType] - ) -> AcceptType: - if not value: - return cls(None) - result = [] - for match in _regex_accept.finditer(value): - mq = match.group(2) - if not mq: - quality = 1.0 - else: - quality = max(min(float(mq), 1), 0) - result.append((match.group(1), quality)) - return cls(result) - - @cachedprop - def accept_language(self) -> LanguageAccept: - return self.__parse_accept_header( - self.headers.get('accept-language'), LanguageAccept - ) - - @cachedprop - def cookies(self) -> SimpleCookie: - cookies: SimpleCookie = SimpleCookie() - for cookie in self.headers.get('cookie', '').split(';'): - cookies.load(cookie) - return cookies - - @property - @abstractmethod - def query_params(self) -> sdict[str, Union[str, List[str]]]: ... diff --git a/emmett/wrappers/helpers.py b/emmett/wrappers/helpers.py deleted file mode 100644 index c3ae3a9b..00000000 --- a/emmett/wrappers/helpers.py +++ /dev/null @@ -1,115 +0,0 @@ -# -*- coding: utf-8 -*- -""" - emmett.wrappers.helpers - ----------------------- - - Provides wrappers helpers. - - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause -""" - -import re - -from typing import ( - BinaryIO, - Dict, - Iterable, - Iterator, - MutableMapping, - Optional, - Tuple, - Union -) - -from .._internal import loop_copyfileobj - -regex_client = re.compile(r'[\w\-:]+(\.[\w\-]+)*\.?') - - -class ResponseHeaders(MutableMapping[str, str]): - __slots__ = ['_data'] - - def __init__(self, data: Optional[Dict[str, str]] = None): - self._data = data or {} - - __hash__ = None # type: ignore - - def __getitem__(self, key: str) -> str: - return self._data[key.lower()] - - def __setitem__(self, key: str, value: str): - self._data[key.lower()] = value - - def __delitem__(self, key: str): - del self._data[key.lower()] - - def __contains__(self, key: str) -> bool: # type: ignore - return key.lower() in self._data - - def __iter__(self) -> Iterator[str]: - for key in self._data.keys(): - yield key - - def __len__(self) -> int: - return len(self._data) - - def items(self) -> Iterator[Tuple[str, str]]: # type: ignore - for key, value in self._data.items(): - yield key, value - - def keys(self) -> Iterator[str]: # type: ignore - for key in self._data.keys(): - yield key - - def values(self) -> Iterator[str]: # type: ignore - for value in self._data.values(): - yield value - - def update(self, data: Dict[str, str]): # type: ignore - self._data.update(data) - - -class FileStorage: - __slots__ = ('stream', 'filename', 'name', 'headers', 'content_type') - - def __init__( - self, - stream: BinaryIO, - filename: str, - name: str = None, - content_type: str = None, - headers: Dict = None - ): - self.stream = stream - self.filename = filename - self.name = name - self.headers = headers or {} - self.content_type = content_type or self.headers.get('content-type') - - @property - def content_length(self) -> int: - return int(self.headers.get('content-length', 0)) - - async def save( - self, - destination: Union[BinaryIO, str], - buffer_size: int = 16384 - ): - close_destination = False - if isinstance(destination, str): - destination = open(destination, 'wb') - close_destination = True - try: - await loop_copyfileobj(self.stream, destination, buffer_size) - finally: - if close_destination: - destination.close() - - def __iter__(self) -> Iterable[bytes]: - return iter(self.stream) - - def __repr__(self) -> str: - return ( - f'<{self.__class__.__name__}: ' - f'{self.filename} ({self.content_type})') diff --git a/emmett/wrappers/request.py b/emmett/wrappers/request.py index 578219f7..707aa9a9 100644 --- a/emmett/wrappers/request.py +++ b/emmett/wrappers/request.py @@ -1,38 +1,24 @@ # -*- coding: utf-8 -*- """ - emmett.wrappers.request - ----------------------- +emmett.wrappers.request +----------------------- - Provides http request wrappers. +Provides http request wrappers. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ -from abc import abstractmethod -from cgi import FieldStorage, parse_header -from io import BytesIO -from urllib.parse import parse_qs -from typing import Any - import pendulum - -from ..datastructures import sdict -from ..parsers import Parsers -from ..utils import cachedprop -from . import IngressWrapper -from .helpers import FileStorage +from emmett_core.http.wrappers.request import Request as _Request +from emmett_core.utils import cachedprop -class Request(IngressWrapper): - __slots__ = ['_now', 'method'] +class Request(_Request): + __slots__ = [] method: str - @property - @abstractmethod - async def body(self) -> bytes: ... - @cachedprop def now(self) -> pendulum.DateTime: return pendulum.instance(self._now) @@ -40,109 +26,3 @@ def now(self) -> pendulum.DateTime: @cachedprop def now_local(self) -> pendulum.DateTime: return self.now.in_timezone(pendulum.local_timezone()) # type: ignore - - @cachedprop - def content_type(self) -> str: - return parse_header(self.headers.get('content-type', ''))[0] - - @cachedprop - def content_length(self) -> int: - return self.headers.get('content_length', 0, cast=int) - - _empty_body_methods = {v: v for v in ['GET', 'HEAD', 'OPTIONS']} - - @cachedprop - async def _input_params(self): - if self._empty_body_methods.get(self.method) or not self.content_type: - return sdict(), sdict() - return await self._load_params() - - @cachedprop - async def body_params(self) -> sdict[str, Any]: - rv, _ = await self._input_params - return rv - - @cachedprop - async def files(self) -> sdict[str, FileStorage]: - _, rv = await self._input_params - return rv - - def _load_params_missing(self, data): - return sdict(), sdict() - - def _load_params_json(self, data): - try: - params = Parsers.get_for('json')(data) - except Exception: - params = {} - return sdict(params), sdict() - - def _load_params_form_urlencoded(self, data): - rv = sdict() - for key, values in parse_qs( - data.decode('latin-1'), keep_blank_values=True - ).items(): - if len(values) == 1: - rv[key] = values[0] - continue - rv[key] = values - return rv, sdict() - - @property - def _multipart_headers(self): - return self.headers - - @staticmethod - def _file_param_from_field(field): - return FileStorage( - BytesIO(field.file.read()), - field.filename, - field.name, - field.type, - field.headers - ) - - def _load_params_form_multipart(self, data): - params, files = sdict(), sdict() - field_storage = FieldStorage( - BytesIO(data), - headers=self._multipart_headers, - environ={'REQUEST_METHOD': self.method}, - keep_blank_values=True - ) - for key in field_storage: - field = field_storage[key] - if isinstance(field, list): - if len(field) > 1: - pvalues, fvalues = [], [] - for item in field: - if item.filename is not None: - fvalues.append(self._file_param_from_field(item)) - else: - pvalues.append(item.value) - if pvalues: - params[key] = pvalues - if fvalues: - files[key] = fvalues - continue - else: - field = field[0] - if field.filename is not None: - files[key] = self._file_param_from_field(field) - else: - params[key] = field.value - return params, files - - _params_loaders = { - 'application/json': _load_params_json, - 'application/x-www-form-urlencoded': _load_params_form_urlencoded, - 'multipart/form-data': _load_params_form_multipart - } - - async def _load_params(self): - loader = self._params_loaders.get( - self.content_type, self._load_params_missing) - return loader(self, await self.body) - - @abstractmethod - async def push_promise(self, path: str): ... diff --git a/emmett/wrappers/response.py b/emmett/wrappers/response.py index 7609b350..efc86cd4 100644 --- a/emmett/wrappers/response.py +++ b/emmett/wrappers/response.py @@ -1,35 +1,34 @@ # -*- coding: utf-8 -*- """ - emmett.wrappers.response - ------------------------ +emmett.wrappers.response +------------------------ - Provides response wrappers. +Provides response wrappers. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ -from http.cookies import Morsel, SimpleCookie +import os +import re from typing import Any +from emmett_core.http.response import HTTPFileResponse, HTTPResponse +from emmett_core.http.wrappers.response import Response as _Response +from emmett_core.utils import cachedprop +from pydal.exceptions import NotAuthorizedException, NotFoundException + +from ..ctx import current from ..datastructures import sdict -from ..helpers import get_flashed_messages +from ..helpers import abort, get_flashed_messages from ..html import htmlescape -from ..utils import cachedprop -from . import Wrapper -from .helpers import ResponseHeaders -# Workaround for adding samesite support to pre 3.8 python -Morsel._reserved["samesite"] = "SameSite" # type: ignore # noqa +_re_dbstream = re.compile(r"(?P
.*?)\.(?P.*?)\..*") -class Response(Wrapper): - __slots__ = ('status', 'headers', 'cookies') - def __init__(self): - self.status = 200 - self.headers = ResponseHeaders({'content-type': 'text/plain'}) - self.cookies = SimpleCookie() +class Response(_Response): + __slots__ = () @cachedprop def meta(self) -> sdict[str, Any]: @@ -39,23 +38,36 @@ def meta(self) -> sdict[str, Any]: def meta_prop(self) -> sdict[str, Any]: return sdict() - @property - def content_type(self) -> str: - return self.headers['content-type'] - - @content_type.setter - def content_type(self, value: str): - self.headers['content-type'] = value - def alerts(self, **kwargs): return get_flashed_messages(**kwargs) def _meta_tmpl(self): - return [ - (key, htmlescape(val)) for key, val in self.meta.items() - ] + return [(key, htmlescape(val)) for key, val in self.meta.items()] def _meta_tmpl_prop(self): - return [ - (key, htmlescape(val)) for key, val in self.meta_prop.items() - ] + return [(key, htmlescape(val)) for key, val in self.meta_prop.items()] + + def wrap_file(self, path) -> HTTPFileResponse: + path = os.path.join(current.app.root_path, path) + return super().wrap_file(path) + + def wrap_dbfile(self, db, name: str) -> HTTPResponse: + items = _re_dbstream.match(name) + if not items: + abort(404) + table_name, field_name = items.group("table"), items.group("field") + try: + field = db[table_name][field_name] + except AttributeError: + abort(404) + try: + filename, path_or_stream = field.retrieve(name, nameonly=True) + except NotAuthorizedException: + abort(403) + except NotFoundException: + abort(404) + except IOError: + abort(404) + if isinstance(path_or_stream, str): + return self.wrap_file(path_or_stream) + return self.wrap_io(path_or_stream) diff --git a/emmett/wrappers/websocket.py b/emmett/wrappers/websocket.py index 75ebf6de..f1fc10f1 100644 --- a/emmett/wrappers/websocket.py +++ b/emmett/wrappers/websocket.py @@ -1,45 +1,12 @@ # -*- coding: utf-8 -*- """ - emmett.wrappers.websocket - ------------------------- +emmett.wrappers.websocket +------------------------- - Provides http websocket wrappers. +Provides http websocket wrappers. - :copyright: 2014 Giovanni Barillari - :license: BSD-3-Clause +:copyright: 2014 Giovanni Barillari +:license: BSD-3-Clause """ -from abc import abstractmethod -from typing import Any, Dict, Optional - -from . import IngressWrapper - - -class Websocket(IngressWrapper): - __slots__ = ['_flow_receive', '_flow_send', 'receive', 'send'] - - def _bind_flow(self, flow_receive, flow_send): - self._flow_receive = flow_receive - self._flow_send = flow_send - - @abstractmethod - async def accept( - self, - headers: Optional[Dict[str, str]] = None, - subprotocol: Optional[str] = None - ): - ... - - async def _accept_and_receive(self) -> Any: - await self.accept() - return await self.receive() - - async def _accept_and_send(self, data: Any): - await self.accept() - await self.send(data) - - @abstractmethod - async def _wrapped_receive(self) -> Any: ... - - @abstractmethod - async def _wrapped_send(self, data: Any): ... +from emmett_core.http.wrappers.websocket import Websocket as Websocket diff --git a/pyproject.toml b/pyproject.toml index 527c7bb9..89b895a4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,17 +1,18 @@ -[project] -name = "emmett" +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" -[tool.poetry] +[project] name = "emmett" -version = "2.5.13" +version = "2.6.0" description = "The web framework for inventors" -authors = ["Giovanni Barillari "] +readme = "README.md" license = "BSD-3-Clause" +requires-python = ">=3.8" -readme = "README.md" -homepage = "https://emmett.sh" -repository = "https://github.com/emmett-framework/emmett" -documentation = "https://emmett.sh/docs" +authors = [ + { name = "Giovanni Barillari", email = "g@baro.dev" } +] keywords = ["web", "asyncio"] classifiers = [ @@ -27,55 +28,122 @@ classifiers = [ "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", "Topic :: Internet :: WWW/HTTP :: Dynamic Content", "Topic :: Software Development :: Libraries :: Python Modules" ] -packages = [ - {include = "emmett/**/*.*", format = "sdist" }, - {include = "tests", format = "sdist"} +dependencies = [ + "click>=6.0", + "emmett-core[granian,rapidjson]~=1.0.1", + "pendulum~=3.0.0", + "pydal@git+https://github.com/gi0baro/pydal@emmett", + "pyyaml~=6.0", + "renoir~=1.6", + "severus~=1.1", ] -include = [ - "CHANGES.md", - "LICENSE", - "docs/**/*" + +[project.optional-dependencies] +orjson = ["orjson~=3.10"] +uvicorn = [ + "uvicorn~=0.19", + "h11>=0.12", + "websockets~=10.0", + "httptools~=0.6; sys_platform != 'win32'" ] -[tool.poetry.scripts] +[project.urls] +Homepage = 'https://emmett.sh' +Documentation = 'https://emmett.sh/docs' +Funding = 'https://github.com/sponsors/gi0baro' +Source = 'https://github.com/emmett-framework/emmett' +Issues = 'https://github.com/emmett-framework/emmett/issues' + +[project.scripts] emmett = "emmett.cli:main" -[tool.poetry.dependencies] -python = "^3.8" -click = ">=6.0" -granian = "~1.5.0" -emmett-crypto = "^0.6" -pendulum = "~3.0.0" -pyDAL = "17.3" -python-rapidjson = "^1.14" -pyyaml = "^6.0" -renoir = "^1.6" -severus = "^1.1" - -orjson = { version = "~3.10", optional = true } - -uvicorn = { version = "^0.19.0", optional = true } -h11 = { version = ">= 0.12.0", optional = true } -websockets = { version = "^10.0", optional = true } -httptools = { version = "~0.6.0", optional = true, markers = "sys_platform != 'win32'" } - -[tool.poetry.dev-dependencies] -ipaddress = "^1.0" -pytest = "^7.1" -pytest-asyncio = "^0.15" -psycopg2-binary = "^2.9.3" - -[tool.poetry.extras] -orjson = ["orjson"] -uvicorn = ["uvicorn", "h11", "httptools", "websockets"] - -[tool.poetry.urls] -"Issue Tracker" = "https://github.com/emmett-framework/emmett/issues" +[tool.hatch.build.targets.sdist] +include = [ + '/README.md', + '/CHANGES.md', + '/LICENSE', + '/docs', + '/emmett', + '/tests', +] -[build-system] -requires = ["poetry-core>=1.0.0"] -build-backend = "poetry.core.masonry.api" +[tool.hatch.metadata] +allow-direct-references = true + +[tool.ruff] +line-length = 120 + +[tool.ruff.format] +quote-style = 'double' + +[tool.ruff.lint] +extend-select = [ + # E and F are enabled by default + 'B', # flake8-bugbear + 'C4', # flake8-comprehensions + 'C90', # mccabe + 'I', # isort + 'N', # pep8-naming + 'Q', # flake8-quotes + 'RUF100', # ruff (unused noqa) + 'S', # flake8-bandit + 'W', # pycodestyle +] +extend-ignore = [ + 'B006', # mutable function args are fine + 'B008', # function calls in args defaults are fine + 'B009', # getattr with constants is fine + 'B034', # re.split won't confuse us + 'B904', # rising without from is fine + 'E731', # assigning lambdas is fine + 'F403', # import * is fine + 'N801', # leave to us class naming + 'N802', # leave to us method naming + 'N806', # leave to us var naming + 'N811', # leave to us var naming + 'N814', # leave to us var naming + 'N818', # leave to us exceptions naming + 'S101', # assert is fine + 'S104', # leave to us security + 'S105', # leave to us security + 'S106', # leave to us security + 'S107', # leave to us security + 'S110', # pass on exceptions is fine + 'S301', # leave to us security + 'S324', # leave to us security +] +mccabe = { max-complexity = 44 } + +[tool.ruff.lint.isort] +combine-as-imports = true +lines-after-imports = 2 +known-first-party = ['emmett', 'tests'] + +[tool.ruff.lint.per-file-ignores] +'emmett/__init__.py' = ['F401'] +'emmett/http.py' = ['F401'] +'emmett/orm/__init__.py' = ['F401'] +'emmett/orm/engines/__init__.py' = ['F401'] +'emmett/orm/migrations/__init__.py' = ['F401'] +'emmett/orm/migrations/revisions.py' = ['B018'] +'emmett/tools/__init__.py' = ['F401'] +'emmett/tools/auth/__init__.py' = ['F401'] +'emmett/validators/__init__.py' = ['F401'] +'tests/**' = ['B017', 'B018', 'E711', 'E712', 'E741', 'F841', 'S110', 'S501'] + +[tool.pytest.ini_options] +asyncio_mode = 'auto' + +[tool.uv] +dev-dependencies = [ + "ipaddress>=1.0", + "pytest>=7.1", + "pytest-asyncio>=0.15", + "psycopg2-binary~=2.9", + "ruff~=0.5.0", +] diff --git a/tests/helpers.py b/tests/helpers.py index 0be3146d..cedc5ccf 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -1,47 +1,72 @@ # -*- coding: utf-8 -*- """ - tests.helpers - ------------- +tests.helpers +------------- - Tests helpers +Tests helpers """ from contextlib import contextmanager -from emmett.asgi.wrappers import Request, Websocket +from emmett_core.protocols.rsgi.test_client.scope import ScopeBuilder + from emmett.ctx import RequestContext, WSContext, current +from emmett.datastructures import sdict +from emmett.rsgi.wrappers import Request, Websocket from emmett.serializers import Serializers -from emmett.testing.env import ScopeBuilder from emmett.wrappers.response import Response -json_dump = Serializers.get_for('json') + +json_dump = Serializers.get_for("json") class FakeRequestContext(RequestContext): def __init__(self, app, scope): self.app = app - self.request = Request(scope, None, None) + self.request = Request(scope, scope.path, None, None) self.response = Response() self.session = None -class FakeWSContext(WSContext): - def __init__(self, app, scope): - self.app = app - self.websocket = Websocket( - scope, - self.receive, - self.send - ) - self._receive_storage = [] +class FakeWSTransport: + def __init__(self): self._send_storage = [] async def receive(self): - return json_dump({'foo': 'bar'}) + return json_dump({"foo": "bar"}) - async def send(self, data): + async def send_str(self, data): self._send_storage.append(data) + async def send_bytes(self, data): + self._send_storage.append(data) + + +class FakeWsProto: + def __init__(self): + self.transport = None + + async def init(self): + self.transport = FakeWSTransport() + + async def receive(self): + return sdict(data=await self.transport.receive()) + + def close(self): + pass + + +class FakeWSContext(WSContext): + def __init__(self, app, scope): + self.app = app + self._proto = FakeWsProto() + self.websocket = Websocket(scope, scope.path, self._proto) + self._receive_storage = [] + + @property + def _send_storage(self): + return self._proto.transport._send_storage + @contextmanager def current_ctx(path, app=None): @@ -55,7 +80,7 @@ def current_ctx(path, app=None): def ws_ctx(path, app=None): builder = ScopeBuilder(path) scope_data = builder.get_data()[0] - scope_data.update(type='websocket', scheme='wss') + scope_data.proto = "ws" token = current._init_(FakeWSContext(app, scope_data)) yield current current._close_(token) diff --git a/tests/test_auth.py b/tests/test_auth.py index b4debf9b..a37e9197 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -1,63 +1,65 @@ # -*- coding: utf-8 -*- """ - tests.auth - ---------- +tests.auth +---------- - Test Emmett Auth module +Test Emmett Auth module """ import os -import pytest import shutil + +import pytest + from emmett import App -from emmett.orm import Database, Field, Model, has_many, belongs_to +from emmett.orm import Database, Field, Model, belongs_to, has_many from emmett.sessions import SessionManager from emmett.tools import Auth, Mailer from emmett.tools.auth.models import AuthUser class User(AuthUser): - has_many('things') + has_many("things") gender = Field() class Thing(Model): - belongs_to('user') + belongs_to("user") -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def app(): rv = App(__name__) - rv.config.mailer.sender = 'nina@massivedynamics.com' + rv.config.mailer.sender = "nina@massivedynamics.com" rv.config.mailer.suppress = True rv.config.auth.single_template = True rv.config.auth.hmac_key = "foobar" - rv.pipeline = [SessionManager.cookies('foobar')] + rv.pipeline = [SessionManager.cookies("foobar")] return rv -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def client(app): return app.test_client() -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def mailer(app): return Mailer(app) -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def _db(app): try: - shutil.rmtree(os.path.join(app.root_path, 'databases')) - except: + shutil.rmtree(os.path.join(app.root_path, "databases")) + except Exception: pass db = Database(app, auto_migrate=True) app.pipeline.append(db.pipe) return db -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def auth(app, _db, mailer): auth = Auth(app, _db, user_model=User) app.pipeline.append(auth.pipe) @@ -65,7 +67,7 @@ def auth(app, _db, mailer): return auth -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def db(_db, auth): with _db.connection(): _db.define_models(Thing) @@ -75,24 +77,12 @@ def db(_db, auth): def test_models(db): with db.connection(): user = User.create( - email="walter@massivedynamics.com", - password="pocketuniverse", - first_name="Walter", - last_name="Bishop" - ) - group = db.auth_groups.insert( - role="admin" - ) - group2 = db.auth_groups.insert( - role="moderator" - ) - db.auth_memberships.insert( - user=user.id, - auth_group=group - ) - db.auth_permissions.insert( - auth_group=group + email="walter@massivedynamics.com", password="pocketuniverse", first_name="Walter", last_name="Bishop" ) + group = db.auth_groups.insert(role="admin") + group2 = db.auth_groups.insert(role="moderator") + db.auth_memberships.insert(user=user.id, auth_group=group) + db.auth_permissions.insert(auth_group=group) user = db.users[1] assert len(user.auth_memberships()) == 1 assert user.auth_memberships()[0].user == 1 @@ -112,119 +102,128 @@ def test_models(db): def test_registration(mailer, db, client): - page = client.get('/auth/registration').data - assert 'E-mail' in page - assert 'First name' in page - assert 'Last name' in page - assert 'Password' in page - assert 'Confirm password' in page - assert 'Register' in page + page = client.get("/auth/registration").data + assert "E-mail" in page + assert "First name" in page + assert "Last name" in page + assert "Password" in page + assert "Confirm password" in page + assert "Register" in page with mailer.store_mails() as mailbox: - with client.get('/auth/registration').context as ctx: - req = client.post('/auth/registration', data={ - 'email': 'william@massivedynamics.com', - 'first_name': 'William', - 'last_name': 'Bell', - 'password': 'imtheceo', - 'password2': 'imtheceo', - '_csrf_token': list(ctx.session._csrf)[-1] - }, follow_redirects=True) + with client.get("/auth/registration").context as ctx: + req = client.post( + "/auth/registration", + data={ + "email": "william@massivedynamics.com", + "first_name": "William", + "last_name": "Bell", + "password": "imtheceo", + "password2": "imtheceo", + "_csrf_token": list(ctx.session._csrf)[-1], + }, + follow_redirects=True, + ) assert "We sent you an email, check your inbox" in req.data assert len(mailbox) == 1 mail = mailbox[0] assert mail.recipients == ["william@massivedynamics.com"] - assert mail.subject == 'Email verification' + assert mail.subject == "Email verification" mail_as_str = str(mail) - assert 'Hello william@massivedynamics.com!' in mail_as_str - assert 'verify your email' in mail_as_str - verification_code = mail_as_str.split( - "http://localhost/auth/email_verification/")[1].split(" ")[0] - req = client.get( - '/auth/email_verification/{}'.format(verification_code), - follow_redirects=True) + assert "Hello william@massivedynamics.com!" in mail_as_str + assert "verify your email" in mail_as_str + verification_code = mail_as_str.split("http://localhost/auth/email_verification/")[1].split(" ")[0] + req = client.get("/auth/email_verification/{}".format(verification_code), follow_redirects=True) assert "Account verification completed" in req.data def test_login(db, client): - page = client.get('/auth/login').data - assert 'E-mail' in page - assert 'Password' in page - assert 'Sign in' in page - with client.get('/auth/login').context as ctx: - req = client.post('/auth/login', data={ - 'email': 'william@massivedynamics.com', - 'password': 'imtheceo', - '_csrf_token': list(ctx.session._csrf)[-1] - }, follow_redirects=True) - assert 'William' in req.data - assert 'Bell' in req.data - assert 'Save' in req.data + page = client.get("/auth/login").data + assert "E-mail" in page + assert "Password" in page + assert "Sign in" in page + with client.get("/auth/login").context as ctx: + req = client.post( + "/auth/login", + data={ + "email": "william@massivedynamics.com", + "password": "imtheceo", + "_csrf_token": list(ctx.session._csrf)[-1], + }, + follow_redirects=True, + ) + assert "William" in req.data + assert "Bell" in req.data + assert "Save" in req.data def test_password_change(db, client): - with client.get('/auth/login').context as ctx: - with client.post('/auth/login', data={ - 'email': 'william@massivedynamics.com', - 'password': 'imtheceo', - '_csrf_token': list(ctx.session._csrf)[-1] - }, follow_redirects=True): - page = client.get('/auth/password_change').data - assert 'Current password' in page - assert 'New password' in page - assert 'Confirm password' in page - with client.get('/auth/password_change').context as ctx2: - with client.post('/auth/password_change', data={ - 'old_password': 'imtheceo', - 'new_password': 'imthebigceo', - 'new_password2': 'imthebigceo', - '_csrf_token': list(ctx2.session._csrf)[-1] - }, follow_redirects=True) as req: - assert 'Password changed successfully' in req.data - assert 'William' in req.data - assert 'Save' in req.data + with client.get("/auth/login").context as ctx: + with client.post( + "/auth/login", + data={ + "email": "william@massivedynamics.com", + "password": "imtheceo", + "_csrf_token": list(ctx.session._csrf)[-1], + }, + follow_redirects=True, + ): + page = client.get("/auth/password_change").data + assert "Current password" in page + assert "New password" in page + assert "Confirm password" in page + with client.get("/auth/password_change").context as ctx2: + with client.post( + "/auth/password_change", + data={ + "old_password": "imtheceo", + "new_password": "imthebigceo", + "new_password2": "imthebigceo", + "_csrf_token": list(ctx2.session._csrf)[-1], + }, + follow_redirects=True, + ) as req: + assert "Password changed successfully" in req.data + assert "William" in req.data + assert "Save" in req.data with db.connection(): assert ( - db.users(email='william@massivedynamics.com').password == - db.users.password.requires[-1]('imthebigceo')[0]) + db.users(email="william@massivedynamics.com").password == db.users.password.requires[-1]("imthebigceo")[0] + ) def test_password_retrieval(mailer, db, client): - page = client.get('/auth/password_retrieval').data - assert 'Email' in page - assert 'Retrieve password' in page + page = client.get("/auth/password_retrieval").data + assert "Email" in page + assert "Retrieve password" in page with mailer.store_mails() as mailbox: - with client.get('/auth/password_retrieval').context as ctx: - with client.post('/auth/password_retrieval', data={ - 'email': 'william@massivedynamics.com', - '_csrf_token': list(ctx.session._csrf)[-1] - }, follow_redirects=True) as req: - assert 'We sent you an email, check your inbox' in req.data + with client.get("/auth/password_retrieval").context as ctx: + with client.post( + "/auth/password_retrieval", + data={"email": "william@massivedynamics.com", "_csrf_token": list(ctx.session._csrf)[-1]}, + follow_redirects=True, + ) as req: + assert "We sent you an email, check your inbox" in req.data assert len(mailbox) == 1 mail = mailbox[0] assert mail.recipients == ["william@massivedynamics.com"] - assert mail.subject == 'Password reset requested' + assert mail.subject == "Password reset requested" mail_as_str = str(mail) - assert 'A password reset was requested for your account' in mail_as_str - reset_code = mail_as_str.split( - "http://localhost/auth/password_reset/")[1].split(" ")[0] - with client.get( - '/auth/password_reset/{}'.format(reset_code), - follow_redirects=True - ) as req: - assert 'New password' in req.data - assert 'Confirm password' in req.data - assert 'Reset password' in req.data + assert "A password reset was requested for your account" in mail_as_str + reset_code = mail_as_str.split("http://localhost/auth/password_reset/")[1].split(" ")[0] + with client.get("/auth/password_reset/{}".format(reset_code), follow_redirects=True) as req: + assert "New password" in req.data + assert "Confirm password" in req.data + assert "Reset password" in req.data with client.post( - '/auth/password_reset/{}'.format(reset_code), + "/auth/password_reset/{}".format(reset_code), data={ - 'password': 'imtheceo', - 'password2': 'imtheceo', - '_csrf_token': list(req.context.session._csrf)[-1]}, - follow_redirects=True + "password": "imtheceo", + "password2": "imtheceo", + "_csrf_token": list(req.context.session._csrf)[-1], + }, + follow_redirects=True, ) as req2: - assert 'Password changed successfully' in req2.data - assert 'Sign in' in req2.data + assert "Password changed successfully" in req2.data + assert "Sign in" in req2.data with db.connection(): - assert ( - db.users(email='william@massivedynamics.com').password == - db.users.password.requires[-1]('imtheceo')[0]) + assert db.users(email="william@massivedynamics.com").password == db.users.password.requires[-1]("imtheceo")[0] diff --git a/tests/test_cache.py b/tests/test_cache.py index 140dbe42..15ffd259 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -1,27 +1,15 @@ # -*- coding: utf-8 -*- """ - tests.cache - ----------- +tests.cache +----------- - Test Emmett cache module +Test Emmett cache module """ import pytest -from collections import defaultdict - from emmett import App -from emmett.cache import CacheHandler, RamCache, DiskCache, Cache - - -def test_basecache(): - base_cache = CacheHandler() - assert base_cache._default_expire == 300 - - assert base_cache('key', lambda: 'value') is 'value' - assert base_cache.get('key') is None - assert base_cache.set('key', 'value', 300) is None - assert base_cache.clear() is None +from emmett.cache import DiskCache async def _await_2(): @@ -32,28 +20,6 @@ async def _await_3(): return 3 -@pytest.mark.asyncio -async def test_ramcache(): - ram_cache = RamCache() - assert ram_cache._prefix == '' - assert ram_cache._threshold == 500 - - assert ram_cache('test', lambda: 2) == 2 - assert ram_cache('test', lambda: 3, 300) == 2 - - assert await ram_cache('test_loop', _await_2) == 2 - assert await ram_cache('test_loop', _await_3, 300) == 2 - - ram_cache.set('test', 3) - assert ram_cache.get('test') == 3 - - ram_cache.set('test', 4, 300) - assert ram_cache.get('test') == 4 - - ram_cache.clear() - assert ram_cache.get('test') is None - - @pytest.mark.asyncio async def test_diskcache(): App(__name__) @@ -61,123 +27,17 @@ async def test_diskcache(): disk_cache = DiskCache() assert disk_cache._threshold == 500 - assert disk_cache('test', lambda: 2) == 2 - assert disk_cache('test', lambda: 3, 300) == 2 + assert disk_cache("test", lambda: 2) == 2 + assert disk_cache("test", lambda: 3, 300) == 2 - assert await disk_cache('test_loop', _await_2) == 2 - assert await disk_cache('test_loop', _await_3, 300) == 2 + assert await disk_cache("test_loop", _await_2) == 2 + assert await disk_cache("test_loop", _await_3, 300) == 2 - disk_cache.set('test', 3) - assert disk_cache.get('test') == 3 + disk_cache.set("test", 3) + assert disk_cache.get("test") == 3 - disk_cache.set('test', 4, 300) - assert disk_cache.get('test') == 4 + disk_cache.set("test", 4, 300) + assert disk_cache.get("test") == 4 disk_cache.clear() - assert disk_cache.get('test') is None - - -@pytest.mark.asyncio -async def test_cache(): - default_cache = Cache() - assert isinstance(default_cache._default_handler, RamCache) - - assert default_cache('test', lambda: 2) == 2 - assert default_cache('test', lambda: 3) == 2 - - assert await default_cache('test_loop', _await_2) == 2 - assert await default_cache('test_loop', _await_3, 300) == 2 - - default_cache.set('test', 3) - assert default_cache('test', lambda: 2) == 3 - - default_cache.set('test', 4, 300) - assert default_cache('test', lambda: 2, 300) == 4 - - default_cache.clear() - - disk_cache = DiskCache() - ram_cache = RamCache() - cache = Cache(default='disc', ram=ram_cache, disc=disk_cache) - assert isinstance(cache._default_handler, DiskCache) - assert cache.disc == disk_cache - assert cache.ram == ram_cache - - -def test_cache_decorator_sync(): - cache = Cache(ram=RamCache(prefix='test:')) - calls = defaultdict(lambda: 0) - - @cache('foo') - def foo(*args, **kwargs): - calls['foo'] += 1 - return 'foo' - - #: no arguments - for _ in range(0, 2): - foo() - assert len(cache._default_handler.data.keys()) == 1 - assert calls['foo'] == 1 - - #: args change the cache key - for _ in range(0, 2): - foo(1) - assert len(cache._default_handler.data.keys()) == 2 - assert calls['foo'] == 2 - - for _ in range(0, 2): - foo(1, 2) - assert len(cache._default_handler.data.keys()) == 3 - assert calls['foo'] == 3 - - #: kwargs change the cache key - for _ in range(0, 2): - foo(1, a='foo', b='bar') - assert len(cache._default_handler.data.keys()) == 4 - assert calls['foo'] == 4 - - #: kwargs order won't change the cache key - for _ in range(0, 2): - foo(1, b='bar', a='foo') - assert len(cache._default_handler.data.keys()) == 4 - assert calls['foo'] == 4 - - -@pytest.mark.asyncio -async def test_cache_decorator_loop(): - cache = Cache(ram=RamCache(prefix='bar:')) - calls = defaultdict(lambda: 0) - - @cache('bar') - async def bar(*args, **kwargs): - calls['bar'] += 1 - return 'bar' - - #: no arguments - for _ in range(0, 2): - await bar() - assert len(cache._default_handler.data.keys()) == 1 - assert calls['bar'] == 1 - - #: args change the cache key - for _ in range(0, 2): - await bar(1) - assert len(cache._default_handler.data.keys()) == 2 - assert calls['bar'] == 2 - - for _ in range(0, 2): - await bar(1, 2) - assert len(cache._default_handler.data.keys()) == 3 - assert calls['bar'] == 3 - - #: kwargs change the cache key - for _ in range(0, 2): - await bar(1, a='foo', b='bar') - assert len(cache._default_handler.data.keys()) == 4 - assert calls['bar'] == 4 - - #: kwargs order won't change the cache key - for _ in range(0, 2): - await bar(1, b='bar', a='foo') - assert len(cache._default_handler.data.keys()) == 4 - assert calls['bar'] == 4 + assert disk_cache.get("test") is None diff --git a/tests/test_http.py b/tests/test_http.py deleted file mode 100644 index 07fdc8ac..00000000 --- a/tests/test_http.py +++ /dev/null @@ -1,58 +0,0 @@ -# -*- coding: utf-8 -*- -""" - tests.http - ---------- - - Test Emmett http module -""" - -from helpers import current_ctx -from emmett.http import HTTP, HTTPBytes, HTTPResponse, redirect - - -def test_http_default(): - http = HTTP(200) - - assert http.encoded_body is b'' - assert http.status_code == 200 - assert list(http.headers) == [(b'content-type', b'text/plain')] - - -def test_http_bytes(): - http = HTTPBytes(200) - - assert http.body == b'' - assert http.status_code == 200 - assert list(http.headers) == [(b'content-type', b'text/plain')] - - -def test_http(): - response = [] - buffer = [] - - def start_response(status, headers): - response[:] = [status, headers] - return buffer.append - - http = HTTP( - 200, - 'Hello World', - headers={'x-test': 'Hello Header'}, - cookies={'cookie_test': 'Set-Cookie: hello cookie'} - ) - - assert http.encoded_body == b'Hello World' - assert http.status_code == 200 - assert list(http.headers) == [ - (b'x-test', b'Hello Header'), (b'set-cookie', b'hello cookie') - ] - - -def test_redirect(): - with current_ctx('/') as ctx: - try: - redirect('/redirect', 302) - except HTTPResponse as http_redirect: - assert ctx.response.status == 302 - assert http_redirect.status_code == 302 - assert list(http_redirect.headers) == [(b'location', b'/redirect')] diff --git a/tests/test_logger.py b/tests/test_logger.py index bf376c7b..49638381 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -1,14 +1,16 @@ # -*- coding: utf-8 -*- """ - tests.logger - ------------ +tests.logger +------------ - Test Emmett logging module +Test Emmett logging module """ import logging -from emmett import App, logger, sdict +from emmett_core import log as logger + +from emmett import App, sdict def _call_create_logger(app): @@ -17,18 +19,14 @@ def _call_create_logger(app): def test_user_assign_valid_level(): app = App(__name__) - app.config.logging.pytest = sdict( - level='info' - ) + app.config.logging.pytest = sdict(level="info") result = _call_create_logger(app) assert result.handlers[-1].level == logging.INFO def test_user_assign_invaild_level(): app = App(__name__) - app.config.logging.pytest = sdict( - level='invalid' - ) + app.config.logging.pytest = sdict(level="invalid") result = _call_create_logger(app) assert result.handlers[-1].level == logging.WARNING diff --git a/tests/test_mailer.py b/tests/test_mailer.py index e978028e..2c598e77 100644 --- a/tests/test_mailer.py +++ b/tests/test_mailer.py @@ -1,30 +1,31 @@ # -*- coding: utf-8 -*- """ - tests.mailer - ------------ +tests.mailer +------------ - Test Emmett mailer +Test Emmett mailer """ import base64 import email -import pytest import re import time - from email.header import Header + +import pytest + from emmett import App from emmett.tools.mailer import Mailer, sanitize_address -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def app(): rv = App(__name__) - rv.config.mailer.sender = 'support@example.com' + rv.config.mailer.sender = "support@example.com" return rv -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def mailer(app): rv = Mailer(app) return rv @@ -32,7 +33,7 @@ def mailer(app): def test_message_init(mailer): msg = mailer.mail(subject="subject", recipients="to@example.com") - assert msg.sender == 'support@example.com' + assert msg.sender == "support@example.com" assert msg.recipients == ["to@example.com"] @@ -45,32 +46,23 @@ def test_recipients(mailer): def test_all_recipients(mailer): msg = mailer.mail( - subject="subject", recipients=["somebody@here.com"], - cc=["cc@example.com"], bcc=["bcc@example.com"]) + subject="subject", recipients=["somebody@here.com"], cc=["cc@example.com"], bcc=["bcc@example.com"] + ) assert len(msg.all_recipients) == 3 msg.add_recipient("cc@example.com") assert len(msg.all_recipients) == 3 def test_sender(mailer): - msg = mailer.mail( - subject="subject", - sender=u'ÄÜÖ → ✓ >', - recipients=["to@example.com"]) - assert 'From: =?utf-8?b?w4TDnMOWIOKGkiDinJM=?= ' in \ - str(msg) + msg = mailer.mail(subject="subject", sender="ÄÜÖ → ✓ >", recipients=["to@example.com"]) + assert "From: =?utf-8?b?w4TDnMOWIOKGkiDinJM=?= " in str(msg) def test_sender_as_tuple(mailer): - msg = mailer.mail( - subject="testing", sender=("tester", "tester@example.com")) + msg = mailer.mail(subject="testing", sender=("tester", "tester@example.com")) assert msg.sender == "tester " - msg = mailer.mail( - subject="subject", - sender=(u"ÄÜÖ → ✓", 'from@example.com>'), - recipients=["to@example.com"]) - assert 'From: =?utf-8?b?w4TDnMOWIOKGkiDinJM=?= ' in \ - str(msg) + msg = mailer.mail(subject="subject", sender=("ÄÜÖ → ✓", "from@example.com>"), recipients=["to@example.com"]) + assert "From: =?utf-8?b?w4TDnMOWIOKGkiDinJM=?= " in str(msg) def test_reply_to(mailer): @@ -79,15 +71,14 @@ def test_reply_to(mailer): recipients=["to@example.com"], sender="spammer ", reply_to="somebody ", - body="testing") - h = Header( - "Reply-To: %s" % sanitize_address('somebody ')) + body="testing", + ) + h = Header("Reply-To: %s" % sanitize_address("somebody ")) assert h.encode() in str(msg) def test_missing_sender(mailer): - msg = mailer.mail( - subject="testing", recipients=["to@example.com"], body="testing") + msg = mailer.mail(subject="testing", recipients=["to@example.com"], body="testing") msg.sender = None with pytest.raises(AssertionError): msg.send() @@ -105,7 +96,8 @@ def test_bcc(mailer): subject="testing", recipients=["to@example.com"], body="testing", - bcc=["tosomeoneelse@example.com"]) + bcc=["tosomeoneelse@example.com"], + ) assert "tosomeoneelse@example.com" not in str(msg) @@ -115,43 +107,36 @@ def test_cc(mailer): subject="testing", recipients=["to@example.com"], body="testing", - cc=["tosomeoneelse@example.com"]) + cc=["tosomeoneelse@example.com"], + ) assert "Cc: tosomeoneelse@example.com" in str(msg) def test_attach(mailer): - msg = mailer.mail( - subject="testing", recipients=["to@example.com"], body="testing") + msg = mailer.mail(subject="testing", recipients=["to@example.com"], body="testing") msg.attach(data=b"this is a test", content_type="text/plain") a = msg.attachments[0] assert a.filename is None - assert a.disposition == 'attachment' - assert a.content_type == 'text/plain' + assert a.disposition == "attachment" + assert a.content_type == "text/plain" assert a.data == b"this is a test" def test_bad_subject(mailer): - msg = mailer.mail( - subject="testing\r\n", - sender="from@example.com", - body="testing", - recipients=["to@example.com"]) + msg = mailer.mail(subject="testing\r\n", sender="from@example.com", body="testing", recipients=["to@example.com"]) with pytest.raises(RuntimeError): msg.send() def test_subject(mailer): - msg = mailer.mail( - subject=u"sübject", - sender='from@example.com', - recipients=["to@example.com"]) - assert '=?utf-8?q?s=C3=BCbject?=' in str(msg) + msg = mailer.mail(subject="sübject", sender="from@example.com", recipients=["to@example.com"]) + assert "=?utf-8?q?s=C3=BCbject?=" in str(msg) def test_empty_subject(mailer): msg = mailer.mail(sender="from@example.com", recipients=["foo@bar.com"]) msg.body = "normal ascii text" - assert 'Subject:' not in str(msg) + assert "Subject:" not in str(msg) def test_multiline_subject(mailer): @@ -159,40 +144,31 @@ def test_multiline_subject(mailer): subject="testing\r\n testing\r\n testing \r\n \ttesting", sender="from@example.com", body="testing", - recipients=["to@example.com"]) + recipients=["to@example.com"], + ) msg_as_string = str(msg) assert "From: from@example.com" in msg_as_string assert "testing\r\n testing\r\n testing \r\n \ttesting" in msg_as_string msg = mailer.mail( - subject="testing\r\n testing\r\n ", - sender="from@example.com", - body="testing", - recipients=["to@example.com"]) + subject="testing\r\n testing\r\n ", sender="from@example.com", body="testing", recipients=["to@example.com"] + ) with pytest.raises(RuntimeError): msg.send() msg = mailer.mail( - subject="testing\r\n testing\r\n\t", - sender="from@example.com", - body="testing", - recipients=["to@example.com"]) + subject="testing\r\n testing\r\n\t", sender="from@example.com", body="testing", recipients=["to@example.com"] + ) with pytest.raises(RuntimeError): msg.send() msg = mailer.mail( - subject="testing\r\n testing\r\n\n", - sender="from@example.com", - body="testing", - recipients=["to@example.com"]) + subject="testing\r\n testing\r\n\n", sender="from@example.com", body="testing", recipients=["to@example.com"] + ) with pytest.raises(RuntimeError): msg.send() def test_bad_sender(mailer): - msg = mailer.mail( - subject="testing", - sender="from@example.com\r\n", - recipients=["to@example.com"], - body="testing") - assert 'From: from@example.com' in str(msg) + msg = mailer.mail(subject="testing", sender="from@example.com\r\n", recipients=["to@example.com"], body="testing") + assert "From: from@example.com" in str(msg) def test_bad_reply_to(mailer): @@ -201,11 +177,12 @@ def test_bad_reply_to(mailer): sender="from@example.com", reply_to="evil@example.com\r", recipients=["to@example.com"], - body="testing") + body="testing", + ) msg_as_string = str(msg) - assert 'From: from@example.com' in msg_as_string - assert 'To: to@example.com' in msg_as_string - assert 'Reply-To: evil@example.com' in msg_as_string + assert "From: from@example.com" in msg_as_string + assert "To: to@example.com" in msg_as_string + assert "Reply-To: evil@example.com" in msg_as_string def test_bad_recipient(mailer): @@ -213,8 +190,9 @@ def test_bad_recipient(mailer): subject="testing", sender="from@example.com", recipients=["to@example.com", "to\r\n@example.com"], - body="testing") - assert 'To: to@example.com' in str(msg) + body="testing", + ) + assert "To: to@example.com" in str(msg) def test_address_sanitize(mailer): @@ -222,86 +200,62 @@ def test_address_sanitize(mailer): subject="testing", sender="sender\r\n@example.com", reply_to="reply_to\r\n@example.com", - recipients=["recipient\r\n@example.com"]) + recipients=["recipient\r\n@example.com"], + ) msg_as_string = str(msg) - assert 'sender@example.com' in msg_as_string - assert 'reply_to@example.com' in msg_as_string - assert 'recipient@example.com' in msg_as_string + assert "sender@example.com" in msg_as_string + assert "reply_to@example.com" in msg_as_string + assert "recipient@example.com" in msg_as_string def test_plain_message(mailer): plain_text = "Hello Joe,\nHow are you?" - msg = mailer.mail( - sender="from@example.com", - subject="subject", - recipients=["to@example.com"], - body=plain_text) + msg = mailer.mail(sender="from@example.com", subject="subject", recipients=["to@example.com"], body=plain_text) assert plain_text == msg.body - assert 'Content-Type: text/plain' in str(msg) + assert "Content-Type: text/plain" in str(msg) def test_plain_message_with_attachments(mailer): - msg = mailer.mail( - sender="from@example.com", - subject="subject", - recipients=["to@example.com"], - body="hello") + msg = mailer.mail(sender="from@example.com", subject="subject", recipients=["to@example.com"], body="hello") msg.attach(data=b"this is a test", content_type="text/plain") - assert 'Content-Type: multipart/mixed' in str(msg) + assert "Content-Type: multipart/mixed" in str(msg) def test_plain_message_with_unicode_attachments(mailer): - msg = mailer.mail( - subject="subject", recipients=["to@example.com"], body="hello") - msg.attach( - data=b"this is a test", - content_type="text/plain", - filename=u'ünicöde ←→ ✓.txt') + msg = mailer.mail(subject="subject", recipients=["to@example.com"], body="hello") + msg.attach(data=b"this is a test", content_type="text/plain", filename="ünicöde ←→ ✓.txt") parsed = email.message_from_string(str(msg)) - assert ( - re.sub(r'\s+', ' ', parsed.get_payload()[1].get('Content-Disposition')) - in [ - 'attachment; filename*="UTF8\'\'' - '%C3%BCnic%C3%B6de%20%E2%86%90%E2%86%92%20%E2%9C%93.txt"', - 'attachment; filename*=UTF8\'\'' - '%C3%BCnic%C3%B6de%20%E2%86%90%E2%86%92%20%E2%9C%93.txt' - ]) + assert re.sub(r"\s+", " ", parsed.get_payload()[1].get("Content-Disposition")) in [ + "attachment; filename*=\"UTF8''" '%C3%BCnic%C3%B6de%20%E2%86%90%E2%86%92%20%E2%9C%93.txt"', + "attachment; filename*=UTF8''" "%C3%BCnic%C3%B6de%20%E2%86%90%E2%86%92%20%E2%9C%93.txt", + ] def test_html_message(mailer): html_text = "

Hello World

" - msg = mailer.mail( - sender="from@example.com", - subject="subject", - recipients=["to@example.com"], - html=html_text) + msg = mailer.mail(sender="from@example.com", subject="subject", recipients=["to@example.com"], html=html_text) assert html_text == msg.html - assert 'Content-Type: multipart/alternative' in str(msg) + assert "Content-Type: multipart/alternative" in str(msg) def test_json_message(mailer): json_text = '{"msg": "Hello World!}' msg = mailer.mail( - sender="from@example.com", - subject="subject", - recipients=["to@example.com"], - alts={'json': json_text}) - assert json_text == msg.alts['json'] - assert 'Content-Type: multipart/alternative' in str(msg) + sender="from@example.com", subject="subject", recipients=["to@example.com"], alts={"json": json_text} + ) + assert json_text == msg.alts["json"] + assert "Content-Type: multipart/alternative" in str(msg) def test_html_message_with_attachments(mailer): html_text = "

Hello World

" - plain_text = 'Hello World' + plain_text = "Hello World" msg = mailer.mail( - sender="from@example.com", - subject="subject", - recipients=["to@example.com"], - body=plain_text, - html=html_text) + sender="from@example.com", subject="subject", recipients=["to@example.com"], body=plain_text, html=html_text + ) msg.attach(data=b"this is a test", content_type="text/plain") assert html_text == msg.html - assert 'Content-Type: multipart/alternative' in str(msg) + assert "Content-Type: multipart/alternative" in str(msg) parsed = email.message_from_string(str(msg)) assert len(parsed.get_payload()) == 2 body, attachment = parsed.get_payload() @@ -309,47 +263,41 @@ def test_html_message_with_attachments(mailer): plain, html = body.get_payload() assert plain.get_payload() == plain_text assert html.get_payload() == html_text - assert base64.b64decode(attachment.get_payload()) == b'this is a test' + assert base64.b64decode(attachment.get_payload()) == b"this is a test" def test_date(mailer): before = time.time() msg = mailer.mail( - sender="from@example.com", - subject="subject", - recipients=["to@example.com"], - body="hello", - date=time.time()) + sender="from@example.com", subject="subject", recipients=["to@example.com"], body="hello", date=time.time() + ) after = time.time() assert before <= msg.date <= after fmt_date = email.utils.formatdate(msg.date, localtime=True) - assert 'Date: {}'.format(fmt_date) in str(msg) + assert "Date: {}".format(fmt_date) in str(msg) def test_msgid(mailer): - msg = mailer.mail( - sender="from@example.com", - subject="subject", - recipients=["to@example.com"], - body="hello") + msg = mailer.mail(sender="from@example.com", subject="subject", recipients=["to@example.com"], body="hello") r = re.compile(r"<\S+@\S+>").match(msg.msgId) assert r is not None - assert 'Message-ID: {}'.format(msg.msgId) in str(msg) + assert "Message-ID: {}".format(msg.msgId) in str(msg) def test_unicode_addresses(mailer): msg = mailer.mail( subject="subject", - sender=u'ÄÜÖ → ✓ ', - recipients=[u"Ä ", u"Ü "], - cc=[u"Ö "]) + sender="ÄÜÖ → ✓ ", + recipients=["Ä ", "Ü "], + cc=["Ö "], + ) msg_as_string = str(msg) - a1 = sanitize_address(u"Ä ") - a2 = sanitize_address(u"Ü ") + a1 = sanitize_address("Ä ") + a2 = sanitize_address("Ü ") h1_a = Header("To: %s, %s" % (a1, a2)) h1_b = Header("To: %s, %s" % (a2, a1)) - h2 = Header("From: %s" % sanitize_address(u"ÄÜÖ → ✓ ")) - h3 = Header("Cc: %s" % sanitize_address(u"Ö ")) + h2 = Header("From: %s" % sanitize_address("ÄÜÖ → ✓ ")) + h3 = Header("Cc: %s" % sanitize_address("Ö ")) try: assert h1_a.encode() in msg_as_string except AssertionError: @@ -364,28 +312,27 @@ def test_extra_headers(mailer): subject="subject", recipients=["to@example.com"], body="hello", - extra_headers={'X-Extra-Header': 'Yes'}) - assert 'X-Extra-Header: Yes' in str(msg) + extra_headers={"X-Extra-Header": "Yes"}, + ) + assert "X-Extra-Header: Yes" in str(msg) def test_send(mailer): with mailer.store_mails() as outbox: - msg = mailer.mail( - subject="testing", recipients=["tester@example.com"], body="test") + msg = mailer.mail(subject="testing", recipients=["tester@example.com"], body="test") msg.send() assert msg.date is not None assert len(outbox) == 1 sent_msg = outbox[0] - assert sent_msg.sender == 'support@example.com' + assert sent_msg.sender == "support@example.com" def test_send_mail(mailer): with mailer.store_mails() as outbox: - mailer.send_mail( - subject="testing", recipients=["tester@example.com"], body="test") + mailer.send_mail(subject="testing", recipients=["tester@example.com"], body="test") assert len(outbox) == 1 sent_msg = outbox[0] - assert sent_msg.subject == 'testing' + assert sent_msg.subject == "testing" assert sent_msg.recipients == ["tester@example.com"] - assert sent_msg.body == 'test' - assert sent_msg.sender == 'support@example.com' + assert sent_msg.body == "test" + assert sent_msg.sender == "support@example.com" diff --git a/tests/test_migrations.py b/tests/test_migrations.py index 9b774177..940f5130 100644 --- a/tests/test_migrations.py +++ b/tests/test_migrations.py @@ -1,9 +1,9 @@ # -*- coding: utf-8 -*- """ - tests.migrations - ---------------- +tests.migrations +---------------- - Test Emmett migrations engine +Test Emmett migrations engine """ import uuid @@ -11,9 +11,9 @@ import pytest from emmett import App -from emmett.orm import Database, Model, Field, belongs_to, refers_to -from emmett.orm.migrations.engine import MetaEngine, Engine -from emmett.orm.migrations.generation import MetaData, Comparator +from emmett.orm import Database, Field, Model, belongs_to, refers_to +from emmett.orm.migrations.engine import Engine, MetaEngine +from emmett.orm.migrations.generation import Comparator, MetaData class FakeEngine(Engine): @@ -23,10 +23,10 @@ def _log_and_exec(self, sql): self.sql_history.append(sql) -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def app(): rv = App(__name__) - rv.config.db.uri = 'sqlite:memory' + rv.config.db.uri = "sqlite:memory" return rv @@ -58,6 +58,7 @@ class StepOneThing(Model): name = Field() value = Field.float() + _step_one_sql = """CREATE TABLE "step_one_things"( "id" INTEGER PRIMARY KEY AUTOINCREMENT, "name" CHAR(512), @@ -105,6 +106,7 @@ class StepTwoThing(Model): value = Field.float(default=8.8) available = Field.bool(default=True) + _step_two_sql = """CREATE TABLE "step_two_things"( "id" INTEGER PRIMARY KEY AUTOINCREMENT, "name" CHAR(512) NOT NULL, @@ -136,6 +138,7 @@ class StepThreeThingThree(Model): tablename = "step_three_thing_ones" b = Field() + _step_three_sql = 'ALTER TABLE "step_three_thing_ones" ADD "b" CHAR(512);' _step_three_sql_drop = 'ALTER TABLE "step_three_thing_ones" DROP COLUMN "a";' @@ -178,6 +181,7 @@ class StepFourThingEdit(Model): available = Field.bool(default=True) asd = Field.int() + _step_four_sql = """ALTER TABLE "step_four_things" ALTER COLUMN "name" DROP NOT NULL; ALTER TABLE "step_four_things" ALTER COLUMN "value" DROP DEFAULT; ALTER TABLE "step_four_things" ALTER COLUMN "asd" TYPE INTEGER;""" @@ -202,32 +206,24 @@ class StepFiveThing(Model): value = Field.int() created_at = Field.datetime() - indexes = { - 'name': True, - ('name', 'value'): True - } + indexes = {"name": True, ("name", "value"): True} class StepFiveThingEdit(StepFiveThing): tablename = "step_five_things" - indexes = { - 'name': False, - 'name_created': { - 'fields': 'name', - 'expressions': lambda m: m.created_at.coalesce(None)} - } + indexes = {"name": False, "name_created": {"fields": "name", "expressions": lambda m: m.created_at.coalesce(None)}} _step_five_sql_before = [ 'CREATE UNIQUE INDEX "step_five_things_widx__code_unique" ON "step_five_things" ("code");', 'CREATE INDEX "step_five_things_widx__name" ON "step_five_things" ("name");', - 'CREATE INDEX "step_five_things_widx__name_value" ON "step_five_things" ("name","value");' + 'CREATE INDEX "step_five_things_widx__name_value" ON "step_five_things" ("name","value");', ] _step_five_sql_after = [ 'DROP INDEX "step_five_things_widx__name";', - 'CREATE INDEX "step_five_things_widx__name_created" ON "step_five_things" ("name",COALESCE("created_at",NULL));' + 'CREATE INDEX "step_five_things_widx__name_created" ON "step_five_things" ("name",COALESCE("created_at",NULL));', ] @@ -255,7 +251,7 @@ class StepSixThing(Model): class StepSixRelate(Model): - belongs_to('step_six_thing') + belongs_to("step_six_thing") name = Field() @@ -268,11 +264,13 @@ class StepSixRelate(Model): "name" CHAR(512), "step_six_thing" CHAR(512) );""" -_step_six_sql_fk = "".join([ - 'ALTER TABLE "step_six_relates" ADD CONSTRAINT ', - '"step_six_relates_ecnt__fk__stepsixthings_stepsixthing" FOREIGN KEY ', - '("step_six_thing") REFERENCES "step_six_things"("id") ON DELETE CASCADE;' -]) +_step_six_sql_fk = "".join( + [ + 'ALTER TABLE "step_six_relates" ADD CONSTRAINT ', + '"step_six_relates_ecnt__fk__stepsixthings_stepsixthing" FOREIGN KEY ', + '("step_six_thing") REFERENCES "step_six_things"("id") ON DELETE CASCADE;', + ] +) def test_step_six_id_types(app): @@ -286,22 +284,17 @@ def test_step_six_id_types(app): class StepSevenThing(Model): - primary_keys = ['foo', 'bar'] + primary_keys = ["foo", "bar"] foo = Field() bar = Field() class StepSevenRelate(Model): - refers_to({'foo': 'StepSevenThing.foo'}, {'bar': 'StepSevenThing.bar'}) + refers_to({"foo": "StepSevenThing.foo"}, {"bar": "StepSevenThing.bar"}) name = Field() - foreign_keys = { - "test": { - "fields": ["foo", "bar"], - "foreign_fields": ["foo", "bar"] - } - } + foreign_keys = {"test": {"fields": ["foo", "bar"], "foreign_fields": ["foo", "bar"]}} _step_seven_sql_t1 = """CREATE TABLE "step_seven_things"( @@ -315,11 +308,13 @@ class StepSevenRelate(Model): "foo" CHAR(512), "bar" CHAR(512) );""" -_step_seven_sql_fk = "".join([ - 'ALTER TABLE "step_seven_relates" ADD CONSTRAINT ', - '"step_seven_relates_ecnt__fk__stepseventhings_foo_bar" FOREIGN KEY ("foo","bar") ', - 'REFERENCES "step_seven_things"("foo","bar") ON DELETE CASCADE;', -]) +_step_seven_sql_fk = "".join( + [ + 'ALTER TABLE "step_seven_relates" ADD CONSTRAINT ', + '"step_seven_relates_ecnt__fk__stepseventhings_foo_bar" FOREIGN KEY ("foo","bar") ', + 'REFERENCES "step_seven_things"("foo","bar") ON DELETE CASCADE;', + ] +) def test_step_seven_composed_pks(app): diff --git a/tests/test_orm.py b/tests/test_orm.py index b99dc303..9dd8e11e 100644 --- a/tests/test_orm.py +++ b/tests/test_orm.py @@ -1,58 +1,63 @@ # -*- coding: utf-8 -*- """ - tests.orm - --------- +tests.orm +--------- - Test pyDAL implementation over Emmett. +Test pyDAL implementation over Emmett. """ -import pytest - from datetime import datetime, timedelta from uuid import uuid4 -from pydal.objects import Table +import pytest from pydal import Field as _Field -from emmett import App, sdict, now +from pydal.objects import Table + +from emmett import App, now, sdict from emmett.orm import ( - Database, Field, Model, + Database, + Field, + Model, + after_commit, + after_delete, + after_destroy, + after_insert, + after_save, + after_update, + before_commit, + before_delete, + before_destroy, + before_insert, + before_save, + before_update, + belongs_to, compute, - before_insert, after_insert, - before_update, after_update, - before_delete, after_delete, - before_save, after_save, - before_destroy, after_destroy, - before_commit, after_commit, - rowattr, rowmethod, - has_one, has_many, belongs_to, refers_to, - scope + has_many, + has_one, + refers_to, + rowattr, + rowmethod, + scope, ) +from emmett.orm.errors import MissingFieldsForCompute from emmett.orm.migrations.utils import generate_runtime_migration from emmett.orm.objects import TransactionOps -from emmett.orm.errors import MissingFieldsForCompute -from emmett.validators import isntEmpty, hasLength +from emmett.validators import hasLength, isntEmpty CALLBACK_OPS = { "before_insert": [], "before_update": [], "before_delete": [], - "before_save":[], + "before_save": [], "before_destroy": [], "after_insert": [], "after_update": [], "after_delete": [], - "after_save":[], - "after_destroy": [] -} -COMMIT_CALLBACKS = { - "all": [], - "insert": [], - "update": [], - "delete": [], - "save": [], - "destroy": [] + "after_save": [], + "after_destroy": [], } +COMMIT_CALLBACKS = {"all": [], "insert": [], "update": [], "delete": [], "save": [], "destroy": []} def _represent_f(value): @@ -72,73 +77,57 @@ class Stuff(Model): total_watch = Field.float() invisible = Field() - validation = { - "a": {'presence': True}, - "total": {"allow": "empty"}, - "total_watch": {"allow": "empty"} - } + validation = {"a": {"presence": True}, "total": {"allow": "empty"}, "total_watch": {"allow": "empty"}} - fields_rw = { - "invisible": False - } + fields_rw = {"invisible": False} - form_labels = { - "a": "A label" - } + form_labels = {"a": "A label"} - form_info = { - "a": "A comment" - } + form_info = {"a": "A comment"} - update_values = { - "a": "a_update" - } + update_values = {"a": "a_update"} - repr_values = { - "a": _represent_f - } + repr_values = {"a": _represent_f} - form_widgets = { - "a": _widget_f - } + form_widgets = {"a": _widget_f} - @compute('total') + @compute("total") def eval_total(self, row): return row.price * row.quantity - @compute('total_watch', watch=['price', 'quantity']) + @compute("total_watch", watch=["price", "quantity"]) def eval_total_watch(self, row): return row.price * row.quantity @before_insert def bi(self, fields): - CALLBACK_OPS['before_insert'].append(fields) + CALLBACK_OPS["before_insert"].append(fields) @after_insert def ai(self, fields, id): - CALLBACK_OPS['after_insert'].append((fields, id)) + CALLBACK_OPS["after_insert"].append((fields, id)) @before_update def bu(self, set, fields): - CALLBACK_OPS['before_update'].append((set, fields)) + CALLBACK_OPS["before_update"].append((set, fields)) @after_update def au(self, set, fields): - CALLBACK_OPS['after_update'].append((set, fields)) + CALLBACK_OPS["after_update"].append((set, fields)) @before_delete def bd(self, set): - CALLBACK_OPS['before_delete'].append(set) + CALLBACK_OPS["before_delete"].append(set) @after_delete def ad(self, set): - CALLBACK_OPS['after_delete'].append(set) + CALLBACK_OPS["after_delete"].append(set) - @rowattr('totalv') + @rowattr("totalv") def eval_total_v(self, row): return row.price * row.quantity - @rowmethod('totalm') + @rowmethod("totalm") def eval_total_m(self, row): return row.price * row.quantity @@ -148,57 +137,55 @@ def method_test(cls, t): class Person(Model): - has_many( - 'things', {'features': {'via': 'things'}}, {'pets': 'Dog.owner'}, - 'subscriptions') + has_many("things", {"features": {"via": "things"}}, {"pets": "Dog.owner"}, "subscriptions") name = Field() age = Field.int() class Thing(Model): - belongs_to('person') - has_many('features') + belongs_to("person") + has_many("features") name = Field() color = Field() class Feature(Model): - belongs_to('thing') - has_one('price') + belongs_to("thing") + has_one("price") name = Field() class Price(Model): - belongs_to('feature') + belongs_to("feature") value = Field.int() class Doctor(Model): - has_many('appointments', {'patients': {'via': 'appointments'}}) + has_many("appointments", {"patients": {"via": "appointments"}}) name = Field() class Patient(Model): - has_many('appointments', {'doctors': {'via': 'appointments'}}) + has_many("appointments", {"doctors": {"via": "appointments"}}) name = Field() class Appointment(Model): - belongs_to('patient', 'doctor') + belongs_to("patient", "doctor") date = Field.datetime() class User(Model): name = Field() has_many( - 'memberships', {'organizations': {'via': 'memberships'}}, - {'cover_orgs': { - 'via': 'memberships.organization', - 'where': lambda m: m.is_cover == True}}) + "memberships", + {"organizations": {"via": "memberships"}}, + {"cover_orgs": {"via": "memberships.organization", "where": lambda m: m.is_cover == True}}, + ) class Organization(Model): @@ -210,22 +197,23 @@ def admin_memberships3(self): return Membership.admins() has_many( - 'memberships', {'users': {'via': 'memberships'}}, - {'admin_memberships': {'target': 'Membership', 'scope': 'admins'}}, - {'admins': {'via': 'admin_memberships.user'}}, - {'admin_memberships2': { - 'target': 'Membership', 'where': lambda m: m.role == 'admin'}}, - {'admins2': {'via': 'admin_memberships2.user'}}, - {'admins3': {'via': 'admin_memberships3.user'}}) + "memberships", + {"users": {"via": "memberships"}}, + {"admin_memberships": {"target": "Membership", "scope": "admins"}}, + {"admins": {"via": "admin_memberships.user"}}, + {"admin_memberships2": {"target": "Membership", "where": lambda m: m.role == "admin"}}, + {"admins2": {"via": "admin_memberships2.user"}}, + {"admins3": {"via": "admin_memberships3.user"}}, + ) class Membership(Model): - belongs_to('user', 'organization') + belongs_to("user", "organization") role = Field() - @scope('admins') + @scope("admins") def filter_admins(self): - return self.role == 'admin' + return self.role == "admin" class House(Model): @@ -234,7 +222,7 @@ class House(Model): class Mouse(Model): tablename = "mice" - has_many('elephants') + has_many("elephants") name = Field() @@ -243,19 +231,19 @@ class NeedSplit(Model): class Zoo(Model): - has_many('animals', 'elephants', {'mice': {'via': 'elephants.mouse'}}) + has_many("animals", "elephants", {"mice": {"via": "elephants.mouse"}}) name = Field() class Animal(Model): - belongs_to('zoo') + belongs_to("zoo") name = Field() - @rowattr('doublename') + @rowattr("doublename") def get_double_name(self, row): return row.name * 2 - @rowattr('pretty') + @rowattr("pretty") def get_pretty(self, row): return row.name @@ -269,10 +257,10 @@ def bi2(self, *args, **kwargs): class Elephant(Animal): - belongs_to('mouse') + belongs_to("mouse") color = Field() - @rowattr('pretty') + @rowattr("pretty") def get_pretty(self, row): return row.name + " " + row.color @@ -282,24 +270,24 @@ def bi2(self, *args, **kwargs): class Dog(Model): - belongs_to({'owner': 'Person'}) + belongs_to({"owner": "Person"}) name = Field() class Subscription(Model): - belongs_to('person') + belongs_to("person") name = Field() status = Field.int() expires_at = Field.datetime() - STATUS = {'active': 1, 'suspended': 2, 'other': 3} + STATUS = {"active": 1, "suspended": 2, "other": 3} - @scope('expired') + @scope("expired") def get_expired(self): return self.expires_at < datetime.now() - @scope('of_status') + @scope("of_status") def filter_status(self, *statuses): if len(statuses) == 1: return self.status == self.STATUS[statuses[0]] @@ -362,7 +350,7 @@ def _compute_price(self, row): class SelfRef(Model): - refers_to({'parent': 'self'}) + refers_to({"parent": "self"}) name = Field.string() @@ -412,29 +400,43 @@ def _commit_watch_after_destroy(self, ctx): COMMIT_CALLBACKS["destroy"].append(("after", ctx)) -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def _db(): app = App(__name__) - db = Database( - app, config=sdict( - uri=f'sqlite://{uuid4().hex}.db', - auto_connect=True - ) - ) + db = Database(app, config=sdict(uri=f"sqlite://{uuid4().hex}.db", auto_connect=True)) db.define_models( - Stuff, Person, Thing, Feature, Price, Dog, Subscription, - Doctor, Patient, Appointment, - User, Organization, Membership, - House, Mouse, NeedSplit, Zoo, Animal, Elephant, - Product, Cart, CartElement, + Stuff, + Person, + Thing, + Feature, + Price, + Dog, + Subscription, + Doctor, + Patient, + Appointment, + User, + Organization, + Membership, + House, + Mouse, + NeedSplit, + Zoo, + Animal, + Elephant, + Product, + Cart, + CartElement, SelfRef, - CustomPKType, CustomPKName, CustomPKMulti, - CommitWatcher + CustomPKType, + CustomPKName, + CustomPKMulti, + CommitWatcher, ) return db -@pytest.fixture(scope='function') +@pytest.fixture(scope="function") def db(_db): migration = generate_runtime_migration(_db) migration.up() @@ -948,12 +950,13 @@ def test_relations(db): assert t[0].name == "apple" and t[0].color == "red" and t[0].person.id == p1.id f = p1.things()[0].features() assert len(f) == 1 - assert f[0].name == "tasty" and f[0].thing.id == t[0].id and \ - f[0].thing.person.id == p1.id + assert f[0].name == "tasty" and f[0].thing.id == t[0].id and f[0].thing.person.id == p1.id m = p1.things()[0].features()[0].price() assert ( - m.value == 5 and m.feature.id == f[0].id and - m.feature.thing.id == t[0].id and m.feature.thing.person.id == p1.id + m.value == 5 + and m.feature.id == f[0].id + and m.feature.thing.id == t[0].id + and m.feature.thing.person.id == p1.id ) p2.things.add(t2) assert p1.things.count() == 1 @@ -974,30 +977,30 @@ def test_relations(db): assert len(doctor.appointments()) == 1 assert len(patient.doctors()) == 1 assert len(patient.appointments()) == 1 - joe = db.User.insert(name='joe') - jim = db.User.insert(name='jim') - org = db.Organization.insert(name='') - org.users.add(joe, role='admin') - org.users.add(jim, role='manager') + joe = db.User.insert(name="joe") + jim = db.User.insert(name="jim") + org = db.Organization.insert(name="") + org.users.add(joe, role="admin") + org.users.add(jim, role="manager") assert len(org.users()) == 2 assert len(joe.organizations()) == 1 assert len(jim.organizations()) == 1 assert joe.organizations().first().id == org assert jim.organizations().first().id == org - assert joe.memberships().first().role == 'admin' - assert jim.memberships().first().role == 'manager' + assert joe.memberships().first().role == "admin" + assert jim.memberships().first().role == "manager" org.users.remove(joe) org.users.remove(jim) assert len(org.users(reload=True)) == 0 assert len(joe.organizations(reload=True)) == 0 assert len(jim.organizations(reload=True)) == 0 #: has_many with specified feld - db.Dog.insert(name='pongo', owner=p1) - assert len(p1.pets()) == 1 and p1.pets().first().name == 'pongo' + db.Dog.insert(name="pongo", owner=p1) + assert len(p1.pets()) == 1 and p1.pets().first().name == "pongo" #: has_many via with specified field - zoo = db.Zoo.insert(name='magic zoo') - mouse = db.Mouse.insert(name='jerry') - db.Elephant.insert(name='dumbo', color='pink', mouse=mouse, zoo=zoo) + zoo = db.Zoo.insert(name="magic zoo") + mouse = db.Mouse.insert(name="jerry") + db.Elephant.insert(name="dumbo", color="pink", mouse=mouse, zoo=zoo) assert len(zoo.mice()) == 1 @@ -1008,44 +1011,34 @@ def test_tablenames(db): def test_inheritance(db): - assert 'name' in db.Animal.fields - assert 'name' in db.Elephant.fields - assert 'zoo' in db.Animal.fields - assert 'zoo' in db.Elephant.fields - assert 'color' in db.Elephant.fields - assert 'color' not in db.Animal.fields - assert Elephant._all_virtuals_['get_double_name'] is \ - Animal._all_virtuals_['get_double_name'] - assert Elephant._all_virtuals_['get_pretty'] is not \ - Animal._all_virtuals_['get_pretty'] - assert Elephant._all_callbacks_['bi'] is \ - Animal._all_callbacks_['bi'] - assert Elephant._all_callbacks_['bi2'] is not \ - Animal._all_callbacks_['bi2'] + assert "name" in db.Animal.fields + assert "name" in db.Elephant.fields + assert "zoo" in db.Animal.fields + assert "zoo" in db.Elephant.fields + assert "color" in db.Elephant.fields + assert "color" not in db.Animal.fields + assert Elephant._all_virtuals_["get_double_name"] is Animal._all_virtuals_["get_double_name"] + assert Elephant._all_virtuals_["get_pretty"] is not Animal._all_virtuals_["get_pretty"] + assert Elephant._all_callbacks_["bi"] is Animal._all_callbacks_["bi"] + assert Elephant._all_callbacks_["bi2"] is not Animal._all_callbacks_["bi2"] def test_scopes(db): p = db.Person.insert(name="Walter", age=50) - s = db.Subscription.insert( - name="a", expires_at=datetime.now() - timedelta(hours=20), person=p, - status=1) - s2 = db.Subscription.insert( - name="b", expires_at=datetime.now() + timedelta(hours=20), person=p, - status=2) - db.Subscription.insert( - name="c", expires_at=datetime.now() + timedelta(hours=20), person=p, - status=3) + s = db.Subscription.insert(name="a", expires_at=datetime.now() - timedelta(hours=20), person=p, status=1) + s2 = db.Subscription.insert(name="b", expires_at=datetime.now() + timedelta(hours=20), person=p, status=2) + db.Subscription.insert(name="c", expires_at=datetime.now() + timedelta(hours=20), person=p, status=3) rows = db(db.Subscription).expired().select() assert len(rows) == 1 and rows[0].id == s rows = p.subscriptions.expired().select() assert len(rows) == 1 and rows[0].id == s rows = Subscription.expired().select() assert len(rows) == 1 and rows[0].id == s - rows = db(db.Subscription).of_status('active', 'suspended').select() + rows = db(db.Subscription).of_status("active", "suspended").select() assert len(rows) == 2 and rows[0].id == s and rows[1].id == s2 - rows = p.subscriptions.of_status('active', 'suspended').select() + rows = p.subscriptions.of_status("active", "suspended").select() assert len(rows) == 2 and rows[0].id == s and rows[1].id == s2 - rows = Subscription.of_status('active', 'suspended').select() + rows = Subscription.of_status("active", "suspended").select() assert len(rows) == 2 and rows[0].id == s and rows[1].id == s2 @@ -1054,7 +1047,7 @@ def test_relations_scopes(db): org = db.Organization.insert(name="Los pollos hermanos") org.users.add(gus, role="admin") frank = db.User.insert(name="Frank") - org.users.add(frank, role='manager') + org.users.add(frank, role="manager") assert org.admins.count() == 1 assert org.admins2.count() == 1 assert org.admins3.count() == 1 @@ -1100,24 +1093,13 @@ def test_relations_scopes(db): def test_model_where(db): - assert Subscription.where(lambda s: s.status == 1).query == \ - db(db.Subscription.status == 1).query + assert Subscription.where(lambda s: s.status == 1).query == db(db.Subscription.status == 1).query def test_model_first(db): p = db.Person.insert(name="Walter", age=50) - db.Subscription.insert( - name="a", - expires_at=datetime.now() + timedelta(hours=20), - person=p, - status=1 - ) - db.Subscription.insert( - name="b", - expires_at=datetime.now() + timedelta(hours=20), - person=p, - status=1 - ) + db.Subscription.insert(name="a", expires_at=datetime.now() + timedelta(hours=20), person=p, status=1) + db.Subscription.insert(name="b", expires_at=datetime.now() + timedelta(hours=20), person=p, status=1) db.CustomPKType.insert(id="a") db.CustomPKType.insert(id="b") db.CustomPKName.insert(name="a") @@ -1126,38 +1108,23 @@ def test_model_first(db): db.CustomPKMulti.insert(first_name="foo", last_name="baz") db.CustomPKMulti.insert(first_name="bar", last_name="baz") - assert Subscription.first().id == Subscription.all().select( - orderby=Subscription.id, - limitby=(0, 1) - ).first().id - assert CustomPKType.first().id == CustomPKType.all().select( - orderby=CustomPKType.id, - limitby=(0, 1) - ).first().id - assert CustomPKName.first().name == CustomPKName.all().select( - orderby=CustomPKName.name, - limitby=(0, 1) - ).first().name - assert CustomPKMulti.first() == CustomPKMulti.all().select( - orderby=CustomPKMulti.first_name|CustomPKMulti.last_name, - limitby=(0, 1) - ).first() + assert Subscription.first().id == Subscription.all().select(orderby=Subscription.id, limitby=(0, 1)).first().id + assert CustomPKType.first().id == CustomPKType.all().select(orderby=CustomPKType.id, limitby=(0, 1)).first().id + assert ( + CustomPKName.first().name == CustomPKName.all().select(orderby=CustomPKName.name, limitby=(0, 1)).first().name + ) + assert ( + CustomPKMulti.first() + == CustomPKMulti.all() + .select(orderby=CustomPKMulti.first_name | CustomPKMulti.last_name, limitby=(0, 1)) + .first() + ) def test_model_last(db): p = db.Person.insert(name="Walter", age=50) - db.Subscription.insert( - name="a", - expires_at=datetime.now() + timedelta(hours=20), - person=p, - status=1 - ) - db.Subscription.insert( - name="b", - expires_at=datetime.now() + timedelta(hours=20), - person=p, - status=1 - ) + db.Subscription.insert(name="a", expires_at=datetime.now() + timedelta(hours=20), person=p, status=1) + db.Subscription.insert(name="b", expires_at=datetime.now() + timedelta(hours=20), person=p, status=1) db.CustomPKType.insert(id="a") db.CustomPKType.insert(id="b") db.CustomPKName.insert(name="a") @@ -1166,19 +1133,14 @@ def test_model_last(db): db.CustomPKMulti.insert(first_name="foo", last_name="baz") db.CustomPKMulti.insert(first_name="bar", last_name="baz") - assert Subscription.last().id == Subscription.all().select( - orderby=~Subscription.id, - limitby=(0, 1) - ).first().id - assert CustomPKType.last().id == CustomPKType.all().select( - orderby=~CustomPKType.id, - limitby=(0, 1) - ).first().id - assert CustomPKName.last().name == CustomPKName.all().select( - orderby=~CustomPKName.name, - limitby=(0, 1) - ).first().name - assert CustomPKMulti.last() == CustomPKMulti.all().select( - orderby=~CustomPKMulti.first_name|~CustomPKMulti.last_name, - limitby=(0, 1) - ).first() + assert Subscription.last().id == Subscription.all().select(orderby=~Subscription.id, limitby=(0, 1)).first().id + assert CustomPKType.last().id == CustomPKType.all().select(orderby=~CustomPKType.id, limitby=(0, 1)).first().id + assert ( + CustomPKName.last().name == CustomPKName.all().select(orderby=~CustomPKName.name, limitby=(0, 1)).first().name + ) + assert ( + CustomPKMulti.last() + == CustomPKMulti.all() + .select(orderby=~CustomPKMulti.first_name | ~CustomPKMulti.last_name, limitby=(0, 1)) + .first() + ) diff --git a/tests/test_orm_connections.py b/tests/test_orm_connections.py index d4c7e048..df881763 100644 --- a/tests/test_orm_connections.py +++ b/tests/test_orm_connections.py @@ -1,9 +1,9 @@ # -*- coding: utf-8 -*- """ - tests.orm_connections - --------------------- +tests.orm_connections +--------------------- - Test pyDAL connection implementation over Emmett. +Test pyDAL connection implementation over Emmett. """ import pytest @@ -12,17 +12,10 @@ from emmett.orm import Database -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def db(): app = App(__name__) - db = Database( - app, - config=sdict( - uri='sqlite:memory', - auto_migrate=True, - auto_connect=False - ) - ) + db = Database(app, config=sdict(uri="sqlite:memory", auto_migrate=True, auto_connect=False)) return db diff --git a/tests/test_orm_gis.py b/tests/test_orm_gis.py index 0da40b86..d63542eb 100644 --- a/tests/test_orm_gis.py +++ b/tests/test_orm_gis.py @@ -1,21 +1,21 @@ # -*- coding: utf-8 -*- """ - tests.orm_gis - ------------- +tests.orm_gis +------------- - Test ORM GIS features +Test ORM GIS features """ import os + import pytest from emmett import App, sdict -from emmett.orm import Database, Model, Field, geo +from emmett.orm import Database, Field, Model, geo from emmett.orm.migrations.utils import generate_runtime_migration -require_postgres = pytest.mark.skipif( - not os.environ.get("POSTGRES_URI"), reason="No postgres database" -) + +require_postgres = pytest.mark.skipif(not os.environ.get("POSTGRES_URI"), reason="No postgres database") class Geography(Model): @@ -40,23 +40,15 @@ class Geometry(Model): multipolygon = Field.geometry("MULTIPOLYGON") -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def _db(): app = App(__name__) - db = Database( - app, - config=sdict( - uri=f"postgres://{os.environ.get('POSTGRES_URI')}" - ) - ) - db.define_models( - Geography, - Geometry - ) + db = Database(app, config=sdict(uri=f"postgres://{os.environ.get('POSTGRES_URI')}")) + db.define_models(Geography, Geometry) return db -@pytest.fixture(scope='function') +@pytest.fixture(scope="function") def db(_db): migration = generate_runtime_migration(_db) with _db.connection(): @@ -93,15 +85,9 @@ def test_gis_insert(db): multipoint=geo.MultiPoint((1, 1), (2, 2)), multiline=geo.MultiLine(((1, 1), (2, 2), (3, 3)), ((1, 1), (4, 4), (5, 5))), multipolygon=geo.MultiPolygon( - ( - ((0, 0), (20, 0), (20, 20), (0, 0)), - ((0, 0), (30, 0), (30, 30), (0, 0)) - ), - ( - ((1, 1), (21, 1), (21, 21), (1, 1)), - ((1, 1), (31, 1), (31, 31), (1, 1)) - ) - ) + (((0, 0), (20, 0), (20, 20), (0, 0)), ((0, 0), (30, 0), (30, 30), (0, 0))), + (((1, 1), (21, 1), (21, 21), (1, 1)), ((1, 1), (31, 1), (31, 31), (1, 1))), + ), ) row.save() @@ -111,35 +97,17 @@ def test_gis_insert(db): assert not row.point.groups assert row.line == "LINESTRING({})".format( - ",".join([ - " ".join(f"{v}.000000" for v in tup) - for tup in [ - (0, 0), - (20, 80), - (80, 80) - ] - ]) + ",".join([" ".join(f"{v}.000000" for v in tup) for tup in [(0, 0), (20, 80), (80, 80)]]) ) assert row.line.geometry == "LINESTRING" assert row.line.coordinates == ((0, 0), (20, 80), (80, 80)) assert not row.line.groups assert row.polygon == "POLYGON(({}))".format( - ",".join([ - " ".join(f"{v}.000000" for v in tup) - for tup in [ - (0, 0), - (150, 0), - (150, 10), - (0, 10), - (0, 0) - ] - ]) + ",".join([" ".join(f"{v}.000000" for v in tup) for tup in [(0, 0), (150, 0), (150, 10), (0, 10), (0, 0)]]) ) assert row.polygon.geometry == "POLYGON" - assert row.polygon.coordinates == ( - ((0, 0), (150, 0), (150, 10), (0, 10), (0, 0)), - ) + assert row.polygon.coordinates == (((0, 0), (150, 0), (150, 10), (0, 10), (0, 0)),) assert not row.polygon.groups assert row.multipoint == "MULTIPOINT((1.000000 1.000000),(2.000000 2.000000))" @@ -150,69 +118,48 @@ def test_gis_insert(db): assert row.multipoint.groups[1] == geo.Point(2, 2) assert row.multiline == "MULTILINESTRING({})".format( - ",".join([ - "({})".format( - ",".join([ - " ".join(f"{v}.000000" for v in tup) - for tup in group - ]) - ) for group in [ - ((1, 1), (2, 2), (3, 3)), - ((1, 1), (4, 4), (5, 5)) + ",".join( + [ + "({})".format(",".join([" ".join(f"{v}.000000" for v in tup) for tup in group])) + for group in [((1, 1), (2, 2), (3, 3)), ((1, 1), (4, 4), (5, 5))] ] - ]) + ) ) assert row.multiline.geometry == "MULTILINESTRING" - assert row.multiline.coordinates == ( - ((1, 1), (2, 2), (3, 3)), - ((1, 1), (4, 4), (5, 5)) - ) + assert row.multiline.coordinates == (((1, 1), (2, 2), (3, 3)), ((1, 1), (4, 4), (5, 5))) assert len(row.multiline.groups) == 2 assert row.multiline.groups[0] == geo.Line((1, 1), (2, 2), (3, 3)) assert row.multiline.groups[1] == geo.Line((1, 1), (4, 4), (5, 5)) assert row.multipolygon == "MULTIPOLYGON({})".format( - ",".join([ - "({})".format( - ",".join([ - "({})".format( - ",".join([ - " ".join(f"{v}.000000" for v in tup) - for tup in group - ]) - ) for group in polygon - ]) - ) for polygon in [ - ( - ((0, 0), (20, 0), (20, 20), (0, 0)), - ((0, 0), (30, 0), (30, 30), (0, 0)) - ), - ( - ((1, 1), (21, 1), (21, 21), (1, 1)), - ((1, 1), (31, 1), (31, 31), (1, 1)) + ",".join( + [ + "({})".format( + ",".join( + [ + "({})".format(",".join([" ".join(f"{v}.000000" for v in tup) for tup in group])) + for group in polygon + ] + ) ) + for polygon in [ + (((0, 0), (20, 0), (20, 20), (0, 0)), ((0, 0), (30, 0), (30, 30), (0, 0))), + (((1, 1), (21, 1), (21, 21), (1, 1)), ((1, 1), (31, 1), (31, 31), (1, 1))), + ] ] - ]) + ) ) assert row.multipolygon.geometry == "MULTIPOLYGON" assert row.multipolygon.coordinates == ( - ( - ((0, 0), (20, 0), (20, 20), (0, 0)), - ((0, 0), (30, 0), (30, 30), (0, 0)) - ), - ( - ((1, 1), (21, 1), (21, 21), (1, 1)), - ((1, 1), (31, 1), (31, 31), (1, 1)) - ) + (((0, 0), (20, 0), (20, 20), (0, 0)), ((0, 0), (30, 0), (30, 30), (0, 0))), + (((1, 1), (21, 1), (21, 21), (1, 1)), ((1, 1), (31, 1), (31, 31), (1, 1))), ) assert len(row.multipolygon.groups) == 2 assert row.multipolygon.groups[0] == geo.Polygon( - ((0, 0), (20, 0), (20, 20), (0, 0)), - ((0, 0), (30, 0), (30, 30), (0, 0)) + ((0, 0), (20, 0), (20, 20), (0, 0)), ((0, 0), (30, 0), (30, 30), (0, 0)) ) assert row.multipolygon.groups[1] == geo.Polygon( - ((1, 1), (21, 1), (21, 21), (1, 1)), - ((1, 1), (31, 1), (31, 31), (1, 1)) + ((1, 1), (21, 1), (21, 21), (1, 1)), ((1, 1), (31, 1), (31, 31), (1, 1)) ) @@ -226,15 +173,9 @@ def test_gis_select(db): multipoint=geo.MultiPoint((1, 1), (2, 2)), multiline=geo.MultiLine(((1, 1), (2, 2), (3, 3)), ((1, 1), (4, 4), (5, 5))), multipolygon=geo.MultiPolygon( - ( - ((0, 0), (20, 0), (20, 20), (0, 0)), - ((0, 0), (30, 0), (30, 30), (0, 0)) - ), - ( - ((1, 1), (21, 1), (21, 21), (1, 1)), - ((1, 1), (31, 1), (31, 31), (1, 1)) - ) - ) + (((0, 0), (20, 0), (20, 20), (0, 0)), ((0, 0), (30, 0), (30, 30), (0, 0))), + (((1, 1), (21, 1), (21, 21), (1, 1)), ((1, 1), (31, 1), (31, 31), (1, 1))), + ), ) row.save() row = model.get(row.id) @@ -248,9 +189,7 @@ def test_gis_select(db): assert not row.line.groups assert row.polygon.geometry == "POLYGON" - assert row.polygon.coordinates == ( - ((0, 0), (150, 0), (150, 10), (0, 10), (0, 0)), - ) + assert row.polygon.coordinates == (((0, 0), (150, 0), (150, 10), (0, 10), (0, 0)),) assert not row.polygon.groups assert row.multipoint.geometry == "MULTIPOINT" @@ -260,31 +199,20 @@ def test_gis_select(db): assert row.multipoint.groups[1] == geo.Point(2, 2) assert row.multiline.geometry == "MULTILINESTRING" - assert row.multiline.coordinates == ( - ((1, 1), (2, 2), (3, 3)), - ((1, 1), (4, 4), (5, 5)) - ) + assert row.multiline.coordinates == (((1, 1), (2, 2), (3, 3)), ((1, 1), (4, 4), (5, 5))) assert len(row.multiline.groups) == 2 assert row.multiline.groups[0] == geo.Line((1, 1), (2, 2), (3, 3)) assert row.multiline.groups[1] == geo.Line((1, 1), (4, 4), (5, 5)) assert row.multipolygon.geometry == "MULTIPOLYGON" assert row.multipolygon.coordinates == ( - ( - ((0, 0), (20, 0), (20, 20), (0, 0)), - ((0, 0), (30, 0), (30, 30), (0, 0)) - ), - ( - ((1, 1), (21, 1), (21, 21), (1, 1)), - ((1, 1), (31, 1), (31, 31), (1, 1)) - ) + (((0, 0), (20, 0), (20, 20), (0, 0)), ((0, 0), (30, 0), (30, 30), (0, 0))), + (((1, 1), (21, 1), (21, 21), (1, 1)), ((1, 1), (31, 1), (31, 31), (1, 1))), ) assert len(row.multipolygon.groups) == 2 assert row.multipolygon.groups[0] == geo.Polygon( - ((0, 0), (20, 0), (20, 20), (0, 0)), - ((0, 0), (30, 0), (30, 30), (0, 0)) + ((0, 0), (20, 0), (20, 20), (0, 0)), ((0, 0), (30, 0), (30, 30), (0, 0)) ) assert row.multipolygon.groups[1] == geo.Polygon( - ((1, 1), (21, 1), (21, 21), (1, 1)), - ((1, 1), (31, 1), (31, 31), (1, 1)) + ((1, 1), (21, 1), (21, 21), (1, 1)), ((1, 1), (31, 1), (31, 31), (1, 1)) ) diff --git a/tests/test_orm_pks.py b/tests/test_orm_pks.py index 54b12f8f..f877ff6f 100644 --- a/tests/test_orm_pks.py +++ b/tests/test_orm_pks.py @@ -1,25 +1,24 @@ # -*- coding: utf-8 -*- """ - tests.orm_pks - ------------- +tests.orm_pks +------------- - Test ORM primary keys hendling +Test ORM primary keys hendling """ import os -import pytest - from uuid import uuid4 +import pytest + from emmett import App, sdict -from emmett.orm import Database, Model, Field, belongs_to, has_many +from emmett.orm import Database, Field, Model, belongs_to, has_many from emmett.orm.errors import SaveException from emmett.orm.helpers import RowReferenceMixin from emmett.orm.migrations.utils import generate_runtime_migration -require_postgres = pytest.mark.skipif( - not os.environ.get("POSTGRES_URI"), reason="No postgres database" -) + +require_postgres = pytest.mark.skipif(not os.environ.get("POSTGRES_URI"), reason="No postgres database") class Standard(Model): @@ -101,7 +100,7 @@ class DoctorCustom(Model): has_many( {"appointments": "AppointmentCustom"}, {"patients": {"via": "appointments.patient_custom"}}, - {"symptoms_to_treat": {"via": "patients.symptoms"}} + {"symptoms_to_treat": {"via": "patients.symptoms"}}, ) id = Field.string(default=lambda: uuid4().hex) @@ -114,7 +113,7 @@ class DoctorMulti(Model): has_many( {"appointments": "AppointmentMulti"}, {"patients": {"via": "appointments.patient_multi"}}, - {"symptoms_to_treat": {"via": "patients.symptoms"}} + {"symptoms_to_treat": {"via": "patients.symptoms"}}, ) foo = Field.string(default=lambda: uuid4().hex) @@ -128,7 +127,7 @@ class PatientCustom(Model): has_many( {"appointments": "AppointmentCustom"}, {"symptoms": "SymptomCustom.patient"}, - {"doctors": {"via": "appointments.doctor_custom"}} + {"doctors": {"via": "appointments.doctor_custom"}}, ) code = Field.string(default=lambda: uuid4().hex) @@ -141,7 +140,7 @@ class PatientMulti(Model): has_many( {"appointments": "AppointmentMulti"}, {"symptoms": "SymptomMulti.patient"}, - {"doctors": {"via": "appointments.doctor_multi"}} + {"doctors": {"via": "appointments.doctor_multi"}}, ) foo = Field.string(default=lambda: uuid4().hex) @@ -185,35 +184,18 @@ class AppointmentMulti(Model): name = Field.string() -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def _db(): app = App(__name__) - db = Database( - app, - config=sdict( - uri=f'sqlite://{uuid4().hex}.db', - auto_connect=True - ) - ) - db.define_models( - Standard, - CustomType, - CustomName, - CustomMulti - ) + db = Database(app, config=sdict(uri=f"sqlite://{uuid4().hex}.db", auto_connect=True)) + db.define_models(Standard, CustomType, CustomName, CustomMulti) return db -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def _pgs(): app = App(__name__) - db = Database( - app, - config=sdict( - uri=f"postgres://{os.environ.get('POSTGRES_URI')}", - auto_connect=True - ) - ) + db = Database(app, config=sdict(uri=f"postgres://{os.environ.get('POSTGRES_URI')}", auto_connect=True)) db.define_models( SourceCustom, SourceMulti, @@ -228,12 +210,12 @@ def _pgs(): PatientMulti, AppointmentMulti, SymptomCustom, - SymptomMulti + SymptomMulti, ) return db -@pytest.fixture(scope='function') +@pytest.fixture(scope="function") def db(_db): migration = generate_runtime_migration(_db) migration.up() @@ -241,7 +223,7 @@ def db(_db): migration.down() -@pytest.fixture(scope='function') +@pytest.fixture(scope="function") def pgs(_pgs): migration = generate_runtime_migration(_pgs) migration.up() @@ -282,7 +264,7 @@ def test_save_insert(db): assert done assert row._concrete assert row.id - assert type(row.id) == int + assert type(row.id) is int row = CustomType.new(id="test1", foo="test2", bar="test3") done = row.save() @@ -770,9 +752,7 @@ def test_row(pgs): assert sm2.dest_multi_multis.count() == 0 dmm1.source_multi = sm2 - assert set(dmm1._changes.keys()).issubset( - {"source_multi", "source_multi_foo", "source_multi_bar"} - ) + assert set(dmm1._changes.keys()).issubset({"source_multi", "source_multi_foo", "source_multi_bar"}) dmm1.save() assert sm1.dest_multi_multis.count() == 0 assert sm2.dest_multi_multis.count() == 1 diff --git a/tests/test_orm_row.py b/tests/test_orm_row.py index 86367d5f..645f3235 100644 --- a/tests/test_orm_row.py +++ b/tests/test_orm_row.py @@ -1,20 +1,18 @@ # -*- coding: utf-8 -*- """ - tests.orm_row - ------------- +tests.orm_row +------------- - Test ORM row objects +Test ORM row objects """ import pickle -import pytest - from uuid import uuid4 -from emmett import App, sdict, now -from emmett.orm import ( - Database, Model, Field, belongs_to, has_many, refers_to, rowattr, rowmethod -) +import pytest + +from emmett import App, now, sdict +from emmett.orm import Database, Field, Model, belongs_to, has_many, refers_to, rowattr, rowmethod from emmett.orm.errors import ValidationError from emmett.orm.helpers import RowReferenceMixin from emmett.orm.migrations.utils import generate_runtime_migration @@ -70,20 +68,15 @@ class Crypted(Model): bar = Field.password() -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def _db(): app = App(__name__) - db = Database( - app, - config=sdict( - uri=f'sqlite://{uuid4().hex}.db', - auto_connect=True - ) - ) + db = Database(app, config=sdict(uri=f"sqlite://{uuid4().hex}.db", auto_connect=True)) db.define_models(One, Two, Three, Override, Crypted) return db -@pytest.fixture(scope='function') + +@pytest.fixture(scope="function") def db(_db): migration = generate_runtime_migration(_db) migration.up() @@ -96,80 +89,80 @@ def test_rowclass(db): db.Two.insert(one=ret, foo="test1", bar="test2") ret._allocate_() - assert type(ret._refrecord) == One._instance_()._rowclass_ + assert type(ret._refrecord) is One._instance_()._rowclass_ row = One.get(ret.id) - assert type(row) == One._instance_()._rowclass_ + assert type(row) is One._instance_()._rowclass_ row = One.first() - assert type(row) == One._instance_()._rowclass_ + assert type(row) is One._instance_()._rowclass_ row = db(db.One).select().first() - assert type(row) == One._instance_()._rowclass_ + assert type(row) is One._instance_()._rowclass_ row = db(db.One).select(db.One.ALL).first() - assert type(row) == One._instance_()._rowclass_ + assert type(row) is One._instance_()._rowclass_ row = One.all().select().first() - assert type(row) == One._instance_()._rowclass_ + assert type(row) is One._instance_()._rowclass_ row = One.where(lambda m: m.id != None).select().first() - assert type(row) == One._instance_()._rowclass_ + assert type(row) is One._instance_()._rowclass_ row = db(db.One).select().first() - assert type(row) == One._instance_()._rowclass_ + assert type(row) is One._instance_()._rowclass_ row = db(db.One).select(db.One.ALL).first() - assert type(row) == One._instance_()._rowclass_ + assert type(row) is One._instance_()._rowclass_ row = One.all().select(One.bar).first() - assert type(row) == Row + assert type(row) is Row row = db(db.One).select(One.bar).first() - assert type(row) == Row + assert type(row) is Row row = One.all().join("twos").select().first() - assert type(row) == One._instance_()._rowclass_ - assert type(row.twos().first()) == Two._instance_()._rowclass_ + assert type(row) is One._instance_()._rowclass_ + assert type(row.twos().first()) is Two._instance_()._rowclass_ row = One.all().join("twos").select(One.table.ALL, Two.table.ALL).first() - assert type(row) == One._instance_()._rowclass_ - assert type(row.twos().first()) == Two._instance_()._rowclass_ + assert type(row) is One._instance_()._rowclass_ + assert type(row.twos().first()) is Two._instance_()._rowclass_ row = One.all().join("twos").select(One.foo, Two.foo).first() - assert type(row) == Row - assert type(row.ones) == Row - assert type(row.twos) == Row + assert type(row) is Row + assert type(row.ones) is Row + assert type(row.twos) is Row row = db(Two.one == One.id).select().first() - assert type(row) == Row - assert type(row.ones) == One._instance_()._rowclass_ - assert type(row.twos) == Two._instance_()._rowclass_ + assert type(row) is Row + assert type(row.ones) is One._instance_()._rowclass_ + assert type(row.twos) is Two._instance_()._rowclass_ row = db(Two.one == One.id).select(One.table.ALL, Two.foo).first() - assert type(row) == Row - assert type(row.ones) == One._instance_()._rowclass_ - assert type(row.twos) == Row + assert type(row) is Row + assert type(row.ones) is One._instance_()._rowclass_ + assert type(row.twos) is Row row = db(Two.one == One.id).select(One.foo, Two.foo).first() - assert type(row) == Row - assert type(row.ones) == Row - assert type(row.twos) == Row + assert type(row) is Row + assert type(row.ones) is Row + assert type(row.twos) is Row for row in db(Two.one == One.id).iterselect(): - assert type(row) == Row - assert type(row.ones) == One._instance_()._rowclass_ - assert type(row.twos) == Two._instance_()._rowclass_ + assert type(row) is Row + assert type(row.ones) is One._instance_()._rowclass_ + assert type(row.twos) is Two._instance_()._rowclass_ for row in db(Two.one == One.id).iterselect(One.table.ALL, Two.foo): - assert type(row) == Row - assert type(row.ones) == One._instance_()._rowclass_ - assert type(row.twos) == Row + assert type(row) is Row + assert type(row.ones) is One._instance_()._rowclass_ + assert type(row.twos) is Row for row in db(Two.one == One.id).iterselect(One.foo, Two.foo): - assert type(row) == Row - assert type(row.ones) == Row - assert type(row.twos) == Row + assert type(row) is Row + assert type(row.ones) is Row + assert type(row.twos) is Row def test_concrete(db): diff --git a/tests/test_orm_transactions.py b/tests/test_orm_transactions.py index 782de130..194d3749 100644 --- a/tests/test_orm_transactions.py +++ b/tests/test_orm_transactions.py @@ -1,9 +1,9 @@ # -*- coding: utf-8 -*- """ - tests.orm_transactions - ---------------------- +tests.orm_transactions +---------------------- - Test pyDAL transactions implementation over Emmett. +Test pyDAL transactions implementation over Emmett. """ import pytest @@ -16,17 +16,15 @@ class Register(Model): value = Field.int() -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def db(): app = App(__name__) - db = Database( - app, config=sdict( - uri='sqlite:memory', auto_migrate=True, auto_connect=True)) + db = Database(app, config=sdict(uri="sqlite:memory", auto_migrate=True, auto_connect=True)) db.define_models(Register) return db -@pytest.fixture(scope='function') +@pytest.fixture(scope="function") def cleanup(request, db): def teardown(): Register.all().delete() @@ -41,7 +39,7 @@ def _save(*vals): def _values_in_register(*vals): - db_vals = Register.all().select(orderby=Register.value).column('value') + db_vals = Register.all().select(orderby=Register.value).column("value") return db_vals == list(vals) diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 559f3385..986deac6 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -1,26 +1,27 @@ # -*- coding: utf-8 -*- """ - tests.pipeline - -------------- +tests.pipeline +-------------- - Test Emmett pipeline +Test Emmett pipeline """ import asyncio -import pytest - from contextlib import contextmanager +import pytest from helpers import current_ctx as _current_ctx, ws_ctx as _ws_ctx -from emmett import App, request, websocket, abort + +from emmett import App, abort, request, websocket from emmett.ctx import current from emmett.http import HTTP -from emmett.pipeline import Pipe, Injector from emmett.parsers import Parsers -from emmett.serializers import Serializers, _json_type +from emmett.pipeline import Injector, Pipe +from emmett.serializers import Serializers -json_load = Parsers.get_for('json') -json_dump = Serializers.get_for('json') + +json_load = Parsers.get_for("json") +json_dump = Serializers.get_for("json") class PipeException(Exception): @@ -44,43 +45,43 @@ def store_parallel(self, status): self.parallel_storage.append(self.__class__.__name__ + "." + status) async def on_pipe_success(self): - self.store_linear('success') + self.store_linear("success") async def on_pipe_failure(self): - self.store_linear('failure') + self.store_linear("failure") class FlowStorePipeCommon(FlowStorePipe): async def open(self): - self.store_parallel('open') + self.store_parallel("open") async def close(self): - self.store_parallel('close') + self.store_parallel("close") async def pipe(self, next_pipe, **kwargs): - self.store_linear('pipe') + self.store_linear("pipe") return await next_pipe(**kwargs) class FlowStorePipeSplit(FlowStorePipe): async def open_request(self): - self.store_parallel('open_request') + self.store_parallel("open_request") async def open_ws(self): - self.store_parallel('open_ws') + self.store_parallel("open_ws") async def close_request(self): - self.store_parallel('close_request') + self.store_parallel("close_request") async def close_ws(self): - self.store_parallel('close_ws') + self.store_parallel("close_ws") async def pipe_request(self, next_pipe, **kwargs): - self.store_linear('pipe_request') + self.store_linear("pipe_request") return await next_pipe(**kwargs) async def pipe_ws(self, next_pipe, **kwargs): - self.store_linear('pipe_ws') + self.store_linear("pipe_ws") return await next_pipe(**kwargs) @@ -90,13 +91,13 @@ class Pipe1(FlowStorePipeCommon): class Pipe2(FlowStorePipeSplit): async def pipe_request(self, next_pipe, **kwargs): - self.store_linear('pipe_request') + self.store_linear("pipe_request") if request.query_params.skip: return "block" return await next_pipe(**kwargs) async def pipe_ws(self, next_pipe, **kwargs): - self.store_linear('pipe_ws') + self.store_linear("pipe_ws") if websocket.query_params.skip: return await next_pipe(**kwargs) @@ -139,18 +140,18 @@ async def close(self): class PipeSR1(FlowStorePipeSplit): def on_receive(self, data): data = json_load(data) - return dict(pipe1r='receive_inject', **data) + return dict(pipe1r="receive_inject", **data) def on_send(self, data): - return json_dump(dict(pipe1s='send_inject', **data)) + return json_dump(dict(pipe1s="send_inject", **data)) class PipeSR2(FlowStorePipeSplit): def on_receive(self, data): - return dict(pipe2r='receive_inject', **data) + return dict(pipe2r="receive_inject", **data) def on_send(self, data): - return dict(pipe2s='send_inject', **data) + return dict(pipe2s="send_inject", **data) class CTXInjector(Injector): @@ -250,7 +251,7 @@ def parallel_flows_are_equal(flow, ctx): return set(flow) == set(ctx._pipeline_parallel_storage) -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def app(): app = App(__name__) app.pipeline = [Pipe1(), Pipe2(), Pipe3()] @@ -269,11 +270,11 @@ def error(): @app.route(pipeline=[ExcPipeOpen(), Pipe4()]) def open_error(): - return '' + return "" @app.route(pipeline=[ExcPipeClose(), Pipe4()]) def close_error(): - return '' + return "" @app.route(pipeline=[Pipe4()]) def pipe4(): @@ -281,7 +282,7 @@ def pipe4(): @app.websocket() async def ws_ok(): - await websocket.send('ok') + await websocket.send("ok") @app.websocket() def ws_error(): @@ -305,7 +306,7 @@ async def ws_inject(): current._receive_storage.append(data) await websocket.send(data) - mod = app.module(__name__, 'mod', url_prefix='mod') + mod = app.module(__name__, "mod", url_prefix="mod") mod.pipeline = [Pipe5()] @mod.route() @@ -324,15 +325,15 @@ def ws_pipe5(): def ws_pipe6(): return - inj = app.module(__name__, 'inj', url_prefix='inj') + inj = app.module(__name__, "inj", url_prefix="inj") inj.pipeline = [GlobalInjector(), ScopedInjector()] - @inj.route(template='test.html') + @inj.route(template="test.html") def injpipe(): - return {'posts': []} + return {"posts": []} - mg1 = app.module(__name__, 'mg1', url_prefix='mg1') - mg2 = app.module(__name__, 'mg2', url_prefix='mg2') + mg1 = app.module(__name__, "mg1", url_prefix="mg1") + mg2 = app.module(__name__, "mg2", url_prefix="mg2") mg1.pipeline = [Pipe5()] mg2.pipeline = [Pipe6()] mg = app.module_group(mg1, mg2) @@ -341,7 +342,7 @@ def injpipe(): async def pipe_mg(): return "mg" - mgc = mg.module(__name__, 'mgc', url_prefix='mgc') + mgc = mg.module(__name__, "mgc", url_prefix="mgc") mgc.pipeline = [Pipe7()] @mgc.route() @@ -353,24 +354,30 @@ async def pipe_mgc(): @pytest.mark.asyncio async def test_ok_flow(app): - with request_ctx(app, '/ok') as ctx: + with request_ctx(app, "/ok") as ctx: parallel_flow = [ - 'Pipe1.open', 'Pipe2.open_request', 'Pipe3.open', - 'Pipe3.close', 'Pipe2.close_request', 'Pipe1.close'] + "Pipe1.open", + "Pipe2.open_request", + "Pipe3.open", + "Pipe3.close", + "Pipe2.close_request", + "Pipe1.close", + ] linear_flow = [ - 'Pipe1.pipe', 'Pipe2.pipe_request', 'Pipe3.pipe', - 'Pipe3.success', 'Pipe2.success', 'Pipe1.success'] + "Pipe1.pipe", + "Pipe2.pipe_request", + "Pipe3.pipe", + "Pipe3.success", + "Pipe2.success", + "Pipe1.success", + ] await ctx.dispatch() assert linear_flows_are_equal(linear_flow, ctx.ctx) assert parallel_flows_are_equal(parallel_flow, ctx.ctx) - with ws_ctx(app, '/ws_ok') as ctx: - parallel_flow = [ - 'Pipe1.open', 'Pipe2.open_ws', 'Pipe3.open', - 'Pipe3.close', 'Pipe2.close_ws', 'Pipe1.close'] - linear_flow = [ - 'Pipe1.pipe', 'Pipe2.pipe_ws', 'Pipe3.pipe', - 'Pipe3.success', 'Pipe2.success', 'Pipe1.success'] + with ws_ctx(app, "/ws_ok") as ctx: + parallel_flow = ["Pipe1.open", "Pipe2.open_ws", "Pipe3.open", "Pipe3.close", "Pipe2.close_ws", "Pipe1.close"] + linear_flow = ["Pipe1.pipe", "Pipe2.pipe_ws", "Pipe3.pipe", "Pipe3.success", "Pipe2.success", "Pipe1.success"] await ctx.dispatch() assert linear_flows_are_equal(linear_flow, ctx.ctx) assert parallel_flows_are_equal(parallel_flow, ctx.ctx) @@ -378,13 +385,23 @@ async def test_ok_flow(app): @pytest.mark.asyncio async def test_httperror_flow(app): - with request_ctx(app, '/http_error') as ctx: + with request_ctx(app, "/http_error") as ctx: parallel_flow = [ - 'Pipe1.open', 'Pipe2.open_request', 'Pipe3.open', - 'Pipe3.close', 'Pipe2.close_request', 'Pipe1.close'] + "Pipe1.open", + "Pipe2.open_request", + "Pipe3.open", + "Pipe3.close", + "Pipe2.close_request", + "Pipe1.close", + ] linear_flow = [ - 'Pipe1.pipe', 'Pipe2.pipe_request', 'Pipe3.pipe', - 'Pipe3.success', 'Pipe2.success', 'Pipe1.success'] + "Pipe1.pipe", + "Pipe2.pipe_request", + "Pipe3.pipe", + "Pipe3.success", + "Pipe2.success", + "Pipe1.success", + ] try: await ctx.dispatch() except HTTP: @@ -395,13 +412,23 @@ async def test_httperror_flow(app): @pytest.mark.asyncio async def test_error_flow(app): - with request_ctx(app, '/error') as ctx: + with request_ctx(app, "/error") as ctx: parallel_flow = [ - 'Pipe1.open', 'Pipe2.open_request', 'Pipe3.open', - 'Pipe3.close', 'Pipe2.close_request', 'Pipe1.close'] + "Pipe1.open", + "Pipe2.open_request", + "Pipe3.open", + "Pipe3.close", + "Pipe2.close_request", + "Pipe1.close", + ] linear_flow = [ - 'Pipe1.pipe', 'Pipe2.pipe_request', 'Pipe3.pipe', - 'Pipe3.failure', 'Pipe2.failure', 'Pipe1.failure'] + "Pipe1.pipe", + "Pipe2.pipe_request", + "Pipe3.pipe", + "Pipe3.failure", + "Pipe2.failure", + "Pipe1.failure", + ] try: await ctx.dispatch() except Exception: @@ -409,13 +436,9 @@ async def test_error_flow(app): assert linear_flows_are_equal(linear_flow, ctx.ctx) assert parallel_flows_are_equal(parallel_flow, ctx.ctx) - with ws_ctx(app, '/ws_error') as ctx: - parallel_flow = [ - 'Pipe1.open', 'Pipe2.open_ws', 'Pipe3.open', - 'Pipe3.close', 'Pipe2.close_ws', 'Pipe1.close'] - linear_flow = [ - 'Pipe1.pipe', 'Pipe2.pipe_ws', 'Pipe3.pipe', - 'Pipe3.failure', 'Pipe2.failure', 'Pipe1.failure'] + with ws_ctx(app, "/ws_error") as ctx: + parallel_flow = ["Pipe1.open", "Pipe2.open_ws", "Pipe3.open", "Pipe3.close", "Pipe2.close_ws", "Pipe1.close"] + linear_flow = ["Pipe1.pipe", "Pipe2.pipe_ws", "Pipe3.pipe", "Pipe3.failure", "Pipe2.failure", "Pipe1.failure"] try: await ctx.dispatch() except Exception: @@ -426,9 +449,8 @@ async def test_error_flow(app): @pytest.mark.asyncio async def test_open_error(app): - with request_ctx(app, '/open_error') as ctx: - parallel_flow = [ - 'Pipe1.open', 'Pipe2.open_request', 'Pipe3.open', 'Pipe4.open'] + with request_ctx(app, "/open_error") as ctx: + parallel_flow = ["Pipe1.open", "Pipe2.open_request", "Pipe3.open", "Pipe4.open"] linear_flow = [] try: await ctx.dispatch() @@ -437,9 +459,8 @@ async def test_open_error(app): assert linear_flows_are_equal(linear_flow, ctx.ctx) assert parallel_flows_are_equal(parallel_flow, ctx.ctx) - with ws_ctx(app, '/ws_open_error') as ctx: - parallel_flow = [ - 'Pipe1.open', 'Pipe2.open_ws', 'Pipe3.open', 'Pipe4.open'] + with ws_ctx(app, "/ws_open_error") as ctx: + parallel_flow = ["Pipe1.open", "Pipe2.open_ws", "Pipe3.open", "Pipe4.open"] linear_flow = [] try: await ctx.dispatch() @@ -451,16 +472,30 @@ async def test_open_error(app): @pytest.mark.asyncio async def test_close_error(app): - with request_ctx(app, '/close_error') as ctx: + with request_ctx(app, "/close_error") as ctx: parallel_flow = [ - 'Pipe1.open', 'Pipe2.open_request', 'Pipe3.open', - 'ExcPipeClose.open', 'Pipe4.open', - 'Pipe4.close', 'Pipe3.close', 'Pipe2.close_request', 'Pipe1.close'] + "Pipe1.open", + "Pipe2.open_request", + "Pipe3.open", + "ExcPipeClose.open", + "Pipe4.open", + "Pipe4.close", + "Pipe3.close", + "Pipe2.close_request", + "Pipe1.close", + ] linear_flow = [ - 'Pipe1.pipe', 'Pipe2.pipe_request', 'Pipe3.pipe', - 'ExcPipeClose.pipe', 'Pipe4.pipe', - 'Pipe4.success', 'ExcPipeClose.success', 'Pipe3.success', - 'Pipe2.success', 'Pipe1.success'] + "Pipe1.pipe", + "Pipe2.pipe_request", + "Pipe3.pipe", + "ExcPipeClose.pipe", + "Pipe4.pipe", + "Pipe4.success", + "ExcPipeClose.success", + "Pipe3.success", + "Pipe2.success", + "Pipe1.success", + ] try: await ctx.dispatch() except PipeException as e: @@ -468,16 +503,30 @@ async def test_close_error(app): assert linear_flows_are_equal(linear_flow, ctx.ctx) assert parallel_flows_are_equal(parallel_flow, ctx.ctx) - with ws_ctx(app, '/ws_close_error') as ctx: + with ws_ctx(app, "/ws_close_error") as ctx: parallel_flow = [ - 'Pipe1.open', 'Pipe2.open_ws', 'Pipe3.open', - 'ExcPipeClose.open', 'Pipe4.open', - 'Pipe4.close', 'Pipe3.close', 'Pipe2.close_ws', 'Pipe1.close'] + "Pipe1.open", + "Pipe2.open_ws", + "Pipe3.open", + "ExcPipeClose.open", + "Pipe4.open", + "Pipe4.close", + "Pipe3.close", + "Pipe2.close_ws", + "Pipe1.close", + ] linear_flow = [ - 'Pipe1.pipe', 'Pipe2.pipe_ws', 'Pipe3.pipe', - 'ExcPipeClose.pipe', 'Pipe4.pipe', - 'Pipe4.success', 'ExcPipeClose.success', 'Pipe3.success', - 'Pipe2.success', 'Pipe1.success'] + "Pipe1.pipe", + "Pipe2.pipe_ws", + "Pipe3.pipe", + "ExcPipeClose.pipe", + "Pipe4.pipe", + "Pipe4.success", + "ExcPipeClose.success", + "Pipe3.success", + "Pipe2.success", + "Pipe1.success", + ] try: await ctx.dispatch() except PipeException as e: @@ -488,24 +537,23 @@ async def test_close_error(app): @pytest.mark.asyncio async def test_flow_interrupt(app): - with request_ctx(app, '/ok?skip=yes') as ctx: + with request_ctx(app, "/ok?skip=yes") as ctx: parallel_flow = [ - 'Pipe1.open', 'Pipe2.open_request', 'Pipe3.open', - 'Pipe3.close', 'Pipe2.close_request', 'Pipe1.close'] - linear_flow = [ - 'Pipe1.pipe', 'Pipe2.pipe_request', - 'Pipe2.success', 'Pipe1.success'] + "Pipe1.open", + "Pipe2.open_request", + "Pipe3.open", + "Pipe3.close", + "Pipe2.close_request", + "Pipe1.close", + ] + linear_flow = ["Pipe1.pipe", "Pipe2.pipe_request", "Pipe2.success", "Pipe1.success"] await ctx.dispatch() assert linear_flows_are_equal(linear_flow, ctx.ctx) assert parallel_flows_are_equal(parallel_flow, ctx.ctx) - with ws_ctx(app, '/ws_ok?skip=yes') as ctx: - parallel_flow = [ - 'Pipe1.open', 'Pipe2.open_ws', 'Pipe3.open', - 'Pipe3.close', 'Pipe2.close_ws', 'Pipe1.close'] - linear_flow = [ - 'Pipe1.pipe', 'Pipe2.pipe_ws', - 'Pipe2.success', 'Pipe1.success'] + with ws_ctx(app, "/ws_ok?skip=yes") as ctx: + parallel_flow = ["Pipe1.open", "Pipe2.open_ws", "Pipe3.open", "Pipe3.close", "Pipe2.close_ws", "Pipe1.close"] + linear_flow = ["Pipe1.pipe", "Pipe2.pipe_ws", "Pipe2.success", "Pipe1.success"] await ctx.dispatch() assert linear_flows_are_equal(linear_flow, ctx.ctx) assert parallel_flows_are_equal(parallel_flow, ctx.ctx) @@ -513,24 +561,52 @@ async def test_flow_interrupt(app): @pytest.mark.asyncio async def test_pipeline_composition(app): - with request_ctx(app, '/pipe4') as ctx: + with request_ctx(app, "/pipe4") as ctx: parallel_flow = [ - 'Pipe1.open', 'Pipe2.open_request', 'Pipe3.open', 'Pipe4.open', - 'Pipe4.close', 'Pipe3.close', 'Pipe2.close_request', 'Pipe1.close'] + "Pipe1.open", + "Pipe2.open_request", + "Pipe3.open", + "Pipe4.open", + "Pipe4.close", + "Pipe3.close", + "Pipe2.close_request", + "Pipe1.close", + ] linear_flow = [ - 'Pipe1.pipe', 'Pipe2.pipe_request', 'Pipe3.pipe', 'Pipe4.pipe', - 'Pipe4.success', 'Pipe3.success', 'Pipe2.success', 'Pipe1.success'] + "Pipe1.pipe", + "Pipe2.pipe_request", + "Pipe3.pipe", + "Pipe4.pipe", + "Pipe4.success", + "Pipe3.success", + "Pipe2.success", + "Pipe1.success", + ] await ctx.dispatch() assert linear_flows_are_equal(linear_flow, ctx.ctx) assert parallel_flows_are_equal(parallel_flow, ctx.ctx) - with ws_ctx(app, '/ws_pipe4') as ctx: + with ws_ctx(app, "/ws_pipe4") as ctx: parallel_flow = [ - 'Pipe1.open', 'Pipe2.open_ws', 'Pipe3.open', 'Pipe4.open', - 'Pipe4.close', 'Pipe3.close', 'Pipe2.close_ws', 'Pipe1.close'] + "Pipe1.open", + "Pipe2.open_ws", + "Pipe3.open", + "Pipe4.open", + "Pipe4.close", + "Pipe3.close", + "Pipe2.close_ws", + "Pipe1.close", + ] linear_flow = [ - 'Pipe1.pipe', 'Pipe2.pipe_ws', 'Pipe3.pipe', 'Pipe4.pipe', - 'Pipe4.success', 'Pipe3.success', 'Pipe2.success', 'Pipe1.success'] + "Pipe1.pipe", + "Pipe2.pipe_ws", + "Pipe3.pipe", + "Pipe4.pipe", + "Pipe4.success", + "Pipe3.success", + "Pipe2.success", + "Pipe1.success", + ] await ctx.dispatch() assert linear_flows_are_equal(linear_flow, ctx.ctx) assert parallel_flows_are_equal(parallel_flow, ctx.ctx) @@ -538,24 +614,52 @@ async def test_pipeline_composition(app): @pytest.mark.asyncio async def test_module_pipeline(app): - with request_ctx(app, '/mod/pipe5') as ctx: + with request_ctx(app, "/mod/pipe5") as ctx: parallel_flow = [ - 'Pipe1.open', 'Pipe2.open_request', 'Pipe3.open', 'Pipe5.open', - 'Pipe5.close', 'Pipe3.close', 'Pipe2.close_request', 'Pipe1.close'] + "Pipe1.open", + "Pipe2.open_request", + "Pipe3.open", + "Pipe5.open", + "Pipe5.close", + "Pipe3.close", + "Pipe2.close_request", + "Pipe1.close", + ] linear_flow = [ - 'Pipe1.pipe', 'Pipe2.pipe_request', 'Pipe3.pipe', 'Pipe5.pipe', - 'Pipe5.success', 'Pipe3.success', 'Pipe2.success', 'Pipe1.success'] + "Pipe1.pipe", + "Pipe2.pipe_request", + "Pipe3.pipe", + "Pipe5.pipe", + "Pipe5.success", + "Pipe3.success", + "Pipe2.success", + "Pipe1.success", + ] await ctx.dispatch() assert linear_flows_are_equal(linear_flow, ctx.ctx) assert parallel_flows_are_equal(parallel_flow, ctx.ctx) - with ws_ctx(app, '/mod/ws_pipe5') as ctx: + with ws_ctx(app, "/mod/ws_pipe5") as ctx: parallel_flow = [ - 'Pipe1.open', 'Pipe2.open_ws', 'Pipe3.open', 'Pipe5.open', - 'Pipe5.close', 'Pipe3.close', 'Pipe2.close_ws', 'Pipe1.close'] + "Pipe1.open", + "Pipe2.open_ws", + "Pipe3.open", + "Pipe5.open", + "Pipe5.close", + "Pipe3.close", + "Pipe2.close_ws", + "Pipe1.close", + ] linear_flow = [ - 'Pipe1.pipe', 'Pipe2.pipe_ws', 'Pipe3.pipe', 'Pipe5.pipe', - 'Pipe5.success', 'Pipe3.success', 'Pipe2.success', 'Pipe1.success'] + "Pipe1.pipe", + "Pipe2.pipe_ws", + "Pipe3.pipe", + "Pipe5.pipe", + "Pipe5.success", + "Pipe3.success", + "Pipe2.success", + "Pipe1.success", + ] await ctx.dispatch() assert linear_flows_are_equal(linear_flow, ctx.ctx) assert parallel_flows_are_equal(parallel_flow, ctx.ctx) @@ -563,32 +667,60 @@ async def test_module_pipeline(app): @pytest.mark.asyncio async def test_module_pipeline_composition(app): - with request_ctx(app, '/mod/pipe6') as ctx: + with request_ctx(app, "/mod/pipe6") as ctx: parallel_flow = [ - 'Pipe1.open', 'Pipe2.open_request', 'Pipe3.open', 'Pipe5.open', - 'Pipe6.open', - 'Pipe6.close', 'Pipe5.close', 'Pipe3.close', 'Pipe2.close_request', - 'Pipe1.close'] + "Pipe1.open", + "Pipe2.open_request", + "Pipe3.open", + "Pipe5.open", + "Pipe6.open", + "Pipe6.close", + "Pipe5.close", + "Pipe3.close", + "Pipe2.close_request", + "Pipe1.close", + ] linear_flow = [ - 'Pipe1.pipe', 'Pipe2.pipe_request', 'Pipe3.pipe', 'Pipe5.pipe', - 'Pipe6.pipe', - 'Pipe6.success', 'Pipe5.success', 'Pipe3.success', 'Pipe2.success', - 'Pipe1.success'] + "Pipe1.pipe", + "Pipe2.pipe_request", + "Pipe3.pipe", + "Pipe5.pipe", + "Pipe6.pipe", + "Pipe6.success", + "Pipe5.success", + "Pipe3.success", + "Pipe2.success", + "Pipe1.success", + ] await ctx.dispatch() assert linear_flows_are_equal(linear_flow, ctx.ctx) assert parallel_flows_are_equal(parallel_flow, ctx.ctx) - with ws_ctx(app, '/mod/ws_pipe6') as ctx: + with ws_ctx(app, "/mod/ws_pipe6") as ctx: parallel_flow = [ - 'Pipe1.open', 'Pipe2.open_ws', 'Pipe3.open', 'Pipe5.open', - 'Pipe6.open', - 'Pipe6.close', 'Pipe5.close', 'Pipe3.close', 'Pipe2.close_ws', - 'Pipe1.close'] + "Pipe1.open", + "Pipe2.open_ws", + "Pipe3.open", + "Pipe5.open", + "Pipe6.open", + "Pipe6.close", + "Pipe5.close", + "Pipe3.close", + "Pipe2.close_ws", + "Pipe1.close", + ] linear_flow = [ - 'Pipe1.pipe', 'Pipe2.pipe_ws', 'Pipe3.pipe', 'Pipe5.pipe', - 'Pipe6.pipe', - 'Pipe6.success', 'Pipe5.success', 'Pipe3.success', 'Pipe2.success', - 'Pipe1.success'] + "Pipe1.pipe", + "Pipe2.pipe_ws", + "Pipe3.pipe", + "Pipe5.pipe", + "Pipe6.pipe", + "Pipe6.success", + "Pipe5.success", + "Pipe3.success", + "Pipe2.success", + "Pipe1.success", + ] await ctx.dispatch() assert linear_flows_are_equal(linear_flow, ctx.ctx) assert parallel_flows_are_equal(parallel_flow, ctx.ctx) @@ -596,24 +728,52 @@ async def test_module_pipeline_composition(app): @pytest.mark.asyncio async def test_module_group_pipeline(app): - with request_ctx(app, '/mg1/pipe_mg') as ctx: + with request_ctx(app, "/mg1/pipe_mg") as ctx: parallel_flow = [ - 'Pipe1.open', 'Pipe2.open_request', 'Pipe3.open', 'Pipe5.open', - 'Pipe5.close', 'Pipe3.close', 'Pipe2.close_request', 'Pipe1.close'] + "Pipe1.open", + "Pipe2.open_request", + "Pipe3.open", + "Pipe5.open", + "Pipe5.close", + "Pipe3.close", + "Pipe2.close_request", + "Pipe1.close", + ] linear_flow = [ - 'Pipe1.pipe', 'Pipe2.pipe_request', 'Pipe3.pipe', 'Pipe5.pipe', - 'Pipe5.success', 'Pipe3.success', 'Pipe2.success', 'Pipe1.success'] + "Pipe1.pipe", + "Pipe2.pipe_request", + "Pipe3.pipe", + "Pipe5.pipe", + "Pipe5.success", + "Pipe3.success", + "Pipe2.success", + "Pipe1.success", + ] await ctx.dispatch() assert linear_flows_are_equal(linear_flow, ctx.ctx) assert parallel_flows_are_equal(parallel_flow, ctx.ctx) - with request_ctx(app, '/mg2/pipe_mg') as ctx: + with request_ctx(app, "/mg2/pipe_mg") as ctx: parallel_flow = [ - 'Pipe1.open', 'Pipe2.open_request', 'Pipe3.open', 'Pipe6.open', - 'Pipe6.close', 'Pipe3.close', 'Pipe2.close_request', 'Pipe1.close'] + "Pipe1.open", + "Pipe2.open_request", + "Pipe3.open", + "Pipe6.open", + "Pipe6.close", + "Pipe3.close", + "Pipe2.close_request", + "Pipe1.close", + ] linear_flow = [ - 'Pipe1.pipe', 'Pipe2.pipe_request', 'Pipe3.pipe', 'Pipe6.pipe', - 'Pipe6.success', 'Pipe3.success', 'Pipe2.success', 'Pipe1.success'] + "Pipe1.pipe", + "Pipe2.pipe_request", + "Pipe3.pipe", + "Pipe6.pipe", + "Pipe6.success", + "Pipe3.success", + "Pipe2.success", + "Pipe1.success", + ] await ctx.dispatch() assert linear_flows_are_equal(linear_flow, ctx.ctx) assert parallel_flows_are_equal(parallel_flow, ctx.ctx) @@ -621,32 +781,60 @@ async def test_module_group_pipeline(app): @pytest.mark.asyncio async def test_module_group_pipeline_composition(app): - with request_ctx(app, '/mg1/mgc/pipe_mgc') as ctx: + with request_ctx(app, "/mg1/mgc/pipe_mgc") as ctx: parallel_flow = [ - 'Pipe1.open', 'Pipe2.open_request', 'Pipe3.open', 'Pipe5.open', - 'Pipe7.open', - 'Pipe7.close', 'Pipe5.close', 'Pipe3.close', 'Pipe2.close_request', - 'Pipe1.close'] + "Pipe1.open", + "Pipe2.open_request", + "Pipe3.open", + "Pipe5.open", + "Pipe7.open", + "Pipe7.close", + "Pipe5.close", + "Pipe3.close", + "Pipe2.close_request", + "Pipe1.close", + ] linear_flow = [ - 'Pipe1.pipe', 'Pipe2.pipe_request', 'Pipe3.pipe', 'Pipe5.pipe', - 'Pipe7.pipe', - 'Pipe7.success', 'Pipe5.success', 'Pipe3.success', 'Pipe2.success', - 'Pipe1.success'] + "Pipe1.pipe", + "Pipe2.pipe_request", + "Pipe3.pipe", + "Pipe5.pipe", + "Pipe7.pipe", + "Pipe7.success", + "Pipe5.success", + "Pipe3.success", + "Pipe2.success", + "Pipe1.success", + ] await ctx.dispatch() assert linear_flows_are_equal(linear_flow, ctx.ctx) assert parallel_flows_are_equal(parallel_flow, ctx.ctx) - with request_ctx(app, '/mg2/mgc/pipe_mgc') as ctx: + with request_ctx(app, "/mg2/mgc/pipe_mgc") as ctx: parallel_flow = [ - 'Pipe1.open', 'Pipe2.open_request', 'Pipe3.open', 'Pipe6.open', - 'Pipe7.open', - 'Pipe7.close', 'Pipe6.close', 'Pipe3.close', 'Pipe2.close_request', - 'Pipe1.close'] + "Pipe1.open", + "Pipe2.open_request", + "Pipe3.open", + "Pipe6.open", + "Pipe7.open", + "Pipe7.close", + "Pipe6.close", + "Pipe3.close", + "Pipe2.close_request", + "Pipe1.close", + ] linear_flow = [ - 'Pipe1.pipe', 'Pipe2.pipe_request', 'Pipe3.pipe', 'Pipe6.pipe', - 'Pipe7.pipe', - 'Pipe7.success', 'Pipe6.success', 'Pipe3.success', 'Pipe2.success', - 'Pipe1.success'] + "Pipe1.pipe", + "Pipe2.pipe_request", + "Pipe3.pipe", + "Pipe6.pipe", + "Pipe7.pipe", + "Pipe7.success", + "Pipe6.success", + "Pipe3.success", + "Pipe2.success", + "Pipe1.success", + ] await ctx.dispatch() assert linear_flows_are_equal(linear_flow, ctx.ctx) assert parallel_flows_are_equal(parallel_flow, ctx.ctx) @@ -654,52 +842,61 @@ async def test_module_group_pipeline_composition(app): @pytest.mark.asyncio async def test_receive_send_flow(app): - send_storage_key = { - "str": "text", - "bytes": "bytes" - }[_json_type] - with ws_ctx(app, '/ws_inject') as ctx: + with ws_ctx(app, "/ws_inject") as ctx: parallel_flow = [ - 'Pipe1.open', 'Pipe2.open_ws', 'Pipe3.open', - 'PipeSR1.open_ws', 'PipeSR2.open_ws', - 'PipeSR2.close_ws', 'PipeSR1.close_ws', - 'Pipe3.close', 'Pipe2.close_ws', 'Pipe1.close'] + "Pipe1.open", + "Pipe2.open_ws", + "Pipe3.open", + "PipeSR1.open_ws", + "PipeSR2.open_ws", + "PipeSR2.close_ws", + "PipeSR1.close_ws", + "Pipe3.close", + "Pipe2.close_ws", + "Pipe1.close", + ] linear_flow = [ - 'Pipe1.pipe', 'Pipe2.pipe_ws', 'Pipe3.pipe', - 'PipeSR1.pipe_ws', 'PipeSR2.pipe_ws', - 'PipeSR2.success', 'PipeSR1.success', - 'Pipe3.success', 'Pipe2.success', 'Pipe1.success'] + "Pipe1.pipe", + "Pipe2.pipe_ws", + "Pipe3.pipe", + "PipeSR1.pipe_ws", + "PipeSR2.pipe_ws", + "PipeSR2.success", + "PipeSR1.success", + "Pipe3.success", + "Pipe2.success", + "Pipe1.success", + ] await ctx.dispatch() assert linear_flows_are_equal(linear_flow, ctx.ctx) assert parallel_flows_are_equal(parallel_flow, ctx.ctx) - assert ctx.ctx._receive_storage[-1] == { - 'foo': 'bar', - 'pipe1r': 'receive_inject', 'pipe2r': 'receive_inject' - } - assert json_load(ctx.ctx._send_storage[-1][send_storage_key]) == { - 'foo': 'bar', - 'pipe1r': 'receive_inject', 'pipe2r': 'receive_inject', - 'pipe1s': 'send_inject', 'pipe2s': 'send_inject' + assert ctx.ctx._receive_storage[-1] == {"foo": "bar", "pipe1r": "receive_inject", "pipe2r": "receive_inject"} + assert json_load(ctx.ctx._send_storage[-1]) == { + "foo": "bar", + "pipe1r": "receive_inject", + "pipe2r": "receive_inject", + "pipe1s": "send_inject", + "pipe2s": "send_inject", } @pytest.mark.asyncio async def test_injectors(app): - with request_ctx(app, '/inj/injpipe') as ctx: + with request_ctx(app, "/inj/injpipe") as ctx: current.app = app await ctx.dispatch() env = ctx.ctx._pipeline_generic_storage[0] - assert env['posts'] == [] - assert env['foo'] == "bar" - assert env['bar'] == "baz" - assert env['staticm']("test") == "test" - assert env['boundm']("test") == ("bar", "test") - assert env['prop'] == "baz" + assert env["posts"] == [] + assert env["foo"] == "bar" + assert env["bar"] == "baz" + assert env["staticm"]("test") == "test" + assert env["boundm"]("test") == ("bar", "test") + assert env["prop"] == "baz" - env = env['scoped'] + env = env["scoped"] assert env.foo == "bar" assert env.bar == "baz" assert env.staticm("test") == "test" diff --git a/tests/test_routing.py b/tests/test_routing.py index 67faf506..5005cd2f 100644 --- a/tests/test_routing.py +++ b/tests/test_routing.py @@ -1,21 +1,22 @@ # -*- coding: utf-8 -*- """ - tests.routing - ------------- +tests.routing +------------- - Test Emmett routing module +Test Emmett routing module """ -import pendulum -import pytest - from contextlib import contextmanager +import pendulum +import pytest +from emmett_core.protocols.rsgi.test_client.scope import ScopeBuilder from helpers import FakeRequestContext + from emmett import App, abort, url from emmett.ctx import current -from emmett.http import HTTP -from emmett.testing.env import ScopeBuilder +from emmett.datastructures import sdict +from emmett.http import HTTPResponse @contextmanager @@ -26,257 +27,257 @@ def current_ctx(app, path): current._close_(token) -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def app(): app = App(__name__) - app.languages = ['en', 'it'] - app.language_default = 'en' + app.languages = ["en", "it"] + app.language_default = "en" app.language_force_on_url = True @app.route() def test_route(): - return 'Test Router' + return "Test Router" @app.route() def test_404(): - abort(404, 'Not found, dude') + abort(404, "Not found, dude") - @app.route('/test2//') + @app.route("/test2//") def test_route2(a, b): - return 'Test Router' + return "Test Router" - @app.route('/test3//foo(/)?(.)?') + @app.route("/test3//foo(/)?(.)?") def test_route3(a, b, c): - return 'Test Router' + return "Test Router" - @app.route('/test_int/') + @app.route("/test_int/") def test_route_int(a): - return 'Test Router' + return "Test Router" - @app.route('/test_float/') + @app.route("/test_float/") def test_route_float(a): - return 'Test Router' + return "Test Router" - @app.route('/test_date/') + @app.route("/test_date/") def test_route_date(a): - return 'Test Router' + return "Test Router" - @app.route('/test_alpha/') + @app.route("/test_alpha/") def test_route_alpha(a): - return 'Test Router' + return "Test Router" - @app.route('/test_str/') + @app.route("/test_str/") def test_route_str(a): - return 'Test Router' + return "Test Router" - @app.route('/test_any/') + @app.route("/test_any/") def test_route_any(a): - return 'Test Router' - - @app.route( - '/test_complex' - '/' - '/' - '/' - '/' - '/' - '/' - ) + return "Test Router" + + @app.route("/test_complex" "/" "/" "/" "/" "/" "/") def test_route_complex(a, b, c, d, e, f): - return 'Test Router' + current._reqargs = {"a": a, "b": b, "c": c, "d": d, "e": e, "f": f} + return "Test Router" return app def test_routing(app): - with current_ctx(app, '/test_int') as ctx: + with current_ctx(app, "/test_int") as ctx: route, args = app._router_http.match(ctx.request) assert not route - with current_ctx(app, '/test_int/a') as ctx: + with current_ctx(app, "/test_int/a") as ctx: route, args = app._router_http.match(ctx.request) assert not route - with current_ctx(app, '/test_int/1.1') as ctx: + with current_ctx(app, "/test_int/1.1") as ctx: route, args = app._router_http.match(ctx.request) assert not route - with current_ctx(app, '/test_int/2000-01-01') as ctx: + with current_ctx(app, "/test_int/2000-01-01") as ctx: route, args = app._router_http.match(ctx.request) assert not route - with current_ctx(app, '/test_int/1') as ctx: + with current_ctx(app, "/test_int/1") as ctx: route, args = app._router_http.match(ctx.request) - assert route.name == 'test_routing.test_route_int' + assert route.name == "test_routing.test_route_int" - with current_ctx(app, '/test_float') as ctx: + with current_ctx(app, "/test_float") as ctx: route, args = app._router_http.match(ctx.request) assert not route - with current_ctx(app, '/test_float/a.a') as ctx: + with current_ctx(app, "/test_float/a.a") as ctx: route, args = app._router_http.match(ctx.request) assert not route - with current_ctx(app, '/test_float/1') as ctx: + with current_ctx(app, "/test_float/1") as ctx: route, args = app._router_http.match(ctx.request) assert not route - with current_ctx(app, '/test_float/1.1') as ctx: + with current_ctx(app, "/test_float/1.1") as ctx: route, args = app._router_http.match(ctx.request) - assert route.name == 'test_routing.test_route_float' + assert route.name == "test_routing.test_route_float" - with current_ctx(app, '/test_date') as ctx: + with current_ctx(app, "/test_date") as ctx: route, args = app._router_http.match(ctx.request) assert not route - with current_ctx(app, '/test_date/2000-01-01') as ctx: + with current_ctx(app, "/test_date/2000-01-01") as ctx: route, args = app._router_http.match(ctx.request) - assert route.name == 'test_routing.test_route_date' + assert route.name == "test_routing.test_route_date" - with current_ctx(app, '/test_alpha') as ctx: + with current_ctx(app, "/test_alpha") as ctx: route, args = app._router_http.match(ctx.request) assert not route - with current_ctx(app, '/test_alpha/a1') as ctx: + with current_ctx(app, "/test_alpha/a1") as ctx: route, args = app._router_http.match(ctx.request) assert not route - with current_ctx(app, '/test_alpha/a-a') as ctx: + with current_ctx(app, "/test_alpha/a-a") as ctx: route, args = app._router_http.match(ctx.request) assert not route - with current_ctx(app, '/test_alpha/a') as ctx: + with current_ctx(app, "/test_alpha/a") as ctx: route, args = app._router_http.match(ctx.request) - assert route.name == 'test_routing.test_route_alpha' + assert route.name == "test_routing.test_route_alpha" - with current_ctx(app, '/test_str') as ctx: + with current_ctx(app, "/test_str") as ctx: route, args = app._router_http.match(ctx.request) assert not route - with current_ctx(app, '/test_str/a/b') as ctx: + with current_ctx(app, "/test_str/a/b") as ctx: route, args = app._router_http.match(ctx.request) assert not route - with current_ctx(app, '/test_str/a1-') as ctx: + with current_ctx(app, "/test_str/a1-") as ctx: route, args = app._router_http.match(ctx.request) - assert route.name == 'test_routing.test_route_str' + assert route.name == "test_routing.test_route_str" - with current_ctx(app, '/test_any') as ctx: + with current_ctx(app, "/test_any") as ctx: route, args = app._router_http.match(ctx.request) assert not route - with current_ctx(app, '/test_any/a/b') as ctx: + with current_ctx(app, "/test_any/a/b") as ctx: route, args = app._router_http.match(ctx.request) - assert route.name == 'test_routing.test_route_any' + assert route.name == "test_routing.test_route_any" -def test_route_args(app): - with current_ctx( - app, '/test_complex/1/1.2/2000-12-01/foo/foo1/bar/baz' - ) as ctx: - route, args = app._router_http.match(ctx.request) - assert route.name == 'test_routing.test_route_complex' - assert args['a'] == 1 - assert args['b'] == 1.2 - assert args['c'] == pendulum.datetime(2000, 12, 1) - assert args['d'] == 'foo' - assert args['e'] == 'foo1' - assert args['f'] == 'bar/baz' +@pytest.mark.asyncio +async def test_route_args(app): + with current_ctx(app, "/test_complex/1/1.2/2000-12-01/foo/foo1/bar/baz") as ctx: + # route, args = app._router_http.match(ctx.request) + # assert route.name == 'test_routing.test_route_complex' + # assert args['a'] == 1 + # assert round(args['b'], 1) == 1.2 + # assert args['c'] == pendulum.datetime(2000, 12, 1) + # assert args['d'] == 'foo' + # assert args['e'] == 'foo1' + # assert args['f'] == 'bar/baz' + await app._router_http.dispatch(ctx.request, sdict()) + assert ctx.request.name == "test_routing.test_route_complex" + args = current._reqargs + assert args["a"] == 1 + assert round(args["b"], 1) == 1.2 + assert args["c"] == pendulum.datetime(2000, 12, 1) + assert args["d"] == "foo" + assert args["e"] == "foo1" + assert args["f"] == "bar/baz" @pytest.mark.asyncio async def test_routing_valid_route(app): - with current_ctx(app, '/it/test_route') as ctx: + with current_ctx(app, "/it/test_route") as ctx: response = await app._router_http.dispatch(ctx.request, ctx.response) assert response.status_code == ctx.response.status == 200 - assert response.body == 'Test Router' - assert ctx.request.language == 'it' + assert response.body == "Test Router" + assert ctx.request.language == "it" @pytest.mark.asyncio async def test_routing_not_found_route(app): - with current_ctx(app, '/') as ctx: - with pytest.raises(HTTP) as excinfo: + with current_ctx(app, "/") as ctx: + with pytest.raises(HTTPResponse) as excinfo: await app._router_http.dispatch(ctx.request, ctx.response) assert excinfo.value.status_code == 404 - assert excinfo.value.body == 'Resource not found\n' + assert excinfo.value.body == b"Resource not found" @pytest.mark.asyncio async def test_routing_exception_route(app): - with current_ctx(app, '/test_404') as ctx: - with pytest.raises(HTTP) as excinfo: + with current_ctx(app, "/test_404") as ctx: + with pytest.raises(HTTPResponse) as excinfo: await app._router_http.dispatch(ctx.request, ctx.response) assert excinfo.value.status_code == 404 - assert excinfo.value.body == 'Not found, dude' + assert excinfo.value.body == "Not found, dude" def test_static_url(app): - link = url('static', 'file') - assert link == '/static/file' + link = url("static", "file") + assert link == "/static/file" app.config.static_version_urls = True - app.config.static_version = '1.0.0' - link = url('static', 'js/foo.js', language='it') - assert link == '/it/static/_1.0.0/js/foo.js' + app.config.static_version = "1.0.0" + link = url("static", "js/foo.js", language="it") + assert link == "/it/static/_1.0.0/js/foo.js" def test_module_url(app): - with current_ctx(app, '/') as ctx: - ctx.request.language = 'it' - link = url('test_route') - assert link == '/it/test_route' - link = url('test_route2') - assert link == '/it/test2' - link = url('test_route2', [2]) - assert link == '/it/test2/2' - link = url('test_route2', [2, 'foo']) - assert link == '/it/test2/2/foo' - link = url('test_route3') - assert link == '/it/test3' - link = url('test_route3', [2]) - assert link == '/it/test3/2/foo' - link = url('test_route3', [2, 'bar']) - assert link == '/it/test3/2/foo/bar' - link = url('test_route3', [2, 'bar', 'json']) - assert link == '/it/test3/2/foo/bar.json' - link = url( - 'test_route3', [2, 'bar', 'json'], {'foo': 'bar', 'bar': 'foo'}) - lsplit = link.split('?') - assert lsplit[0] == '/it/test3/2/foo/bar.json' - assert lsplit[1] in ['foo=bar&bar=foo', 'bar=foo&foo=bar'] + with current_ctx(app, "/") as ctx: + ctx.request.language = "it" + link = url("test_route") + assert link == "/it/test_route" + link = url("test_route2") + assert link == "/it/test2" + link = url("test_route2", [2]) + assert link == "/it/test2/2" + link = url("test_route2", [2, "foo"]) + assert link == "/it/test2/2/foo" + link = url("test_route3") + assert link == "/it/test3" + link = url("test_route3", [2]) + assert link == "/it/test3/2/foo" + link = url("test_route3", [2, "bar"]) + assert link == "/it/test3/2/foo/bar" + link = url("test_route3", [2, "bar", "json"]) + assert link == "/it/test3/2/foo/bar.json" + link = url("test_route3", [2, "bar", "json"], {"foo": "bar", "bar": "foo"}) + lsplit = link.split("?") + assert lsplit[0] == "/it/test3/2/foo/bar.json" + assert lsplit[1] in ["foo=bar&bar=foo", "bar=foo&foo=bar"] def test_global_url_prefix(app): - app._router_http._prefix_main = '/foo' + app._router_http._prefix_main = "/foo" app._router_http._prefix_main_len = 3 - with current_ctx(app, '/') as ctx: + with current_ctx(app, "/") as ctx: app.config.static_version_urls = False - ctx.request.language = 'en' + ctx.request.language = "en" - link = url('test_route') - assert link == '/foo/test_route' + link = url("test_route") + assert link == "/foo/test_route" - link = url('static', 'js/foo.js') - assert link == '/foo/static/js/foo.js' + link = url("static", "js/foo.js") + assert link == "/foo/static/js/foo.js" app.config.static_version_urls = True - app.config.static_version = '1.0.0' + app.config.static_version = "1.0.0" - link = url('static', 'js/foo.js') - assert link == '/foo/static/_1.0.0/js/foo.js' + link = url("static", "js/foo.js") + assert link == "/foo/static/_1.0.0/js/foo.js" app.config.static_version_urls = False - ctx.request.language = 'it' + ctx.request.language = "it" - link = url('test_route') - assert link == '/foo/it/test_route' + link = url("test_route") + assert link == "/foo/it/test_route" - link = url('static', 'js/foo.js') - assert link == '/foo/it/static/js/foo.js' + link = url("static", "js/foo.js") + assert link == "/foo/it/static/js/foo.js" app.config.static_version_urls = True - app.config.static_version = '1.0.0' + app.config.static_version = "1.0.0" - link = url('static', 'js/foo.js') - assert link == '/foo/it/static/_1.0.0/js/foo.js' + link = url("static", "js/foo.js") + assert link == "/foo/it/static/_1.0.0/js/foo.js" diff --git a/tests/test_session.py b/tests/test_session.py deleted file mode 100644 index 91ec0cc0..00000000 --- a/tests/test_session.py +++ /dev/null @@ -1,56 +0,0 @@ -# -*- coding: utf-8 -*- -""" - tests.session - ------------- - - Test Emmett session module -""" - -import pytest - -from emmett.asgi.wrappers import Request -from emmett.ctx import RequestContext, current -from emmett.sessions import SessionManager -from emmett.testing.env import ScopeBuilder -from emmett.wrappers.response import Response - - -class FakeRequestContext(RequestContext): - def __init__(self, app, scope): - self.request = Request(scope, None, None) - self.response = Response() - self.session = None - - -@pytest.fixture(scope='module') -def ctx(): - builder = ScopeBuilder() - token = current._init_(FakeRequestContext(None, builder.get_data()[0])) - yield current - current._close_(token) - - -@pytest.mark.asyncio -async def test_session_cookie(ctx): - session_cookie = SessionManager.cookies( - key='sid', - secure=True, - domain='localhost', - cookie_name='foo_session' - ) - assert session_cookie.key == 'sid' - assert session_cookie.secure is True - assert session_cookie.domain == 'localhost' - - await session_cookie.open_request() - assert ctx.session._expiration == 3600 - - await session_cookie.close_request() - cookie = str(ctx.response.cookies) - assert 'foo_session' in cookie - assert 'Domain=localhost;' in cookie - assert 'secure' in cookie.lower() - - ctx.request.cookies = ctx.response.cookies - await session_cookie.open_request() - assert ctx.session._expiration == 3600 diff --git a/tests/test_templates.py b/tests/test_templates.py index b4acf839..cfeb6511 100644 --- a/tests/test_templates.py +++ b/tests/test_templates.py @@ -1,21 +1,21 @@ # -*- coding: utf-8 -*- """ - tests.templates - --------------- +tests.templates +--------------- - Test Emmett templating module +Test Emmett templating module """ import pytest - from helpers import current_ctx + from emmett import App -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def app(): app = App(__name__) - app.config.templates_escape = 'all' + app.config.templates_escape = "all" app.config.templates_adjust_indent = False app.config.templates_auto_reload = True return app @@ -24,31 +24,34 @@ def app(): def test_helpers(app): templater = app.templater r = templater._render(source="{{ include_helpers }}") - assert r == '' + \ - '' + assert ( + r + == '' + + '' + ) def test_meta(app): - with current_ctx('/', app) as ctx: + with current_ctx("/", app) as ctx: ctx.response.meta.foo = "bar" ctx.response.meta_prop.foo = "bar" templater = app.templater - r = templater._render( - source="{{ include_meta }}", - context={'current': ctx}) - assert r == '' + \ - '' + r = templater._render(source="{{ include_meta }}", context={"current": ctx}) + assert r == '' + '' def test_static(app): templater = app.templater s = "{{include_static 'foo.js'}}\n{{include_static 'bar.css'}}" r = templater._render(source=s) - assert r == '\n' + assert ( + r + == '\n' + ) rendered_value = """ @@ -78,20 +81,19 @@ def test_static(app): """.format( - helpers="".join([ - f'' - for name in ["jquery.min.js", "helpers.js"] - ]) + helpers="".join( + [ + f'' + for name in ["jquery.min.js", "helpers.js"] + ] + ) ) def test_render(app): - with current_ctx('/', app) as ctx: - ctx.language = 'it' - r = app.templater.render( - 'test.html', { - 'current': ctx, 'posts': [{'title': 'foo'}, {'title': 'bar'}] - } + with current_ctx("/", app) as ctx: + ctx.language = "it" + r = app.templater.render("test.html", {"current": ctx, "posts": [{"title": "foo"}, {"title": "bar"}]}) + assert "\n".join([l.strip() for l in r.splitlines() if l.strip()]) == "\n".join( + [l.strip() for l in rendered_value[1:].splitlines()] ) - assert "\n".join([l.strip() for l in r.splitlines() if l.strip()]) == \ - "\n".join([l.strip() for l in rendered_value[1:].splitlines()]) diff --git a/tests/test_translator.py b/tests/test_translator.py index 925cbbcf..ac868035 100644 --- a/tests/test_translator.py +++ b/tests/test_translator.py @@ -1,9 +1,9 @@ # -*- coding: utf-8 -*- """ - tests.translator - ---------------- +tests.translator +---------------- - Test Emmett translator module +Test Emmett translator module """ import pytest @@ -13,17 +13,17 @@ from emmett.locals import T -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def app(): return App(__name__) def _make_translation(language): - return str(T('partly cloudy', lang=language)) + return str(T("partly cloudy", lang=language)) def test_translation(app): - current.language = 'en' - assert _make_translation('it') == 'nuvolosità variabile' - assert _make_translation('de') == 'teilweise bewölkt' - assert _make_translation('ru') == 'переменная облачность' + current.language = "en" + assert _make_translation("it") == "nuvolosità variabile" + assert _make_translation("de") == "teilweise bewölkt" + assert _make_translation("ru") == "переменная облачность" diff --git a/tests/test_utils.py b/tests/test_utils.py index 896f8bde..18a2f081 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,79 +1,39 @@ # -*- coding: utf-8 -*- """ - tests.utils - ----------- +tests.utils +----------- - Test Emmett utils engine +Test Emmett utils engine """ -import pytest - from emmett.datastructures import sdict -from emmett.utils import ( - cachedprop, _cached_prop_sync, _cached_prop_loop, - dict_to_sdict, is_valid_ip_address) - - -class Class: - def __init__(self): - self.calls = 0 - - @cachedprop - def prop(self): - self.calls += 1 - return 'test_cachedprop_sync' - - @cachedprop - async def prop_loop(self): - self.calls += 1 - return 'test_cachedprop_loop' - - -def test_cachedprop_sync(): - assert isinstance(Class.prop, _cached_prop_sync) - obj = Class() - assert obj.calls == 0 - assert obj.prop == 'test_cachedprop_sync' - assert obj.calls == 1 - assert obj.prop == 'test_cachedprop_sync' - assert obj.calls == 1 - - -@pytest.mark.asyncio -async def test_cachedprop_loop(): - assert isinstance(Class.prop_loop, _cached_prop_loop) - obj = Class() - assert obj.calls == 0 - assert (await obj.prop_loop) == 'test_cachedprop_loop' - assert obj.calls == 1 - assert (await obj.prop_loop) == 'test_cachedprop_loop' - assert obj.calls == 1 +from emmett.utils import dict_to_sdict, is_valid_ip_address def test_is_valid_ip_address(): - result_localhost = is_valid_ip_address('127.0.0.1') + result_localhost = is_valid_ip_address("127.0.0.1") assert result_localhost is True - result_unknown = is_valid_ip_address('unknown') + result_unknown = is_valid_ip_address("unknown") assert result_unknown is False - result_ipv4_valid = is_valid_ip_address('::ffff:192.168.0.1') + result_ipv4_valid = is_valid_ip_address("::ffff:192.168.0.1") assert result_ipv4_valid is True - result_ipv4_valid = is_valid_ip_address('192.168.256.1') + result_ipv4_valid = is_valid_ip_address("192.168.256.1") assert result_ipv4_valid is False - result_ipv6_valid = is_valid_ip_address('fd40:363d:ee85::') + result_ipv6_valid = is_valid_ip_address("fd40:363d:ee85::") assert result_ipv6_valid is True - result_ipv6_valid = is_valid_ip_address('fd40:363d:ee85::1::') + result_ipv6_valid = is_valid_ip_address("fd40:363d:ee85::1::") assert result_ipv6_valid is False def test_dict_to_sdict(): - result_sdict = dict_to_sdict({'test': 'dict'}) + result_sdict = dict_to_sdict({"test": "dict"}) assert isinstance(result_sdict, sdict) - assert result_sdict.test == 'dict' + assert result_sdict.test == "dict" result_number = dict_to_sdict(1) assert not isinstance(result_number, sdict) diff --git a/tests/test_validators.py b/tests/test_validators.py index 2deebd93..fb343529 100644 --- a/tests/test_validators.py +++ b/tests/test_validators.py @@ -1,21 +1,47 @@ # -*- coding: utf-8 -*- """ - tests.validators - ---------------- +tests.validators +---------------- - Test Emmett validators over pyDAL. +Test Emmett validators over pyDAL. """ +from datetime import datetime, timedelta + import pytest -from datetime import datetime, timedelta from emmett import App, sdict -from emmett.orm import Database, Model, Field, has_many, belongs_to +from emmett.orm import Database, Field, Model, belongs_to, has_many from emmett.validators import ( - isEmptyOr, hasLength, isInt, isFloat, isDate, isTime, isDatetime, isJSON, - isntEmpty, inSet, inDB, isEmail, isUrl, isIP, isImage, inRange, Equals, - Lower, Upper, Cleanup, Urlify, Crypt, notInDB, Allow, Not, Matches, Any, - isList) + Allow, + Any, + Cleanup, + Crypt, + Equals, + Lower, + Matches, + Not, + Upper, + Urlify, + hasLength, + inDB, + inRange, + inSet, + isDate, + isDatetime, + isEmail, + isEmptyOr, + isFloat, + isImage, + isInt, + isIP, + isJSON, + isList, + isntEmpty, + isTime, + isUrl, + notInDB, +) class A(Model): @@ -49,11 +75,9 @@ class B(Model): tablename = "b" a = Field() - b = Field(validation={'len': {'gte': 5}}) + b = Field(validation={"len": {"gte": 5}}) - validation = { - 'a': {'len': {'gte': 5}} - } + validation = {"a": {"len": {"gte": 5}}} class Consist(Model): @@ -65,12 +89,12 @@ class Consist(Model): emailsplit = Field.string_list() validation = { - 'email': {'is': 'email'}, - 'url': {'is': 'url'}, - 'ip': {'is': 'ip'}, - 'image': {'is': 'image'}, - 'emails': {'is': 'list:email'}, - 'emailsplit': {'is': {'list:email': {'splitter': ',;'}}} + "email": {"is": "email"}, + "url": {"is": "url"}, + "ip": {"is": "ip"}, + "image": {"is": "image"}, + "emails": {"is": "list:email"}, + "emailsplit": {"is": {"list:email": {"splitter": ",;"}}}, } @@ -81,10 +105,10 @@ class Len(Model): d = Field() validation = { - 'a': {'len': 5}, - 'b': {'len': {'gt': 4, 'lt': 13}}, - 'c': {'len': {'gte': 5, 'lte': 12}}, - 'd': {'len': {'range': (5, 13)}} + "a": {"len": 5}, + "b": {"len": {"gt": 4, "lt": 13}}, + "c": {"len": {"gte": 5, "lte": 12}}, + "d": {"len": {"range": (5, 13)}}, } @@ -92,10 +116,7 @@ class Inside(Model): a = Field() b = Field.int() - validation = { - 'a': {'in': ['a', 'b']}, - 'b': {'in': {'range': (1, 5)}} - } + validation = {"a": {"in": ["a", "b"]}, "b": {"in": {"range": (1, 5)}}} class Num(Model): @@ -103,11 +124,7 @@ class Num(Model): b = Field.int() c = Field.int() - validation = { - 'a': {'gt': 0}, - 'b': {'lt': 5}, - 'c': {'gt': 0, 'lte': 4} - } + validation = {"a": {"gt": 0}, "b": {"lt": 5}, "c": {"gt": 0, "lte": 4}} class Proc(Model): @@ -119,12 +136,12 @@ class Proc(Model): f = Field.password() validation = { - 'a': {'lower': True}, - 'b': {'upper': True}, - 'c': {'clean': True}, - 'd': {'urlify': True}, - 'e': {'len': {'range': (6, 25)}, 'crypt': True}, - 'f': {'len': {'gt': 5, 'lt': 25}, 'crypt': 'md5'} + "a": {"lower": True}, + "b": {"upper": True}, + "c": {"clean": True}, + "d": {"urlify": True}, + "e": {"len": {"range": (6, 25)}, "crypt": True}, + "f": {"len": {"gt": 5, "lt": 25}, "crypt": "md5"}, } @@ -133,60 +150,47 @@ class Eq(Model): b = Field.int() c = Field.float() - validation = { - 'a': {'equals': 'asd'}, - 'b': {'equals': 2}, - 'c': {'not': {'equals': 2.4}} - } + validation = {"a": {"equals": "asd"}, "b": {"equals": 2}, "c": {"not": {"equals": 2.4}}} class Match(Model): a = Field() b = Field() - validation = { - 'a': {'match': 'ab'}, - 'b': {'match': {'expression': 'ab', 'strict': True}} - } + validation = {"a": {"match": "ab"}, "b": {"match": {"expression": "ab", "strict": True}}} class Anyone(Model): a = Field() - validation = { - 'a': {'any': {'is': 'email', 'in': ['foo', 'bar']}} - } + validation = {"a": {"any": {"is": "email", "in": ["foo", "bar"]}}} class Person(Model): - has_many('things') + has_many("things") - name = Field(validation={'empty': False}) - surname = Field(validation={'presence': True}) + name = Field(validation={"empty": False}) + surname = Field(validation={"presence": True}) class Thing(Model): - belongs_to('person') + belongs_to("person") name = Field() color = Field() uid = Field(unique=True) - validation = { - 'name': {'presence': True}, - 'color': {'in': ['blue', 'red']}, - 'uid': {'empty': False} - } + validation = {"name": {"presence": True}, "color": {"in": ["blue", "red"]}, "uid": {"empty": False}} class Allowed(Model): - a = Field(validation={'in': ['a', 'b'], 'allow': None}) - b = Field(validation={'in': ['a', 'b'], 'allow': 'empty'}) - c = Field(validation={'in': ['a', 'b'], 'allow': 'blank'}) + a = Field(validation={"in": ["a", "b"], "allow": None}) + b = Field(validation={"in": ["a", "b"], "allow": "empty"}) + c = Field(validation={"in": ["a", "b"], "allow": "blank"}) class Mixed(Model): - belongs_to('person') + belongs_to("person") date = Field.date() type = Field() @@ -197,27 +201,21 @@ class Mixed(Model): psw = Field.password() validation = { - 'date': {'format': '%d/%m/%Y', 'gt': lambda: datetime.utcnow().date()}, - 'type': {'in': ['a', 'b'], 'allow': None}, - 'inside': {'in': ['asd', 'lol']}, - 'number': {'allow': 'blank'}, - 'dont': {'empty': True}, - 'yep': {'presence': True}, - 'psw': {'len': {'range': (6, 25)}, 'crypt': True} + "date": {"format": "%d/%m/%Y", "gt": lambda: datetime.utcnow().date()}, + "type": {"in": ["a", "b"], "allow": None}, + "inside": {"in": ["asd", "lol"]}, + "number": {"allow": "blank"}, + "dont": {"empty": True}, + "yep": {"presence": True}, + "psw": {"len": {"range": (6, 25)}, "crypt": True}, } -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def db(): app = App(__name__) - db = Database( - app, config=sdict( - uri='sqlite://validators.db', auto_connect=True, - auto_migrate=True)) - db.define_models([ - A, AA, AAA, B, Consist, Len, Inside, Num, Eq, Match, Anyone, Proc, - Person, Thing, Allowed, Mixed - ]) + db = Database(app, config=sdict(uri="sqlite://validators.db", auto_connect=True, auto_migrate=True)) + db.define_models([A, AA, AAA, B, Consist, Len, Inside, Num, Eq, Match, Anyone, Proc, Person, Thing, Allowed, Mixed]) return db @@ -349,243 +347,243 @@ def test_allow(db): def test_validation(db): #: 'is' is_data = { - 'name': 'foo', - 'val': 1, - 'fval': 1.5, - 'text': 'Lorem ipsum', - 'password': 'notverysecret', - 'd': '{:%Y-%m-%d}'.format(datetime.utcnow()), - 't': '15:23', - 'dt': '2015-12-23T15:23:00', - 'json': '{}' + "name": "foo", + "val": 1, + "fval": 1.5, + "text": "Lorem ipsum", + "password": "notverysecret", + "d": "{:%Y-%m-%d}".format(datetime.utcnow()), + "t": "15:23", + "dt": "2015-12-23T15:23:00", + "json": "{}", } errors = A.validate(is_data) assert not errors d = dict(is_data) - d['val'] = 'foo' + d["val"] = "foo" errors = A.validate(d) - assert 'val' in errors and len(errors) == 1 + assert "val" in errors and len(errors) == 1 d = dict(is_data) - d['fval'] = 'bar' + d["fval"] = "bar" errors = A.validate(d) - assert 'fval' in errors and len(errors) == 1 + assert "fval" in errors and len(errors) == 1 d = dict(is_data) - d['d'] = 'foo' + d["d"] = "foo" errors = A.validate(d) - assert 'd' in errors and len(errors) == 1 + assert "d" in errors and len(errors) == 1 d = dict(is_data) - d['t'] = 'bar' + d["t"] = "bar" errors = A.validate(d) - assert 't' in errors and len(errors) == 1 + assert "t" in errors and len(errors) == 1 d = dict(is_data) - d['dt'] = 'foo' + d["dt"] = "foo" errors = A.validate(d) - assert 'dt' in errors and len(errors) == 1 + assert "dt" in errors and len(errors) == 1 d = dict(is_data) - d['json'] = 'bar' + d["json"] = "bar" errors = A.validate(d) - assert 'json' in errors and len(errors) == 1 - errors = Consist.validate({'email': 'foo'}) - assert 'email' in errors - errors = Consist.validate({'url': 'notanurl'}) - assert 'url' in errors - errors = Consist.validate({'url': 'http://domain.com/'}) - assert 'url' not in errors - errors = Consist.validate({'url': 'http://domain.com'}) - assert 'url' not in errors - errors = Consist.validate({'url': 'domain.com'}) - assert 'url' not in errors - errors = Consist.validate({'ip': 'foo'}) - assert 'ip' in errors - errors = Consist.validate({'emails': 'foo'}) - assert 'emails' in errors - errors = Consist.validate({'emailsplit': 'foo'}) - assert 'emailsplit' in errors - errors = Consist.validate({'emailsplit': '1@asd.com, 2@asd.com'}) - assert 'emailsplit' not in errors - errors = Consist.validate({'emails': ['1@asd.com', '2@asd.com']}) - assert 'emails' not in errors + assert "json" in errors and len(errors) == 1 + errors = Consist.validate({"email": "foo"}) + assert "email" in errors + errors = Consist.validate({"url": "notanurl"}) + assert "url" in errors + errors = Consist.validate({"url": "http://domain.com/"}) + assert "url" not in errors + errors = Consist.validate({"url": "http://domain.com"}) + assert "url" not in errors + errors = Consist.validate({"url": "domain.com"}) + assert "url" not in errors + errors = Consist.validate({"ip": "foo"}) + assert "ip" in errors + errors = Consist.validate({"emails": "foo"}) + assert "emails" in errors + errors = Consist.validate({"emailsplit": "foo"}) + assert "emailsplit" in errors + errors = Consist.validate({"emailsplit": "1@asd.com, 2@asd.com"}) + assert "emailsplit" not in errors + errors = Consist.validate({"emails": ["1@asd.com", "2@asd.com"]}) + assert "emails" not in errors #: 'len' - len_data = {'a': '12345', 'b': '12345', 'c': '12345', 'd': '12345'} + len_data = {"a": "12345", "b": "12345", "c": "12345", "d": "12345"} errors = Len.validate(len_data) assert not errors d = dict(len_data) - d['a'] = 'ciao' + d["a"] = "ciao" errors = Len.validate(d) - assert 'a' in errors and len(errors) == 1 + assert "a" in errors and len(errors) == 1 d = dict(len_data) - d['b'] = 'ciao' + d["b"] = "ciao" errors = Len.validate(d) - assert 'b' in errors and len(errors) == 1 + assert "b" in errors and len(errors) == 1 d = dict(len_data) - d['b'] = '1234567890123' + d["b"] = "1234567890123" errors = Len.validate(d) - assert 'b' in errors and len(errors) == 1 + assert "b" in errors and len(errors) == 1 d = dict(len_data) - d['c'] = 'ciao' + d["c"] = "ciao" errors = Len.validate(d) - assert 'c' in errors and len(errors) == 1 + assert "c" in errors and len(errors) == 1 d = dict(len_data) - d['c'] = '1234567890123' + d["c"] = "1234567890123" errors = Len.validate(d) - assert 'c' in errors and len(errors) == 1 + assert "c" in errors and len(errors) == 1 d = dict(len_data) - d['d'] = 'ciao' + d["d"] = "ciao" errors = Len.validate(d) - assert 'd' in errors and len(errors) == 1 + assert "d" in errors and len(errors) == 1 d = dict(len_data) - d['d'] = '1234567890123' + d["d"] = "1234567890123" errors = Len.validate(d) - assert 'd' in errors and len(errors) == 1 + assert "d" in errors and len(errors) == 1 #: 'in' - in_data = {'a': 'a', 'b': 2} + in_data = {"a": "a", "b": 2} errors = Inside.validate(in_data) assert not errors d = dict(in_data) - d['a'] = 'c' + d["a"] = "c" errors = Inside.validate(d) - assert 'a' in errors and len(errors) == 1 + assert "a" in errors and len(errors) == 1 d = dict(in_data) - d['b'] = 0 + d["b"] = 0 errors = Inside.validate(d) - assert 'b' in errors and len(errors) == 1 + assert "b" in errors and len(errors) == 1 d = dict(in_data) - d['b'] = 7 + d["b"] = 7 errors = Inside.validate(d) - assert 'b' in errors and len(errors) == 1 + assert "b" in errors and len(errors) == 1 #: 'gt', 'lt', 'gte', 'lte' - num_data = {'a': 1, 'b': 4, 'c': 2} + num_data = {"a": 1, "b": 4, "c": 2} errors = Num.validate(num_data) assert not errors d = dict(num_data) - d['a'] = 0 + d["a"] = 0 errors = Num.validate(d) - assert 'a' in errors and len(errors) == 1 + assert "a" in errors and len(errors) == 1 d = dict(num_data) - d['b'] = 5 + d["b"] = 5 errors = Num.validate(d) - assert 'b' in errors and len(errors) == 1 + assert "b" in errors and len(errors) == 1 d = dict(num_data) - d['c'] = 0 + d["c"] = 0 errors = Num.validate(d) - assert 'c' in errors and len(errors) == 1 + assert "c" in errors and len(errors) == 1 d = dict(num_data) - d['c'] = 5 + d["c"] = 5 errors = Num.validate(d) - assert 'c' in errors and len(errors) == 1 + assert "c" in errors and len(errors) == 1 #: 'equals' - eq_data = {'a': 'asd', 'b': 2, 'c': 2.3} + eq_data = {"a": "asd", "b": 2, "c": 2.3} errors = Eq.validate(eq_data) assert not errors d = dict(eq_data) - d['a'] = 'lol' + d["a"] = "lol" errors = Eq.validate(d) - assert 'a' in errors and len(errors) == 1 + assert "a" in errors and len(errors) == 1 d = dict(eq_data) - d['b'] = 3 + d["b"] = 3 errors = Eq.validate(d) - assert 'b' in errors and len(errors) == 1 + assert "b" in errors and len(errors) == 1 #: 'not' d = dict(eq_data) - d['c'] = 2.4 + d["c"] = 2.4 errors = Eq.validate(d) - assert 'c' in errors and len(errors) == 1 + assert "c" in errors and len(errors) == 1 #: 'match' - match_data = {'a': 'abc', 'b': 'ab'} + match_data = {"a": "abc", "b": "ab"} errors = Match.validate(match_data) assert not errors d = dict(match_data) - d['a'] = 'lol' + d["a"] = "lol" errors = Match.validate(d) - assert 'a' in errors and len(errors) == 1 + assert "a" in errors and len(errors) == 1 d = dict(match_data) - d['b'] = 'abc' + d["b"] = "abc" errors = Match.validate(d) - assert 'b' in errors and len(errors) == 1 - d['b'] = 'lol' + assert "b" in errors and len(errors) == 1 + d["b"] = "lol" errors = Match.validate(d) - assert 'b' in errors and len(errors) == 1 + assert "b" in errors and len(errors) == 1 #: 'any' - errors = Anyone.validate({'a': 'foo'}) + errors = Anyone.validate({"a": "foo"}) assert not errors - errors = Anyone.validate({'a': 'walter@massivedynamics.com'}) + errors = Anyone.validate({"a": "walter@massivedynamics.com"}) assert not errors - errors = Anyone.validate({'a': 'lol'}) - assert 'a' in errors + errors = Anyone.validate({"a": "lol"}) + assert "a" in errors #: 'allow' - allow_data = {'a': 'a', 'b': 'a', 'c': 'a'} + allow_data = {"a": "a", "b": "a", "c": "a"} errors = Allowed.validate(allow_data) assert not errors d = dict(allow_data) - d['a'] = None + d["a"] = None errors = Allowed.validate(d) assert not errors - d['a'] = 'foo' + d["a"] = "foo" errors = Allowed.validate(d) - assert 'a' in errors and len(errors) == 1 + assert "a" in errors and len(errors) == 1 d = dict(allow_data) - d['b'] = '' + d["b"] = "" errors = Allowed.validate(d) assert not errors - d['b'] = None + d["b"] = None errors = Allowed.validate(d) assert not errors - d['b'] = 'foo' + d["b"] = "foo" errors = Allowed.validate(d) - assert 'b' in errors and len(errors) == 1 + assert "b" in errors and len(errors) == 1 d = dict(allow_data) - d['c'] = '' + d["c"] = "" errors = Allowed.validate(d) assert not errors - d['c'] = None + d["c"] = None errors = Allowed.validate(d) assert not errors - d['c'] = 'foo' + d["c"] = "foo" errors = Allowed.validate(d) - assert 'c' in errors and len(errors) == 1 + assert "c" in errors and len(errors) == 1 #: processing validators - assert Proc.a.validate('Asd')[0] == 'asd' - assert Proc.b.validate('Asd')[0] == 'ASD' - assert Proc.d.validate('Two Words')[0] == 'two-words' - psw = str(Proc.e.validate('somepassword')[0]) - assert psw[:23] == 'pbkdf2(1000,20,sha512)$' - assert psw[23:] != 'somepassword' - psw = str(Proc.f.validate('somepassword')[0]) - assert psw[:4] == 'md5$' - assert psw[4:] != 'somepassword' + assert Proc.a.validate("Asd")[0] == "asd" + assert Proc.b.validate("Asd")[0] == "ASD" + assert Proc.d.validate("Two Words")[0] == "two-words" + psw = str(Proc.e.validate("somepassword")[0]) + assert psw[:23] == "pbkdf2(1000,20,sha512)$" + assert psw[23:] != "somepassword" + psw = str(Proc.f.validate("somepassword")[0]) + assert psw[:4] == "md5$" + assert psw[4:] != "somepassword" #: 'presence' - mario = {'name': 'mario'} + mario = {"name": "mario"} errors = Person.validate(mario) - assert 'surname' in errors + assert "surname" in errors assert len(errors) == 1 #: 'presence' with reference, 'unique' - thing = {'name': 'a', 'person': 5, 'color': 'blue', 'uid': 'lol'} + thing = {"name": "a", "person": 5, "color": "blue", "uid": "lol"} errors = Thing.validate(thing) - assert 'person' in errors + assert "person" in errors assert len(errors) == 1 - mario = {'name': 'mario', 'surname': 'draghi'} + mario = {"name": "mario", "surname": "draghi"} mario = Person.create(mario) assert len(mario.errors.keys()) == 0 assert mario.id == 1 - thing = {'name': 'euro', 'person': mario.id, 'color': 'red', 'uid': 'lol'} + thing = {"name": "euro", "person": mario.id, "color": "red", "uid": "lol"} thing = Thing.create(thing) assert len(thing.errors.keys()) == 0 - thing = {'name': 'euro2', 'person': mario.id, 'color': 'red', 'uid': 'lol'} + thing = {"name": "euro2", "person": mario.id, "color": "red", "uid": "lol"} errors = Thing.validate(thing) assert len(errors) == 1 - assert 'uid' in errors + assert "uid" in errors def test_multi(db): p = db.Person(name="mario") base_data = { - 'date': '{0:%d/%m/%Y}'.format(datetime.utcnow()+timedelta(days=1)), - 'type': 'a', - 'inside': 'asd', - 'number': 1, - 'yep': 'asd', - 'psw': 'password', - 'person': p.id + "date": "{0:%d/%m/%Y}".format(datetime.utcnow() + timedelta(days=1)), + "type": "a", + "inside": "asd", + "number": 1, + "yep": "asd", + "psw": "password", + "person": p.id, } #: everything ok res = Mixed.create(base_data) @@ -593,67 +591,67 @@ def test_multi(db): assert len(res.errors.keys()) == 0 #: invalid belongs vals = dict(base_data) - del vals['person'] + del vals["person"] res = Mixed.create(vals) assert res.id is None assert len(res.errors.keys()) == 1 - assert 'person' in res.errors + assert "person" in res.errors #: invalid date range vals = dict(base_data) - vals['date'] = '{0:%d/%m/%Y}'.format(datetime.utcnow()-timedelta(days=2)) + vals["date"] = "{0:%d/%m/%Y}".format(datetime.utcnow() - timedelta(days=2)) res = Mixed.create(vals) assert res.id is None assert len(res.errors.keys()) == 1 - assert 'date' in res.errors + assert "date" in res.errors #: invalid date format - vals['date'] = '76-12-1249' + vals["date"] = "76-12-1249" res = Mixed.create(vals) assert res.id is None assert len(res.errors.keys()) == 1 - assert 'date' in res.errors + assert "date" in res.errors #: invalid in vals = dict(base_data) - vals['type'] = ' ' + vals["type"] = " " res = Mixed.create(vals) assert res.id is None assert len(res.errors.keys()) == 1 - assert 'type' in res.errors + assert "type" in res.errors #: empty number vals = dict(base_data) - vals['number'] = None + vals["number"] = None res = Mixed.create(vals) assert res.id == 2 assert len(res.errors.keys()) == 0 #: invalid number vals = dict(base_data) - vals['number'] = 'asd' + vals["number"] = "asd" res = Mixed.create(vals) assert res.id is None assert len(res.errors.keys()) == 1 - assert 'number' in res.errors + assert "number" in res.errors #: invalid empty vals = dict(base_data) - vals['dont'] = '2' + vals["dont"] = "2" res = Mixed.create(vals) assert res.id is None assert len(res.errors.keys()) == 1 - assert 'dont' in res.errors + assert "dont" in res.errors #: invalid presence vals = dict(base_data) - vals['yep'] = '' + vals["yep"] = "" res = Mixed.create(vals) assert res.id is None assert len(res.errors.keys()) == 1 - assert 'yep' in res.errors + assert "yep" in res.errors #: invalid password vals = dict(base_data) - vals['psw'] = '' + vals["psw"] = "" res = Mixed.create(vals) assert res.id is None assert len(res.errors.keys()) == 1 - assert 'psw' in res.errors - vals['psw'] = 'aksjfalsdkjflkasjdflkajsldkjfalslkdfjaslkdjf' + assert "psw" in res.errors + vals["psw"] = "aksjfalsdkjflkasjdflkajsldkjfalslkdfjaslkdjf" res = Mixed.create(vals) assert res.id is None assert len(res.errors.keys()) == 1 - assert 'psw' in res.errors + assert "psw" in res.errors diff --git a/tests/test_wrappers.py b/tests/test_wrappers.py index be7704f2..a600d449 100644 --- a/tests/test_wrappers.py +++ b/tests/test_wrappers.py @@ -1,36 +1,37 @@ # -*- coding: utf-8 -*- """ - tests.wrappers - -------------- +tests.wrappers +-------------- - Test Emmett wrappers module +Test Emmett wrappers module """ +from emmett_core.protocols.rsgi.test_client.scope import ScopeBuilder from helpers import current_ctx -from emmett.asgi.wrappers import Request -from emmett.testing.env import ScopeBuilder + +from emmett.rsgi.wrappers import Request from emmett.wrappers.response import Response def test_request(): scope, _ = ScopeBuilder( - path='/?foo=bar', - method='GET', + path="/?foo=bar", + method="GET", ).get_data() request = Request(scope, None, None) - assert request.query_params == {'foo': 'bar'} - assert request.client == '127.0.0.1' + assert request.query_params == {"foo": "bar"} + assert request.client == "127.0.0.1" def test_response(): response = Response() assert response.status == 200 - assert response.headers['content-type'] == 'text/plain' + assert response.headers["content-type"] == "text/plain" def test_req_ctx(): - with current_ctx('/?foo=bar') as ctx: + with current_ctx("/?foo=bar") as ctx: assert isinstance(ctx.request, Request) assert isinstance(ctx.response, Response)