19
19
from pgl .utils import paddle_helper
20
20
from pgl .utils import op
21
21
22
- __all__ = ['graph_pooling' ]
22
+ __all__ = ['graph_pooling' , 'graph_norm' ]
23
23
24
24
25
25
def graph_pooling (gw , node_feat , pool_type ):
@@ -40,3 +40,28 @@ def graph_pooling(gw, node_feat, pool_type):
40
40
graph_feat = op .nested_lod_reset (node_feat , gw .graph_lod )
41
41
graph_feat = fluid .layers .sequence_pool (graph_feat , pool_type )
42
42
return graph_feat
43
+
44
+
45
+ def graph_norm (gw , feature ):
46
+ """Implementation of graph normalization
47
+
48
+ Reference Paper: BENCHMARKING GRAPH NEURAL NETWORKS
49
+
50
+ Each node features is divied by sqrt(num_nodes) per graphs.
51
+
52
+ Args:
53
+ gw: Graph wrapper object (:code:`StaticGraphWrapper` or :code:`GraphWrapper`)
54
+
55
+ feature: A tensor with shape (num_nodes, hidden_size)
56
+
57
+ Return:
58
+ A tensor with shape (num_nodes, hidden_size)
59
+ """
60
+ nodes = fluid .layers .fill_constant (
61
+ [gw .num_nodes , 1 ], dtype = "float32" , value = 1.0 )
62
+ norm = graph_pooling (gw , nodes , pool_type = "sum" )
63
+ norm = fluid .layers .sqrt (norm )
64
+ feature_lod = op .nested_lod_reset (feature , gw .graph_lod )
65
+ norm = fluid .layers .sequence_expand_as (norm , feature_lod )
66
+ norm .stop_gradient = True
67
+ return feature_lod / norm
0 commit comments