diff --git a/hiku/validate/query.py b/hiku/validate/query.py index e2a0c586..c00dd077 100644 --- a/hiku/validate/query.py +++ b/hiku/validate/query.py @@ -533,3 +533,53 @@ def validate(graph: Graph, query: QueryNode) -> t.List[str]: query_validator = QueryValidator(graph) query_validator.visit(query) return query_validator.errors.list + + +# TODO: add tests +class QueryComplexityValidator(QueryVisitor): + def __init__(self): + self.complexity = 0 + + def validate(self, obj: QueryNode) -> int: + self.visit(obj) + # TODO: should we return complexity or accept max_complexity in init and raise/return error ? + return self.complexity + + def visit_field(self, obj: QueryField) -> None: + # TODO: consider options presece + self.complexity += 1 + + def visit_link(self, obj: QueryLink) -> None: + # TODO do we need to inclide link ? What if complex field - what complexity for records ? + self.complexity += 1 + super().visit_link(obj) + + +# TODO: add tests +class QueryDepthValidator(QueryVisitor): + def __init__(self): + self.current_depth = 0 + self.max_depth = 0 + + def validate(self, obj: QueryNode) -> int: + self.visit(obj) + # TODO: should we return complexity or accept max_complexity in init and raise/return error ? + return self.max_depth + + def visit_field(self, obj: QueryField) -> None: + pass # Do nothing for individual fields. + + def visit_link(self, obj: QueryLink) -> None: + # TODO: either calculate query depth in Link or in Node, but not in both + self.current_depth += 1 + super().visit_link(obj) + self.current_depth -= 1 + + def visit_node(self, obj: QueryNode) -> None: + self.current_depth += 1 + self.max_depth = max(self.max_depth, self.current_depth) + super().visit_node(obj) + self.current_depth -= 1 + + def visit_fragment(self, obj: Fragment) -> None: + super().visit_fragment(obj) \ No newline at end of file diff --git a/tests/test_validate_query.py b/tests/test_validate_query.py index d13058a5..55a914f3 100644 --- a/tests/test_validate_query.py +++ b/tests/test_validate_query.py @@ -4,7 +4,7 @@ from hiku.graph import Graph, Node, Field, Link, Option, Root from hiku.types import Integer, Record, Sequence, Optional, TypeRef, Boolean from hiku.types import String, Mapping, Any -from hiku.validate.query import validate +from hiku.validate.query import QueryComplexityValidator, QueryDepthValidator, validate def _(): @@ -796,3 +796,38 @@ def test_any_in_option(): 'Invalid value for option "root.get:foo", ' '"str" instead of Mapping[String, Any]' ] + + +# TODO test complex field +# TODO test with options +# TODO: test with records +# TODO: test with fragments +# TODO: research about query complexity techniques +def test_query_complexity_validator(): + validator = QueryComplexityValidator() + + query = q.Node([ + q.Field("a"), + q.Field("b"), + q.Link("c", q.Node([ + q.Field("d"), + q.Field("e"), + ])), + ]) + + assert validator.validate(query) == 5 + + +def test_query_depth_validator(): + validator = QueryDepthValidator() + + query = q.Node([ + q.Field("a"), + q.Field("b"), + q.Link("c", q.Node([ + q.Field("d"), + q.Field("e"), + ])), + ]) + + assert validator.validate(query) == 3