diff --git a/flask_rest_jsonapi/data_layers/alchemy.py b/flask_rest_jsonapi/data_layers/alchemy.py index f08b2ec..2cf231b 100644 --- a/flask_rest_jsonapi/data_layers/alchemy.py +++ b/flask_rest_jsonapi/data_layers/alchemy.py @@ -452,7 +452,12 @@ def sort_query(self, query, sort_info): field = sort_opt['field'] if not hasattr(self.model, field): raise InvalidSort("{} has no attribute {}".format(self.model.__name__, field)) - query = query.order_by(getattr(getattr(self.model, field), sort_opt['order'])()) + if sort_opt['relationship']: + relationField = getattr(self.model, field) + relationClass = relationField.mapper.class_ + query = query.join(relationField).order_by(getattr(getattr(relationClass, 'id'), sort_opt['order'])()) + else: + query = query.order_by(getattr(getattr(self.model, field), sort_opt['order'])()) return query def paginate_query(self, query, paginate_info): diff --git a/flask_rest_jsonapi/querystring.py b/flask_rest_jsonapi/querystring.py index 3bfc2b5..6bebbaa 100644 --- a/flask_rest_jsonapi/querystring.py +++ b/flask_rest_jsonapi/querystring.py @@ -165,14 +165,16 @@ def sorting(self): if self.qs.get('sort'): sorting_results = [] for sort_field in self.qs['sort'].split(','): + relationship = False field = sort_field.replace('-', '') if field not in self.schema._declared_fields: raise InvalidSort("{} has no attribute {}".format(self.schema.__name__, field)) if field in get_relationships(self.schema): - raise InvalidSort("You can't sort on {} because it is a relationship field".format(field)) + #raise InvalidSort("You can't sort on {} because it is a relationship field".format(field)) + relationship = True field = get_model_field(self.schema, field) order = 'desc' if sort_field.startswith('-') else 'asc' - sorting_results.append({'field': field, 'order': order}) + sorting_results.append({'field': field, 'order': order, 'relationship': relationship}) return sorting_results return [] diff --git a/tests/test_sqlalchemy_data_layer.py b/tests/test_sqlalchemy_data_layer.py index d5f62a3..11b9202 100644 --- a/tests/test_sqlalchemy_data_layer.py +++ b/tests/test_sqlalchemy_data_layer.py @@ -389,9 +389,9 @@ def test_query_string_manager(person_schema): qsm = QSManager(query_string, person_schema) with pytest.raises(BadRequest): qsm.pagination - qsm.qs['sort'] = 'computers' - with pytest.raises(InvalidSort): - qsm.sorting + #qsm.qs['sort'] = 'computers' + #with pytest.raises(InvalidSort): + #qsm.sorting def test_resource(app, person_model, person_schema, session, monkeypatch): @@ -474,6 +474,46 @@ def test_get_list(client, register_routes, person, person_2): response = client.get('/persons' + '?' + querystring, content_type='application/vnd.api+json') assert response.status_code == 200 +def test_get_list_sort_relationship(client, register_routes, person, person_2): + with client: + querystring = urlencode({'page[number]': 1, + 'page[size]': 1, + 'fields[person]': 'name,birth_date', + 'sort': '-computers', + 'include': 'computers.owner', + 'filter': json.dumps( + [ + { + 'and': [ + { + 'name': 'computers', + 'op': 'any', + 'val': { + 'name': 'serial', + 'op': 'eq', + 'val': '0000' + } + }, + { + 'or': [ + { + 'name': 'name', + 'op': 'like', + 'val': '%test%' + }, + { + 'name': 'name', + 'op': 'like', + 'val': '%test2%' + } + ] + } + ] + } + ])}) + response = client.get('/persons' + '?' + querystring, content_type='application/vnd.api+json') + assert response.status_code == 200 + def test_get_list_disable_pagination(client, register_routes): with client: