From b5e306a241bf5020a518c53db85464b115ebb6f7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=89lie=20Bouttier?= Date: Wed, 6 Dec 2023 16:21:46 +0100 Subject: [PATCH] SQLAlchemy 1.4 (#46) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Update flask to 3.0 * Update SqlAlchemy to 1.4 * Abandon support for debian 10 * Abandon support for python 3.7 * Custom select class to replace query_class (to be removed) --------- Co-authored-by: TheoLechemia Co-authored-by: Élie Bouttier Co-authored-by: Jacques Fize <4259846+jacquesfize@users.noreply.github.com> Co-authored-by: Pierre Narcisi --- .github/workflows/pytest.yml | 20 ++--- requirements.in | 2 +- src/utils_flask_sqla/generic.py | 4 +- src/utils_flask_sqla/models.py | 15 ++++ src/utils_flask_sqla/sqlalchemy.py | 19 ++++ .../tests/test_custom_select.py | 89 +++++++++++++++++++ src/utils_flask_sqla/utils.py | 33 +++++++ 7 files changed, 166 insertions(+), 16 deletions(-) create mode 100644 src/utils_flask_sqla/models.py create mode 100644 src/utils_flask_sqla/sqlalchemy.py create mode 100644 src/utils_flask_sqla/tests/test_custom_select.py diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index 9a513e0..f21ae24 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -16,21 +16,15 @@ jobs: strategy: matrix: - python-version: ['3.9'] - sqlalchemy-version: [ '1.3', '1.4' ] + python-version: [ '3.9', '3.11' ] + sqlalchemy-version: [ '1.4' ] include: - - sqlalchemy-version: '1.3' - sqlalchemy-lt-version: '1.4' - flask-sqlalchemy-version: '2.0' - flask-sqlalchemy-lt-version: '3.0' - flask-version: '2' - flask-lt-version: '3' - sqlalchemy-version: '1.4' sqlalchemy-lt-version: '2.0' flask-sqlalchemy-version: '3.0' - flask-sqlalchemy-lt-version: '4.0' - flask-version: '3' - flask-lt-version: '4' + flask-sqlalchemy-lt-version: '3.1' + flask-version: '2.2' + flask-lt-version: '4.0' name: Python ${{ matrix.python-version }} - SQLAlchemy ${{ matrix.sqlalchemy-version }} @@ -48,14 +42,14 @@ jobs: python -m pip install -e .[tests] pytest-cov \ "sqlalchemy>=${{ matrix.sqlalchemy-version }},<${{ matrix.sqlalchemy-lt-version }}" \ "flask-sqlalchemy>=${{ matrix.flask-sqlalchemy-version }},<${{ matrix.flask-sqlalchemy-lt-version }}" \ - "flask>=${{ matrix.flask-version }},<${{ matrix.flask-lt-version }}" + "flask>=${{ matrix.flask-version }},<${{ matrix.flask-lt-version }}" - name: Test with pytest run: | pytest -v --cov --cov-report xml - name: Upload coverage to Codecov - if: ${{ matrix.python-version == '3.9' && matrix.sqlalchemy-version == '1.4'}} + if: ${{ matrix.python-version == '3.11' && matrix.sqlalchemy-version == '1.4'}} uses: codecov/codecov-action@v3 with: flags: pytest diff --git a/requirements.in b/requirements.in index f53221a..0e5e934 100644 --- a/requirements.in +++ b/requirements.in @@ -3,4 +3,4 @@ flask-sqlalchemy flask-migrate marshmallow python-dateutil -sqlalchemy>=1.3,<2 +sqlalchemy<2 diff --git a/src/utils_flask_sqla/generic.py b/src/utils_flask_sqla/generic.py index b29985a..c5bbeda 100644 --- a/src/utils_flask_sqla/generic.py +++ b/src/utils_flask_sqla/generic.py @@ -101,8 +101,8 @@ def __init__(self, tableName, schemaName, engine): - engine : sqlalchemy instance engine for exemple : DB.engine if DB = Sqlalchemy() """ - meta = MetaData(schema=schemaName, bind=engine) - meta.reflect(views=True) + meta = MetaData(schema=schemaName) + meta.reflect(views=True, bind=engine) try: self.tableDef = meta.tables["{}.{}".format(schemaName, tableName)] diff --git a/src/utils_flask_sqla/models.py b/src/utils_flask_sqla/models.py new file mode 100644 index 0000000..8f6b6a8 --- /dev/null +++ b/src/utils_flask_sqla/models.py @@ -0,0 +1,15 @@ +from .sqlalchemy import CustomSelect +from flask_sqlalchemy.model import Model + + +class SelectModel(Model): + __abstract__ = True + + @classmethod + @property + def select(cls): + if hasattr(cls, "__select_class__"): + select_cls = cls.__select_class__ + else: + select_cls = CustomSelect + return select_cls._create_future_select(cls) # SQLA 2.0: _create_future_select → _create diff --git a/src/utils_flask_sqla/sqlalchemy.py b/src/utils_flask_sqla/sqlalchemy.py new file mode 100644 index 0000000..d44f817 --- /dev/null +++ b/src/utils_flask_sqla/sqlalchemy.py @@ -0,0 +1,19 @@ +from flask_sqlalchemy import SQLAlchemy +from sqlalchemy.util.langhelpers import public_factory +from sqlalchemy.sql.expression import Select + + +class CustomSelect(Select): + inherit_cache = True + + def where_if(self, condition_to_execute_where, whereclause): + if condition_to_execute_where: + return self.where(whereclause) + else: + return self + + +class CustomSQLAlchemy(SQLAlchemy): + @staticmethod + def select(*entities): + return CustomSelect._create_future_select(*entities) diff --git a/src/utils_flask_sqla/tests/test_custom_select.py b/src/utils_flask_sqla/tests/test_custom_select.py new file mode 100644 index 0000000..2906348 --- /dev/null +++ b/src/utils_flask_sqla/tests/test_custom_select.py @@ -0,0 +1,89 @@ +import pytest +from flask import Flask +from sqlalchemy import func + +from utils_flask_sqla.sqlalchemy import CustomSQLAlchemy, CustomSelect +from utils_flask_sqla.models import SelectModel + + +db = CustomSQLAlchemy(model_class=SelectModel) + + +class FooModel(db.Model): + pk = db.Column(db.Integer, primary_key=True) + + +class BarSelect(CustomSelect): + inherit_cache = True + + def where_pk(self, pk): + return self.where(BarModel.pk == pk) + + +class BarModel(db.Model): + __select_class__ = BarSelect + + pk = db.Column(db.Integer, primary_key=True) + + +@pytest.fixture(scope="session") +def app(): + app = Flask("utils-flask-sqla") + app.config["SQLALCHEMY_DATABASE_URI"] = f"sqlite:///" + db.init_app(app) + with app.app_context(): + db.create_all() + yield app + + +@pytest.fixture(scope="session") +def foo(app): + foo = FooModel() + db.session.add(foo) + db.session.commit() + return foo + + +@pytest.fixture(scope="session") +def bar(app): + bar = BarModel() + db.session.add(bar) + db.session.commit() + return bar + + +class TestCustomSelect: + def test_select_where_if(self, foo): + # Filter does not apply, we get foo + assert ( + foo + in db.session.scalars(db.select(FooModel).where_if(False, FooModel.pk != foo.pk)).all() + ) + # Filter apply, we does not get foo + assert ( + foo + not in db.session.scalars( + db.select(FooModel).where_if(True, FooModel.pk != foo.pk) + ).all() + ) + + def test_model_where_if(self, foo): + # Filter does not apply, we get foo + assert ( + foo in db.session.scalars(FooModel.select.where_if(False, FooModel.pk != foo.pk)).all() + ) + # Filter apply, we does not get foo + assert ( + foo + not in db.session.scalars(FooModel.select.where_if(True, FooModel.pk != foo.pk)).all() + ) + + def test_model_select_class(self, bar): + assert db.session.scalars(BarModel.select.where_pk(bar.pk)).one_or_none() is bar + assert db.session.scalars(BarModel.select.where_pk(bar.pk + 1)).one_or_none() is not bar + + def test_chain_custom_where(self, bar): + assert ( + db.session.scalars(BarModel.select.where_pk(bar.pk).where_pk(bar.pk)).one_or_none() + is bar + ) diff --git a/src/utils_flask_sqla/utils.py b/src/utils_flask_sqla/utils.py index 937207d..347b35d 100644 --- a/src/utils_flask_sqla/utils.py +++ b/src/utils_flask_sqla/utils.py @@ -5,6 +5,8 @@ from tempfile import TemporaryDirectory from shutil import copyfileobj from urllib.request import urlopen +from contextlib import suppress +from sqlalchemy.sql import visitors class remote_file(ExitStack): @@ -32,3 +34,34 @@ def __enter__(self): with urlopen(self.url) as response, open(remote_file_path, "wb") as remote_file: copyfileobj(response, remote_file) return remote_file_path + + +def is_already_joined(my_class, query): + """ + Check if the given class is already present is the current query + _class: SQLAlchemy class + query: SQLAlchemy query + return boolean + """ + for visitor in visitors.iterate(query.statement): + # Checking for `.join(Parent.child)` clauses + if visitor.__visit_name__ == "binary": + for vis in visitors.iterate(visitor): + # Visitor might not have table attribute + with suppress(AttributeError): + # Verify if already present based on table name + if my_class.__table__.fullname == vis.table.fullname: + return True + # Checking for `.join(Child)` clauses + if visitor.__visit_name__ == "table": + # Visitor might be of ColumnCollection or so, + # which cannot be compared to model + with suppress(TypeError): + if my_class == visitor.entity_namespace: + return True + # Checking for `Model.column` clauses + if visitor.__visit_name__ == "column": + with suppress(AttributeError): + if my_class.__table__.fullname == visitor.table.fullname: + return True + return False