Skip to content

Commit b01e44c

Browse files
committed
Even more features for G.visualize
1 parent 41f5f1f commit b01e44c

File tree

1 file changed

+90
-29
lines changed

1 file changed

+90
-29
lines changed

graphdatascience/graph/graph_object.py

Lines changed: 90 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from __future__ import annotations
2+
from itertools import chain
23

34
import colorsys
45
import random
@@ -82,7 +83,6 @@ def node_count(self) -> int:
8283
"""
8384
Returns:
8485
the number of nodes in the graph
85-
8686
"""
8787
return self._graph_info(["nodeCount"]) # type: ignore
8888

@@ -191,7 +191,6 @@ def drop(self, failIfMissing: bool = False) -> "Series[str]":
191191
192192
Returns:
193193
the result of the drop operation
194-
195194
"""
196195
result = self._query_runner.call_procedure(
197196
endpoint="gds.graph.drop",
@@ -205,7 +204,6 @@ def creation_time(self) -> Any: # neo4j.time.DateTime not exported
205204
"""
206205
Returns:
207206
the creation time of the graph
208-
209207
"""
210208
return self._graph_info(["creationTime"])
211209

@@ -236,12 +234,56 @@ def __repr__(self) -> str:
236234

237235
def visualize(
238236
self,
239-
notebook: bool = True,
240237
node_count: int = 100,
238+
directed: bool = True,
241239
center_nodes: Optional[List[int]] = None,
242-
include_node_properties: List[str] = None,
243240
color_property: Optional[str] = None,
241+
size_property: Optional[str] = None,
242+
include_node_properties: Optional[List[str]] = None,
243+
rel_weight_property: Optional[str] = None,
244+
notebook: bool = True,
245+
px_height: int = 750,
246+
theme: str = "dark",
244247
) -> Any:
248+
"""
249+
Visualize the `Graph` in an interactive graphical interface.
250+
The graph will be sampled down to specified `node_count` to limit computationally expensive rendering.
251+
252+
Args:
253+
node_count: number of nodes in the graph to be visualized
254+
directed: whether or not to display relationships as directed
255+
center_nodes: nodes around subgraph will be sampled, if sampling is necessary
256+
color_property: node property that determines node categories for coloring. Default is to use node labels
257+
size_property: node property that determines the size of nodes. Default is to compute a page rank for this
258+
include_node_properties: node properties to include for mouse-over inspection
259+
rel_weight_property: relationship property that determines width of relationships
260+
notebook: whether or not the code is run in a notebook
261+
px_height: the height of the graphic containing output the visualization
262+
theme: coloring theme for the visualization. "light" or "dark"
263+
264+
Returns:
265+
an interactive graphical visualization of the specified graph
266+
"""
267+
268+
actual_node_properties = list(chain.from_iterable(self.node_properties().to_dict().values()))
269+
if (color_property is not None) and (color_property not in actual_node_properties):
270+
raise ValueError(f"There is no node property '{color_property}' in graph '{self._name}'")
271+
272+
if size_property is not None and size_property not in actual_node_properties:
273+
raise ValueError(f"There is no node property '{size_property}' in graph '{self._name}'")
274+
275+
if include_node_properties is not None:
276+
for prop in include_node_properties:
277+
if prop not in actual_node_properties:
278+
raise ValueError(f"There is no node property '{prop}' in graph '{self._name}'")
279+
280+
actual_rel_properties = list(chain.from_iterable(self.relationship_properties().to_dict().values()))
281+
if rel_weight_property is not None and rel_weight_property not in actual_rel_properties:
282+
raise ValueError(f"There is no relationship property '{rel_weight_property}' in graph '{self._name}'")
283+
284+
if theme not in {"light", "dark"}:
285+
raise ValueError(f"Color `theme` '{theme}' is not allowed. Must be either 'light' or 'dark'")
286+
245287
visual_graph = self._name
246288
if self.node_count() > node_count:
247289
visual_graph = str(uuid4())
@@ -256,14 +298,19 @@ def visualize(
256298
custom_error=False,
257299
)
258300

259-
pr_prop = str(uuid4())
260-
self._query_runner.call_procedure(
261-
endpoint="gds.pageRank.mutate",
262-
params=CallParameters(graph_name=visual_graph, config=dict(mutateProperty=pr_prop)),
263-
custom_error=False,
264-
)
301+
# Make sure we always have at least a size property so that we can run `gds.graph.nodeProperties.stream`
302+
if size_property is None:
303+
size_property = str(uuid4())
304+
self._query_runner.call_procedure(
305+
endpoint="gds.pageRank.mutate",
306+
params=CallParameters(graph_name=visual_graph, config=dict(mutateProperty=size_property)),
307+
custom_error=False,
308+
)
309+
clean_up_size_prop = True
310+
else:
311+
clean_up_size_prop = False
265312

266-
node_properties = [pr_prop]
313+
node_properties = [size_property]
267314
if include_node_properties is not None:
268315
node_properties.extend(include_node_properties)
269316

@@ -295,11 +342,18 @@ def visualize(
295342
result.columns.name = None
296343
node_properties_df = result
297344

298-
relationships_df = self._query_runner.call_procedure(
299-
endpoint="gds.graph.relationships.stream",
300-
params=CallParameters(graph_name=visual_graph),
301-
custom_error=False,
302-
)
345+
if rel_weight_property is None:
346+
relationships_df = self._query_runner.call_procedure(
347+
endpoint="gds.graph.relationships.stream",
348+
params=CallParameters(graph_name=visual_graph),
349+
custom_error=False,
350+
)
351+
else:
352+
relationships_df = self._query_runner.call_procedure(
353+
endpoint="gds.graph.relationshipProperty.stream",
354+
params=CallParameters(graph_name=visual_graph, properties=rel_weight_property),
355+
custom_error=False,
356+
)
303357

304358
# Clean up
305359
if visual_graph != self._name:
@@ -308,10 +362,10 @@ def visualize(
308362
params=CallParameters(graph_name=visual_graph),
309363
custom_error=False,
310364
)
311-
else:
365+
elif clean_up_size_prop:
312366
self._query_runner.call_procedure(
313367
endpoint="gds.graph.nodeProperties.drop",
314-
params=CallParameters(graph_name=visual_graph, nodeProperties=pr_prop),
368+
params=CallParameters(graph_name=visual_graph, nodeProperties=size_property),
315369
custom_error=False,
316370
)
317371

@@ -320,19 +374,21 @@ def visualize(
320374
net = Network(
321375
notebook=True if notebook else False,
322376
cdn_resources="remote" if notebook else "local",
323-
bgcolor="#222222", # Dark background
324-
font_color="white",
325-
height="750px", # Modify according to your screen size
377+
directed=directed,
378+
bgcolor="#222222" if theme == "dark" else "#F2F2F2",
379+
font_color="white" if theme == "dark" else "black",
380+
height=f"{px_height}px",
326381
width="100%",
327382
)
328383

329384
if color_property is None:
330-
color_map = {label: self._random_bright_color() for label in self.node_labels()}
385+
color_map = {label: self._random_themed_color(theme) for label in self.node_labels()}
331386
else:
332387
color_map = {
333-
prop_val: self._random_bright_color() for prop_val in node_properties_df[color_property].unique()
388+
prop_val: self._random_themed_color(theme) for prop_val in node_properties_df[color_property].unique()
334389
}
335390

391+
# Add all the nodes
336392
for _, node in node_properties_df.iterrows():
337393
title = f"Node ID: {node['nodeId']}\nLabels: {node['nodeLabels']}"
338394
if include_node_properties is not None:
@@ -347,17 +403,22 @@ def visualize(
347403

348404
net.add_node(
349405
int(node["nodeId"]),
350-
value=node[pr_prop],
406+
value=node[size_property],
351407
color=color,
352408
title=title,
353409
)
354410

355411
# Add all the relationships
356-
net.add_edges(zip(relationships_df["sourceNodeId"], relationships_df["targetNodeId"]))
412+
for _, rel in relationships_df.iterrows():
413+
if rel_weight_property is None:
414+
net.add_edge(rel["sourceNodeId"], rel["targetNodeId"], title=f"Type: {rel['relationshipType']}")
415+
else:
416+
title = f"Type: {rel['relationshipType']}\n{rel_weight_property} = {rel['rel_weight_property']}"
417+
net.add_edge(rel["sourceNodeId"], rel["targetNodeId"], title=title, value=rel[rel_weight_property])
357418

358419
return net.show(f"{self._name}.html")
359420

360421
@staticmethod
361-
def _random_bright_color() -> str:
362-
h = random.randint(0, 255) / 255.0
363-
return "#%02X%02X%02X" % tuple(map(lambda x: int(x * 255), colorsys.hls_to_rgb(h, 0.7, 1.0)))
422+
def _random_themed_color(theme) -> str:
423+
l = 0.7 if theme == "dark" else 0.4
424+
return "#%02X%02X%02X" % tuple(map(lambda x: int(x * 255), colorsys.hls_to_rgb(random.random(), l, 1.0)))

0 commit comments

Comments
 (0)