From ffab504da086720912bf706923fe611b9ad2008a Mon Sep 17 00:00:00 2001
From: Jacobe2169 <jacques.fize@ecrins-parcnational.fr>
Date: Fri, 8 Dec 2023 15:13:57 +0100
Subject: [PATCH] fix test + change decorator

---
 src/utils_flask_sqla/models.py             | 19 +++++++++++--------
 src/utils_flask_sqla/tests/test_qfilter.py | 15 +++++++++++++--
 2 files changed, 24 insertions(+), 10 deletions(-)

diff --git a/src/utils_flask_sqla/models.py b/src/utils_flask_sqla/models.py
index 6f0faa8..936c9a6 100644
--- a/src/utils_flask_sqla/models.py
+++ b/src/utils_flask_sqla/models.py
@@ -2,6 +2,12 @@
 from flask_sqlalchemy.model import DefaultMeta
 from sqlalchemy.sql import select, Select
 
+AUTHORIZED_WHERECLAUSE_TYPES = [bool, BooleanClauseList, BinaryExpression]
+
+
+def is_whereclause_compatible(object):
+    return any([isinstance(object, type_) for type_ in AUTHORIZED_WHERECLAUSE_TYPES])
+
 
 def qfilter(*args_dec, **kwargs_dec):
     """
@@ -22,7 +28,7 @@ def filter_by_params(cls,**kwargs):
                 filters = []
                 if "id_station" in kwargs:
                     filters.append(Station.id_station == kwargs["id_station"])
-                return query.whereclause
+                return filters
             # If you wish the method to return a query
             @qfilter(query=True)
             def filter_by_paramsQ(cls,**kwargs):
@@ -60,8 +66,8 @@ def filter_by_paramsQ(cls,**kwargs):
         return _qfilter(*args_dec, **kwargs_dec)
 
 
-def _qfilter(*args_dec, **kwargs_dec):
-    is_query = kwargs_dec.get("query", False)
+def _qfilter(query=False):
+    is_query = query
 
     def _qfilter_decorator(method):
         def _(*args, **kwargs):
@@ -83,13 +89,10 @@ def _(*args, **kwargs):
             if is_query and not isinstance(result, Select):
                 raise ValueError("Your method must return a SQLAlchemy Select object ")
 
-            authorise_whereclause_type = [bool, BooleanClauseList, BinaryExpression]
-            if not is_query and not any(
-                [isinstance(result, type_) for type_ in authorise_whereclause_type]
-            ):
+            if not is_query and not is_whereclause_compatible(result):
                 raise ValueError(
                     "Your method must return an object in the following types: {} ".format(
-                        ", ".join(map(lambda cls: cls.__name__, authorise_whereclause_type))
+                        ", ".join(map(lambda cls: cls.__name__, AUTHORIZED_WHERECLAUSE_TYPES))
                     )
                 )
             # if filter is wanted as where clause
diff --git a/src/utils_flask_sqla/tests/test_qfilter.py b/src/utils_flask_sqla/tests/test_qfilter.py
index 97f375c..d2c9612 100644
--- a/src/utils_flask_sqla/tests/test_qfilter.py
+++ b/src/utils_flask_sqla/tests/test_qfilter.py
@@ -1,6 +1,6 @@
 import pytest
 from flask import Flask
-from sqlalchemy import func
+from sqlalchemy import func, and_
 
 from flask_sqlalchemy import SQLAlchemy
 
@@ -26,6 +26,10 @@ def where_pk_query(cls, pk, **kwargs):
         query = kwargs["query"]
         return query.where(BarModel.pk == pk)
 
+    @qfilter
+    def where_pk_list(cls, pk, **kwargs):
+        return and_(*[BarModel.pk == pk])
+
 
 @pytest.fixture(scope="session")
 def app():
@@ -54,7 +58,7 @@ def bar(app):
 
 
 class TestQfilter:
-    def test_qfilter_returns_whereclause(self, bar):
+    def test_qfilter(self, bar):
         assert db.session.scalars(BarModel.where_pk_query(bar.pk)).one_or_none() is bar
         assert (
             db.session.scalars(db.select(BarModel).where(BarModel.where_pk(bar.pk))).one_or_none()
@@ -68,3 +72,10 @@ def test_qfilter_returns_whereclause(self, bar):
             ).one_or_none()
             is not bar
         )
+
+        assert (
+            db.session.scalars(
+                db.select(BarModel).where(BarModel.where_pk_list(bar.pk))
+            ).one_or_none()
+            is bar
+        )