Skip to content

Commit 46dd55d

Browse files
authored
Merge pull request #37 from Yelrose/master
Add Graph Normalization Layers
2 parents 2a8223a + 50ab64b commit 46dd55d

File tree

1 file changed

+26
-1
lines changed

1 file changed

+26
-1
lines changed

pgl/layers/graph_pool.py

+26-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from pgl.utils import paddle_helper
2020
from pgl.utils import op
2121

22-
__all__ = ['graph_pooling']
22+
__all__ = ['graph_pooling', 'graph_norm']
2323

2424

2525
def graph_pooling(gw, node_feat, pool_type):
@@ -40,3 +40,28 @@ def graph_pooling(gw, node_feat, pool_type):
4040
graph_feat = op.nested_lod_reset(node_feat, gw.graph_lod)
4141
graph_feat = fluid.layers.sequence_pool(graph_feat, pool_type)
4242
return graph_feat
43+
44+
45+
def graph_norm(gw, feature):
46+
"""Implementation of graph normalization
47+
48+
Reference Paper: BENCHMARKING GRAPH NEURAL NETWORKS
49+
50+
Each node features is divied by sqrt(num_nodes) per graphs.
51+
52+
Args:
53+
gw: Graph wrapper object (:code:`StaticGraphWrapper` or :code:`GraphWrapper`)
54+
55+
feature: A tensor with shape (num_nodes, hidden_size)
56+
57+
Return:
58+
A tensor with shape (num_nodes, hidden_size)
59+
"""
60+
nodes = fluid.layers.fill_constant(
61+
[gw.num_nodes, 1], dtype="float32", value=1.0)
62+
norm = graph_pooling(gw, nodes, pool_type="sum")
63+
norm = fluid.layers.sqrt(norm)
64+
feature_lod = op.nested_lod_reset(feature, gw.graph_lod)
65+
norm = fluid.layers.sequence_expand_as(norm, feature_lod)
66+
norm.stop_gradient = True
67+
return feature_lod / norm

0 commit comments

Comments
 (0)