Skip to content

Commit

Permalink
Merge pull request #51 from cloudblue/feature/LITE-24359
Browse files Browse the repository at this point in the history
LITE-24359 Adapted library for integration of various django extensions
  • Loading branch information
maxipavlovic authored Jul 13, 2022
2 parents fd707b9 + 94c7cc7 commit 42d1969
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 60 deletions.
29 changes: 15 additions & 14 deletions dj_rql/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,24 +31,25 @@


class FilterTypes(FT):
mapper = [
(models.AutoField, FT.INT),
(models.BooleanField, FT.BOOLEAN),
(models.NullBooleanField, FT.BOOLEAN),
(models.DateTimeField, FT.DATETIME),
(models.DateField, FT.DATE),
(models.DecimalField, FT.DECIMAL),
(models.FloatField, FT.FLOAT),
(models.IntegerField, FT.INT),
(models.TextField, FT.STRING),
(models.UUIDField, FT.STRING),
(models.CharField, FT.STRING),
]

@classmethod
def field_filter_type(cls, field):
mapper = [
(models.AutoField, cls.INT),
(models.BooleanField, cls.BOOLEAN),
(models.NullBooleanField, cls.BOOLEAN),
(models.DateTimeField, cls.DATETIME),
(models.DateField, cls.DATE),
(models.DecimalField, cls.DECIMAL),
(models.FloatField, cls.FLOAT),
(models.IntegerField, cls.INT),
(models.TextField, cls.STRING),
(models.UUIDField, cls.STRING),
(models.CharField, cls.STRING),
]
return next(
(
filter_type for base_cls, filter_type in mapper
filter_type for base_cls, filter_type in cls.mapper
if issubclass(field.__class__, base_cls)
),
cls.STRING,
Expand Down
95 changes: 64 additions & 31 deletions dj_rql/filter_cls.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,12 @@ class RQLFilterClass:
QUERIES_CACHE_SIZE = 20
"""Default number of cached queries."""

Q_CLS = Q
"""Class for building nodes of the query, generated by django."""

FILTER_TYPES_CLS = FilterTypes
"""Class for the mapping of model field types to filter types."""

def __init__(self, queryset, instance=None):
self.queryset = queryset
self._is_distinct = self.DISTINCT
Expand All @@ -82,9 +88,13 @@ def __init__(self, queryset, instance=None):
self._validate_init()
self._default_init(self._get_init_filters())

@classmethod
def _is_valid_model_cls(cls, model):
return issubclass(model, Model)

def _validate_init(self):
e = 'Django model must be set for Filter Class.'
assert self.MODEL and issubclass(self.MODEL, Model), e
assert self.MODEL and self._is_valid_model_cls(self.MODEL), e

e = 'Wrong filter settings type for Filter Class.'
assert (self.FILTERS is None) or isinstance(self.FILTERS, iterable_types), e
Expand Down Expand Up @@ -258,7 +268,7 @@ def build_q_for_filter(self, data):

base_item = self.get_filter_base_item(filter_name)
if not base_item:
return Q()
return self.Q_CLS()

if base_item.get('distinct'):
self._is_distinct = True
Expand Down Expand Up @@ -311,7 +321,7 @@ def build_q_for_filter(self, data):
return self._build_django_q(filter_item, django_lookup, filter_lookup, typed_value)

# filter has different DB field 'sources'
q = Q()
q = self.Q_CLS()
for item in filter_item:
item_q = self._build_django_q(item, django_lookup, filter_lookup, typed_value)
if filter_lookup == FilterLookups.NE:
Expand Down Expand Up @@ -432,7 +442,7 @@ def _build_q_for_search(self, operator, str_value):

unquoted_value = self.remove_quotes(str_value)
if not unquoted_value:
return Q()
return self.Q_CLS()

if not unquoted_value.startswith(RQL_ANY_SYMBOL):
unquoted_value = '*' + unquoted_value
Expand All @@ -449,7 +459,7 @@ def _build_q_for_search(self, operator, str_value):
return q

def _build_q_for_extended_search(self, str_value):
q = Q()
q = self.Q_CLS()
extended_search_filter_lookup = FilterLookups.I_LIKE

for django_orm_route in self.EXTENDED_SEARCH_ORM_ROUTES:
Expand Down Expand Up @@ -591,9 +601,9 @@ def _build_filters(self, filters, **kwargs):
orm_field_name = item.get('source', namespace)
related_orm_route = '{0}{1}__'.format(orm_route, orm_field_name)

related_model = self._get_field(
related_model = self._get_field_related_model(self._get_field(
_model, orm_field_name, get_related=True,
).related_model
))

qs = item.get('qs')
tree, p_qs = self._fill_select_tree(
Expand Down Expand Up @@ -735,6 +745,14 @@ def _extend_annotations(self):

self.annotations.update(dict(extended_annotations))

@classmethod
def _is_field_supported(cls, field):
return isinstance(field, SUPPORTED_FIELD_TYPES)

@classmethod
def _get_field_related_model(cls, field):
return field.related_model

@classmethod
def _get_field(cls, base_model, field_name, get_related=False):
""" Django ORM field getter.
Expand All @@ -750,10 +768,10 @@ def _get_field(cls, base_model, field_name, get_related=False):
current_field = cls._get_model_field(current_model, part)
if index == field_name_parts_length:
e = 'Unsupported field type: {0}.'.format(field_name)
assert get_related or isinstance(current_field, SUPPORTED_FIELD_TYPES), e
assert get_related or cls._is_field_supported(current_field), e

return current_field
current_model = current_field.related_model
current_model = cls._get_field_related_model(current_field)

@staticmethod
def _get_field_name_parts(field_name):
Expand All @@ -762,6 +780,10 @@ def _get_field_name_parts(field_name):

return field_name.split('.' if '.' in field_name else '__')

@classmethod
def _is_field_nullable(cls, field):
return field.null or cls._is_pk_field(field)

@classmethod
def _build_mapped_item(cls, field, field_orm_route, **kwargs):
lookups = kwargs.get('lookups')
Expand All @@ -771,8 +793,8 @@ def _build_mapped_item(cls, field, field_orm_route, **kwargs):
openapi = kwargs.get('openapi')
hidden = kwargs.get('hidden')

possible_lookups = lookups or FilterTypes.default_field_filter_lookups(field)
if not (field.null or cls._is_pk_field(field)):
possible_lookups = lookups or cls.FILTER_TYPES_CLS.default_field_filter_lookups(field)
if not cls._is_field_nullable(field):
possible_lookups.discard(FilterLookups.NULL)

result = {
Expand Down Expand Up @@ -924,40 +946,50 @@ def _escape_regex_special_symbols(str_value):

@classmethod
def _convert_value(cls, django_field, str_value, use_repr=False):
ft_cls = cls.FILTER_TYPES_CLS
val = cls.remove_quotes(str_value)
filter_type = FilterTypes.field_filter_type(django_field)
filter_type = ft_cls.field_filter_type(django_field)

if filter_type == FilterTypes.FLOAT:
if filter_type == ft_cls.FLOAT:
return float(val)

elif filter_type == FilterTypes.DECIMAL:
if '.' in val:
integer_part, fractional_part = val.split('.', 1)
val = integer_part + '.' + fractional_part[:django_field.decimal_places]
return decimal.Decimal(val)
elif filter_type == ft_cls.DECIMAL:
return cls._convert_decimal_value(val, django_field)

elif filter_type == FilterTypes.DATE:
elif filter_type == ft_cls.DATE:
return cls._convert_date_value(val)

elif filter_type == FilterTypes.DATETIME:
elif filter_type == ft_cls.DATETIME:
return cls._convert_datetime_value(val)

elif filter_type == FilterTypes.BOOLEAN:
elif filter_type == ft_cls.BOOLEAN:
return cls._convert_boolean_value(val)

if val == RQL_EMPTY:
if (filter_type == FilterTypes.INT) or (not django_field.blank):
if (filter_type == ft_cls.INT) or (not django_field.blank):
raise ValueError
return ''

choices = getattr(django_field, 'choices', None)
if not choices:
if filter_type == FilterTypes.INT:
if filter_type == ft_cls.INT:
return int(val)
return val

return cls._get_choices_field_db_value(val, choices, filter_type, use_repr)

@classmethod
def _convert_decimal_value(cls, value, field):
if '.' in value:
integer_part, fractional_part = value.split('.', 1)
value = integer_part + '.' + fractional_part[:cls._get_decimal_field_precision(field)]

return decimal.Decimal(value)

@classmethod
def _get_decimal_field_precision(cls, field):
return field.decimal_places

@staticmethod
def _convert_date_value(value):
dt = parse_date(value)
Expand Down Expand Up @@ -1002,8 +1034,8 @@ def _get_choices_field_db_value(cls, value, choices, filter_type, use_repr):
except StopIteration:
raise ValueError

@staticmethod
def _get_choice_class_db_value(value, choices, filter_type, use_repr):
@classmethod
def _get_choice_class_db_value(cls, value, choices, filter_type, use_repr):
if use_repr:
try:
db_value = next(
Expand All @@ -1013,7 +1045,7 @@ def _get_choice_class_db_value(value, choices, filter_type, use_repr):
except StopIteration:
raise ValueError

if filter_type == FilterTypes.INT:
if filter_type == cls.FILTER_TYPES_CLS.INT:
db_value = int(value)
else:
db_value = value
Expand All @@ -1024,8 +1056,8 @@ def _get_choice_class_db_value(value, choices, filter_type, use_repr):
return db_value

def _build_django_q(self, filter_item, django_lookup, filter_lookup, typed_value):
kwargs = {'{0}__{1}'.format(filter_item['orm_route'], django_lookup): typed_value}
return ~Q(**kwargs) if filter_lookup == FilterLookups.NE else Q(**kwargs)
q = self.Q_CLS(**{'{0}__{1}'.format(filter_item['orm_route'], django_lookup): typed_value})
return ~q if filter_lookup == FilterLookups.NE else q

@staticmethod
def _get_filter_lookup_by_operator(grammar_operator):
Expand Down Expand Up @@ -1082,9 +1114,10 @@ def _check_dynamic(filter_item, filter_name, filter_route):
e = "{0}: common filters can't have 'field' set.".format(filter_name)
assert not filter_item.get('custom', False) and field is None, e

@staticmethod
def _check_search(filter_item, filter_name, field):
is_non_string_field_type = FilterTypes.field_filter_type(field) != FilterTypes.STRING
@classmethod
def _check_search(cls, filter_item, filter_name, field):
ft_cls = cls.FILTER_TYPES_CLS
is_non_string_field_type = ft_cls.field_filter_type(field) != ft_cls.STRING

e = "{0}: 'search' can be applied only to text filters.".format(filter_name)
assert not (filter_item.get('search') and is_non_string_field_type), e
Expand Down
33 changes: 19 additions & 14 deletions dj_rql/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@

from dj_rql._dataclasses import FilterArgs

from django.db.models import Q

from lark import Tree

from py_rql.constants import (
Expand Down Expand Up @@ -44,6 +42,10 @@ def __init__(self, filter_cls_instance):

self.__visit_tokens__ = False

@property
def _q(self):
return self._filter_cls_instance.Q_CLS

def _push_namespace(self, tree):
if tree.data in self.NAMESPACE_PROVIDERS:
self._namespace.append(None)
Expand All @@ -69,12 +71,11 @@ def _transform_tree(self, tree):
self._pop_namespace(tree)
return ret_value

@staticmethod
def _get_value(obj):
def _get_value(self, obj):
while isinstance(obj, Tree):
obj = obj.children[0]

if isinstance(obj, Q):
if isinstance(obj, self._q):
return obj

return obj.value
Expand All @@ -95,7 +96,7 @@ def start(self, args):
def comp(self, args):
prop, operation, value = self._extract_comparison(args)

if isinstance(value, Q):
if isinstance(value, self._q):
if operation == ComparisonOperators.EQ:
return value
else:
Expand All @@ -106,17 +107,21 @@ def comp(self, args):
return self._filter_cls_instance.build_q_for_filter(filter_args)

def tuple(self, args):
return Q(*args)
return self._q(*args)

def logical(self, args):
operation = args[0].data
children = args[0].children
if operation == LogicalOperators.get_grammar_key(LogicalOperators.NOT):
return ~Q(children[0])
return ~children[0]

q = self._q()
if operation == LogicalOperators.get_grammar_key(LogicalOperators.AND):
return Q(*children)
for child in children:
q &= child

return q

q = Q()
for child in children:
q |= child

Expand All @@ -127,10 +132,10 @@ def listing(self, args):
operation, prop = self._get_value(args[0]), self._get_value(args[1])
f_op = ComparisonOperators.EQ if operation == ListOperators.IN else ComparisonOperators.NE

q = Q()
q = self._q()
for value_tree in args[2:]:
value = self._get_value(value_tree)
if isinstance(value, Q):
if isinstance(value, self._q):
if f_op == ComparisonOperators.EQ:
field_q = value
else:
Expand Down Expand Up @@ -164,7 +169,7 @@ def ordering(self, args):
for prop in props:
self._filtered_props.add(prop.replace('-', '').replace('+', ''))

return Q()
return self._q()

def select(self, args):
assert not self._select
Expand All @@ -177,7 +182,7 @@ def select(self, args):
if not prop.startswith('-'):
self._filtered_props.add(prop.replace('+', ''))

return Q()
return self._q()


class RQLLimitOffsetTransformer(BaseRQLTransformer):
Expand Down
2 changes: 1 addition & 1 deletion requirements/dev.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
lib-rql==1.1.2
lib-rql>=1.1.3,<2
Django>=2.2.19

0 comments on commit 42d1969

Please sign in to comment.