Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugfix + Unify digraph and multidigraph behaviour #46

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 110 additions & 29 deletions grandcypher/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,12 @@


return_clause : "return"i distinct_return? return_item ("," return_item)*
return_item : entity_id | aggregation_function | entity_id "." attribute_id
return_item : (entity_id | aggregation_function | entity_id "." attribute_id) ( "AS"i alias )?

aggregation_function : AGGREGATE_FUNC "(" entity_id ( "." attribute_id )? ")"
AGGREGATE_FUNC : "COUNT" | "SUM" | "AVG" | "MAX" | "MIN"
attribute_id : CNAME
alias : CNAME

distinct_return : "DISTINCT"i
limit_clause : "limit"i NUMBER
Expand All @@ -97,7 +98,7 @@

order_items : order_item ("," order_item)*

order_item : entity_id order_direction?
order_item : (entity_id | aggregation_function) order_direction?

order_direction : "ASC"i -> asc
| "DESC"i -> desc
Expand Down Expand Up @@ -363,7 +364,7 @@ def inner(


def _data_path_to_entity_name_attribute(data_path):
if not isinstance(data_path, str):
if isinstance(data_path, Token):
data_path = data_path.value
if "." in data_path:
entity_name, entity_attribute = data_path.split(".")
Expand All @@ -376,7 +377,9 @@ def _data_path_to_entity_name_attribute(data_path):

class _GrandCypherTransformer(Transformer):
def __init__(self, target_graph: nx.Graph, limit=None):
self._target_graph = target_graph
self._target_graph = nx.MultiDiGraph(target_graph)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a super smart change and simplifies a ton — good thinking! there's probably a ton of business logic we can strip out as a result... thinking out loud, maybe makes sense to put in a test coverage library to auto-detect those chunks...

any performance hits you can think of as a result of doing this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a test coverage library to auto-detect those chunks

that sounds like a great idea! Haven't used many test coverage libraries myself so open to suggestions :)

Also, w.r.t to performance hit I'm unsure about the impact of changing to MultiDiGraph -- at least in practice it appears to be similar. Probably a good idea to benchmark future versions

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've been liking codspeed (e.g., aplbrain/grand#48) — maybe a cool thing to extend to this repo someday!

self._entity2alias = dict()
self._alias2entity = dict()
self._paths = []
self._where_condition: CONDITION = None
self._motif = nx.MultiDiGraph()
Expand All @@ -385,6 +388,7 @@ def __init__(self, target_graph: nx.Graph, limit=None):
self._return_requests = []
self._return_edges = {}
self._aggregate_functions = []
self._aggregation_attributes = set()
self._distinct = False
self._order_by = None
self._order_by_attributes = set()
Expand Down Expand Up @@ -491,12 +495,15 @@ def _lookup(self, data_paths: List[str], offset_limit) -> Dict[str, List]:
ret_with_attr = []
for r in ret:
r_attr = {}
for i, v in r.items():
r_attr[(i, list(v.get("__labels__"))[0])] = v.get(
entity_attribute, None
)
# eg, [{(0, 'paid'): 70, (1, 'paid'): 90}, {(0, 'paid'): 400, (1, 'friend'): None, (2, 'paid'): 650}]
ret_with_attr.append(r_attr)
if isinstance(r, dict):
r = [r]
for el in r:
for i, v in el.items():
r_attr[(i, list(v.get("__labels__", [i]))[0])] = v.get(
entity_attribute, None
)
# eg, [{(0, 'paid'): 70, (1, 'paid'): 90}, {(0, 'paid'): 400, (1, 'friend'): None, (2, 'paid'): 650}]
ret_with_attr.append(r_attr)

ret = ret_with_attr

Expand All @@ -508,31 +515,73 @@ def return_clause(self, clause):
# collect all entity identifiers to be returned
for item in clause:
if item:
alias = self._extract_alias(item)
item = item.children[0] if isinstance(item, Tree) else item
if isinstance(item, Tree) and item.data == "aggregation_function":
func = str(item.children[0].value) # AGGREGATE_FUNC
entity = str(item.children[1].value)
if len(item.children) > 2:
entity += "." + str(item.children[2].children[0].value)
func, entity = self._parse_aggregation_token(item)
if alias:
self._entity2alias[self._format_aggregation_key(func, entity)] = alias
self._aggregation_attributes.add(entity)
self._aggregate_functions.append((func, entity))
self._return_requests.append(entity)
else:
if not isinstance(item, str):
item = str(item.value)

if alias:
self._entity2alias[item] = alias
self._return_requests.append(item)

self._alias2entity.update({v: k for k, v in self._entity2alias.items()})

def _extract_alias(self, item: Tree):
'''
Extract the alias from the return item (if it exists)
'''

if len(item.children) == 1:
return None
item_keys = [it.data if isinstance(it, Tree) else None for it in item.children]
if any(k == 'alias' for k in item_keys):
# get the index of the alias
alias_index = item_keys.index('alias')
return str(item.children[alias_index].children[0].value)

return None

def _parse_aggregation_token(self, item: Tree):
'''
Parse the aggregation function token and return the function and entity
input: Tree('aggregation_function', [Token('AGGREGATE_FUNC', 'SUM'), Token('CNAME', 'r'), Tree('attribute_id', [Token('CNAME', 'value')])])
output: ('SUM', 'r.value')
'''
func = str(item.children[0].value) # AGGREGATE_FUNC
entity = str(item.children[1].value)
if len(item.children) > 2:
entity += "." + str(item.children[2].children[0].value)

return func, entity

def _format_aggregation_key(self, func, entity):
return f"{func}({entity})"

def order_clause(self, order_clause):
self._order_by = []
for item in order_clause[0].children:
field = str(item.children[0]) # assuming the field name is the first child
if isinstance(item.children[0], Tree) and item.children[0].data == "aggregation_function":
func, entity = self._parse_aggregation_token(item.children[0])
field = self._format_aggregation_key(func, entity)
self._order_by_attributes.add(entity)
else:
field = str(item.children[0]) # assuming the field name is the first child
self._order_by_attributes.add(field)

# Default to 'ASC' if not specified
if len(item.children) > 1 and str(item.children[1].data).lower() != "desc":
direction = "ASC"
else:
direction = "DESC"

self._order_by.append((field, direction)) # [('n.age', 'DESC'), ...]
self._order_by_attributes.add(field)

def distinct_return(self, distinct):
self._distinct = True
Expand Down Expand Up @@ -616,8 +665,11 @@ def _collate_data(data, unique_labels, func):

def returns(self, ignore_limit=False):

data_paths = self._return_requests + list(self._order_by_attributes) + list(self._aggregation_attributes)
# aliases should already be requested in their original form, so we will remove them for lookup
data_paths = [d for d in data_paths if d not in self._alias2entity]
results = self._lookup(
self._return_requests + list(self._order_by_attributes),
data_paths,
offset_limit=slice(0, None),
)
if len(self._aggregate_functions) > 0:
Expand All @@ -630,46 +682,75 @@ def returns(self, ignore_limit=False):
aggregated_results = {}
for func, entity in self._aggregate_functions:
aggregated_data = self.aggregate(func, results, entity, group_keys)
func_key = f"{func}({entity})"
func_key = self._format_aggregation_key(func, entity)
aggregated_results[func_key] = aggregated_data
self._return_requests.append(func_key)
results.update(aggregated_results)

# update the results with the given alias(es)
results = {self._entity2alias.get(k, k): v for k, v in results.items()}

if self._order_by:
results = self._apply_order_by(results)
if self._distinct:
results = self._apply_distinct(results)
results = self._apply_pagination(results, ignore_limit)

# Exclude order-by-only attributes from the final results
# Only include keys that were asked for in `RETURN` in the final results
results = {
key: values
for key, values in results.items()
if key in self._return_requests
if self._alias2entity.get(key, key) in self._return_requests
}

return results

def _apply_order_by(self, results):
if self._order_by:
sort_lists = [
(results[field], direction)
(results[field], field, direction)
for field, direction in self._order_by
if field in results
]

if sort_lists:
# Generate a list of indices sorted by the specified fields
indices = range(
len(next(iter(results.values())))
) # Safe because all lists are assumed to be of the same length
for sort_list, direction in reversed(
for (sort_list, field, direction) in reversed(
sort_lists
): # reverse to ensure the first sort key is primary
indices = sorted(
indices,
key=lambda i: sort_list[i],
reverse=(direction == "DESC"),
)

if all(isinstance(item, dict) for item in sort_list):
# (for edge attributes) If all items in sort_list are dictionaries
# example: ([{(0, 'paid'): 9, (1, 'paid'): 40}, {(0, 'paid'): 14}], 'DESC')

# sort within each edge first
sorted_sublists = []
for sublist in sort_list:
sorted_sublist = sorted(
sublist.items(),
key=lambda x: x[1] or 0, # 0 if `None`
reverse=(direction == "DESC"),
)
sorted_sublists.append({k: v for k, v in sorted_sublist})
sort_list = sorted_sublists

# then sort the indices based on the sorted sublists
indices = sorted(
indices,
key=lambda i: list(sort_list[i].values())[0] or 0, # 0 if `None`
reverse=(direction == "DESC"),
)
# update results with sorted edge attributes list
results[field] = sort_list
else:
# (for node attributes) single values
indices = sorted(
indices,
key=lambda i: sort_list[i],
reverse=(direction == "DESC"),
)

# Reorder all lists in results using sorted indices
for key in results:
Expand Down
Loading
Loading