1
1
from __future__ import annotations
2
+ from itertools import chain
2
3
3
4
import colorsys
4
5
import random
@@ -82,7 +83,6 @@ def node_count(self) -> int:
82
83
"""
83
84
Returns:
84
85
the number of nodes in the graph
85
-
86
86
"""
87
87
return self ._graph_info (["nodeCount" ]) # type: ignore
88
88
@@ -191,7 +191,6 @@ def drop(self, failIfMissing: bool = False) -> "Series[str]":
191
191
192
192
Returns:
193
193
the result of the drop operation
194
-
195
194
"""
196
195
result = self ._query_runner .call_procedure (
197
196
endpoint = "gds.graph.drop" ,
@@ -205,7 +204,6 @@ def creation_time(self) -> Any: # neo4j.time.DateTime not exported
205
204
"""
206
205
Returns:
207
206
the creation time of the graph
208
-
209
207
"""
210
208
return self ._graph_info (["creationTime" ])
211
209
@@ -236,12 +234,56 @@ def __repr__(self) -> str:
236
234
237
235
def visualize (
238
236
self ,
239
- notebook : bool = True ,
240
237
node_count : int = 100 ,
238
+ directed : bool = True ,
241
239
center_nodes : Optional [List [int ]] = None ,
242
- include_node_properties : List [str ] = None ,
243
240
color_property : Optional [str ] = None ,
241
+ size_property : Optional [str ] = None ,
242
+ include_node_properties : Optional [List [str ]] = None ,
243
+ rel_weight_property : Optional [str ] = None ,
244
+ notebook : bool = True ,
245
+ px_height : int = 750 ,
246
+ theme : str = "dark" ,
244
247
) -> Any :
248
+ """
249
+ Visualize the `Graph` in an interactive graphical interface.
250
+ The graph will be sampled down to specified `node_count` to limit computationally expensive rendering.
251
+
252
+ Args:
253
+ node_count: number of nodes in the graph to be visualized
254
+ directed: whether or not to display relationships as directed
255
+ center_nodes: nodes around subgraph will be sampled, if sampling is necessary
256
+ color_property: node property that determines node categories for coloring. Default is to use node labels
257
+ size_property: node property that determines the size of nodes. Default is to compute a page rank for this
258
+ include_node_properties: node properties to include for mouse-over inspection
259
+ rel_weight_property: relationship property that determines width of relationships
260
+ notebook: whether or not the code is run in a notebook
261
+ px_height: the height of the graphic containing output the visualization
262
+ theme: coloring theme for the visualization. "light" or "dark"
263
+
264
+ Returns:
265
+ an interactive graphical visualization of the specified graph
266
+ """
267
+
268
+ actual_node_properties = list (chain .from_iterable (self .node_properties ().to_dict ().values ()))
269
+ if (color_property is not None ) and (color_property not in actual_node_properties ):
270
+ raise ValueError (f"There is no node property '{ color_property } ' in graph '{ self ._name } '" )
271
+
272
+ if size_property is not None and size_property not in actual_node_properties :
273
+ raise ValueError (f"There is no node property '{ size_property } ' in graph '{ self ._name } '" )
274
+
275
+ if include_node_properties is not None :
276
+ for prop in include_node_properties :
277
+ if prop not in actual_node_properties :
278
+ raise ValueError (f"There is no node property '{ prop } ' in graph '{ self ._name } '" )
279
+
280
+ actual_rel_properties = list (chain .from_iterable (self .relationship_properties ().to_dict ().values ()))
281
+ if rel_weight_property is not None and rel_weight_property not in actual_rel_properties :
282
+ raise ValueError (f"There is no relationship property '{ rel_weight_property } ' in graph '{ self ._name } '" )
283
+
284
+ if theme not in {"light" , "dark" }:
285
+ raise ValueError (f"Color `theme` '{ theme } ' is not allowed. Must be either 'light' or 'dark'" )
286
+
245
287
visual_graph = self ._name
246
288
if self .node_count () > node_count :
247
289
visual_graph = str (uuid4 ())
@@ -256,14 +298,19 @@ def visualize(
256
298
custom_error = False ,
257
299
)
258
300
259
- pr_prop = str (uuid4 ())
260
- self ._query_runner .call_procedure (
261
- endpoint = "gds.pageRank.mutate" ,
262
- params = CallParameters (graph_name = visual_graph , config = dict (mutateProperty = pr_prop )),
263
- custom_error = False ,
264
- )
301
+ # Make sure we always have at least a size property so that we can run `gds.graph.nodeProperties.stream`
302
+ if size_property is None :
303
+ size_property = str (uuid4 ())
304
+ self ._query_runner .call_procedure (
305
+ endpoint = "gds.pageRank.mutate" ,
306
+ params = CallParameters (graph_name = visual_graph , config = dict (mutateProperty = size_property )),
307
+ custom_error = False ,
308
+ )
309
+ clean_up_size_prop = True
310
+ else :
311
+ clean_up_size_prop = False
265
312
266
- node_properties = [pr_prop ]
313
+ node_properties = [size_property ]
267
314
if include_node_properties is not None :
268
315
node_properties .extend (include_node_properties )
269
316
@@ -295,11 +342,18 @@ def visualize(
295
342
result .columns .name = None
296
343
node_properties_df = result
297
344
298
- relationships_df = self ._query_runner .call_procedure (
299
- endpoint = "gds.graph.relationships.stream" ,
300
- params = CallParameters (graph_name = visual_graph ),
301
- custom_error = False ,
302
- )
345
+ if rel_weight_property is None :
346
+ relationships_df = self ._query_runner .call_procedure (
347
+ endpoint = "gds.graph.relationships.stream" ,
348
+ params = CallParameters (graph_name = visual_graph ),
349
+ custom_error = False ,
350
+ )
351
+ else :
352
+ relationships_df = self ._query_runner .call_procedure (
353
+ endpoint = "gds.graph.relationshipProperty.stream" ,
354
+ params = CallParameters (graph_name = visual_graph , properties = rel_weight_property ),
355
+ custom_error = False ,
356
+ )
303
357
304
358
# Clean up
305
359
if visual_graph != self ._name :
@@ -308,10 +362,10 @@ def visualize(
308
362
params = CallParameters (graph_name = visual_graph ),
309
363
custom_error = False ,
310
364
)
311
- else :
365
+ elif clean_up_size_prop :
312
366
self ._query_runner .call_procedure (
313
367
endpoint = "gds.graph.nodeProperties.drop" ,
314
- params = CallParameters (graph_name = visual_graph , nodeProperties = pr_prop ),
368
+ params = CallParameters (graph_name = visual_graph , nodeProperties = size_property ),
315
369
custom_error = False ,
316
370
)
317
371
@@ -320,19 +374,21 @@ def visualize(
320
374
net = Network (
321
375
notebook = True if notebook else False ,
322
376
cdn_resources = "remote" if notebook else "local" ,
323
- bgcolor = "#222222" , # Dark background
324
- font_color = "white" ,
325
- height = "750px" , # Modify according to your screen size
377
+ directed = directed ,
378
+ bgcolor = "#222222" if theme == "dark" else "#F2F2F2" ,
379
+ font_color = "white" if theme == "dark" else "black" ,
380
+ height = f"{ px_height } px" ,
326
381
width = "100%" ,
327
382
)
328
383
329
384
if color_property is None :
330
- color_map = {label : self ._random_bright_color ( ) for label in self .node_labels ()}
385
+ color_map = {label : self ._random_themed_color ( theme ) for label in self .node_labels ()}
331
386
else :
332
387
color_map = {
333
- prop_val : self ._random_bright_color ( ) for prop_val in node_properties_df [color_property ].unique ()
388
+ prop_val : self ._random_themed_color ( theme ) for prop_val in node_properties_df [color_property ].unique ()
334
389
}
335
390
391
+ # Add all the nodes
336
392
for _ , node in node_properties_df .iterrows ():
337
393
title = f"Node ID: { node ['nodeId' ]} \n Labels: { node ['nodeLabels' ]} "
338
394
if include_node_properties is not None :
@@ -347,17 +403,22 @@ def visualize(
347
403
348
404
net .add_node (
349
405
int (node ["nodeId" ]),
350
- value = node [pr_prop ],
406
+ value = node [size_property ],
351
407
color = color ,
352
408
title = title ,
353
409
)
354
410
355
411
# Add all the relationships
356
- net .add_edges (zip (relationships_df ["sourceNodeId" ], relationships_df ["targetNodeId" ]))
412
+ for _ , rel in relationships_df .iterrows ():
413
+ if rel_weight_property is None :
414
+ net .add_edge (rel ["sourceNodeId" ], rel ["targetNodeId" ], title = f"Type: { rel ['relationshipType' ]} " )
415
+ else :
416
+ title = f"Type: { rel ['relationshipType' ]} \n { rel_weight_property } = { rel ['rel_weight_property' ]} "
417
+ net .add_edge (rel ["sourceNodeId" ], rel ["targetNodeId" ], title = title , value = rel [rel_weight_property ])
357
418
358
419
return net .show (f"{ self ._name } .html" )
359
420
360
421
@staticmethod
361
- def _random_bright_color ( ) -> str :
362
- h = random . randint ( 0 , 255 ) / 255.0
363
- return "#%02X%02X%02X" % tuple (map (lambda x : int (x * 255 ), colorsys .hls_to_rgb (h , 0.7 , 1.0 )))
422
+ def _random_themed_color ( theme ) -> str :
423
+ l = 0.7 if theme == "dark" else 0.4
424
+ return "#%02X%02X%02X" % tuple (map (lambda x : int (x * 255 ), colorsys .hls_to_rgb (random . random (), l , 1.0 )))
0 commit comments