Skip to content

Commit

Permalink
Even more features for G.visualize
Browse files Browse the repository at this point in the history
  • Loading branch information
adamnsch committed Oct 18, 2024
1 parent 41f5f1f commit b01e44c
Showing 1 changed file with 90 additions and 29 deletions.
119 changes: 90 additions & 29 deletions graphdatascience/graph/graph_object.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from __future__ import annotations
from itertools import chain

import colorsys
import random
Expand Down Expand Up @@ -82,7 +83,6 @@ def node_count(self) -> int:
"""
Returns:
the number of nodes in the graph
"""
return self._graph_info(["nodeCount"]) # type: ignore

Expand Down Expand Up @@ -191,7 +191,6 @@ def drop(self, failIfMissing: bool = False) -> "Series[str]":
Returns:
the result of the drop operation
"""
result = self._query_runner.call_procedure(
endpoint="gds.graph.drop",
Expand All @@ -205,7 +204,6 @@ def creation_time(self) -> Any: # neo4j.time.DateTime not exported
"""
Returns:
the creation time of the graph
"""
return self._graph_info(["creationTime"])

Expand Down Expand Up @@ -236,12 +234,56 @@ def __repr__(self) -> str:

def visualize(
self,
notebook: bool = True,
node_count: int = 100,
directed: bool = True,
center_nodes: Optional[List[int]] = None,
include_node_properties: List[str] = None,
color_property: Optional[str] = None,
size_property: Optional[str] = None,
include_node_properties: Optional[List[str]] = None,
rel_weight_property: Optional[str] = None,
notebook: bool = True,
px_height: int = 750,
theme: str = "dark",
) -> Any:
"""
Visualize the `Graph` in an interactive graphical interface.
The graph will be sampled down to specified `node_count` to limit computationally expensive rendering.
Args:
node_count: number of nodes in the graph to be visualized
directed: whether or not to display relationships as directed
center_nodes: nodes around subgraph will be sampled, if sampling is necessary
color_property: node property that determines node categories for coloring. Default is to use node labels
size_property: node property that determines the size of nodes. Default is to compute a page rank for this
include_node_properties: node properties to include for mouse-over inspection
rel_weight_property: relationship property that determines width of relationships
notebook: whether or not the code is run in a notebook
px_height: the height of the graphic containing output the visualization
theme: coloring theme for the visualization. "light" or "dark"
Returns:
an interactive graphical visualization of the specified graph
"""

actual_node_properties = list(chain.from_iterable(self.node_properties().to_dict().values()))
if (color_property is not None) and (color_property not in actual_node_properties):
raise ValueError(f"There is no node property '{color_property}' in graph '{self._name}'")

if size_property is not None and size_property not in actual_node_properties:
raise ValueError(f"There is no node property '{size_property}' in graph '{self._name}'")

if include_node_properties is not None:
for prop in include_node_properties:
if prop not in actual_node_properties:
raise ValueError(f"There is no node property '{prop}' in graph '{self._name}'")

actual_rel_properties = list(chain.from_iterable(self.relationship_properties().to_dict().values()))
if rel_weight_property is not None and rel_weight_property not in actual_rel_properties:
raise ValueError(f"There is no relationship property '{rel_weight_property}' in graph '{self._name}'")

if theme not in {"light", "dark"}:
raise ValueError(f"Color `theme` '{theme}' is not allowed. Must be either 'light' or 'dark'")

visual_graph = self._name
if self.node_count() > node_count:
visual_graph = str(uuid4())
Expand All @@ -256,14 +298,19 @@ def visualize(
custom_error=False,
)

pr_prop = str(uuid4())
self._query_runner.call_procedure(
endpoint="gds.pageRank.mutate",
params=CallParameters(graph_name=visual_graph, config=dict(mutateProperty=pr_prop)),
custom_error=False,
)
# Make sure we always have at least a size property so that we can run `gds.graph.nodeProperties.stream`
if size_property is None:
size_property = str(uuid4())
self._query_runner.call_procedure(
endpoint="gds.pageRank.mutate",
params=CallParameters(graph_name=visual_graph, config=dict(mutateProperty=size_property)),
custom_error=False,
)
clean_up_size_prop = True
else:
clean_up_size_prop = False

node_properties = [pr_prop]
node_properties = [size_property]
if include_node_properties is not None:
node_properties.extend(include_node_properties)

Expand Down Expand Up @@ -295,11 +342,18 @@ def visualize(
result.columns.name = None
node_properties_df = result

relationships_df = self._query_runner.call_procedure(
endpoint="gds.graph.relationships.stream",
params=CallParameters(graph_name=visual_graph),
custom_error=False,
)
if rel_weight_property is None:
relationships_df = self._query_runner.call_procedure(
endpoint="gds.graph.relationships.stream",
params=CallParameters(graph_name=visual_graph),
custom_error=False,
)
else:
relationships_df = self._query_runner.call_procedure(
endpoint="gds.graph.relationshipProperty.stream",
params=CallParameters(graph_name=visual_graph, properties=rel_weight_property),
custom_error=False,
)

# Clean up
if visual_graph != self._name:
Expand All @@ -308,10 +362,10 @@ def visualize(
params=CallParameters(graph_name=visual_graph),
custom_error=False,
)
else:
elif clean_up_size_prop:
self._query_runner.call_procedure(
endpoint="gds.graph.nodeProperties.drop",
params=CallParameters(graph_name=visual_graph, nodeProperties=pr_prop),
params=CallParameters(graph_name=visual_graph, nodeProperties=size_property),
custom_error=False,
)

Expand All @@ -320,19 +374,21 @@ def visualize(
net = Network(
notebook=True if notebook else False,
cdn_resources="remote" if notebook else "local",
bgcolor="#222222", # Dark background
font_color="white",
height="750px", # Modify according to your screen size
directed=directed,
bgcolor="#222222" if theme == "dark" else "#F2F2F2",
font_color="white" if theme == "dark" else "black",
height=f"{px_height}px",
width="100%",
)

if color_property is None:
color_map = {label: self._random_bright_color() for label in self.node_labels()}
color_map = {label: self._random_themed_color(theme) for label in self.node_labels()}
else:
color_map = {
prop_val: self._random_bright_color() for prop_val in node_properties_df[color_property].unique()
prop_val: self._random_themed_color(theme) for prop_val in node_properties_df[color_property].unique()
}

# Add all the nodes
for _, node in node_properties_df.iterrows():
title = f"Node ID: {node['nodeId']}\nLabels: {node['nodeLabels']}"
if include_node_properties is not None:
Expand All @@ -347,17 +403,22 @@ def visualize(

net.add_node(
int(node["nodeId"]),
value=node[pr_prop],
value=node[size_property],
color=color,
title=title,
)

# Add all the relationships
net.add_edges(zip(relationships_df["sourceNodeId"], relationships_df["targetNodeId"]))
for _, rel in relationships_df.iterrows():
if rel_weight_property is None:
net.add_edge(rel["sourceNodeId"], rel["targetNodeId"], title=f"Type: {rel['relationshipType']}")
else:
title = f"Type: {rel['relationshipType']}\n{rel_weight_property} = {rel['rel_weight_property']}"
net.add_edge(rel["sourceNodeId"], rel["targetNodeId"], title=title, value=rel[rel_weight_property])

return net.show(f"{self._name}.html")

@staticmethod
def _random_bright_color() -> str:
h = random.randint(0, 255) / 255.0
return "#%02X%02X%02X" % tuple(map(lambda x: int(x * 255), colorsys.hls_to_rgb(h, 0.7, 1.0)))
def _random_themed_color(theme) -> str:
l = 0.7 if theme == "dark" else 0.4
return "#%02X%02X%02X" % tuple(map(lambda x: int(x * 255), colorsys.hls_to_rgb(random.random(), l, 1.0)))

0 comments on commit b01e44c

Please sign in to comment.