Skip to content

Commit fff9186

Browse files
authored
Merge pull request #38 from Liwb5/develop
add gin layer
2 parents 46dd55d + 15853d5 commit fff9186

File tree

3 files changed

+162
-6
lines changed

3 files changed

+162
-6
lines changed

ogb_examples/linkproppred/main_pgl.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def forward(self, graph):
9696

9797
loss = fluid.layers.sigmoid_cross_entropy_with_logits(pred,
9898
self.edge_label)
99-
loss = fluid.layers.reduce_mean(loss)
99+
loss = fluid.layers.reduce_sum(loss)
100100

101101
return pred, prob, loss
102102

@@ -223,8 +223,10 @@ def test(exe, val_program, prob, evaluator, feed, splitted_edge):
223223
"float32").reshape(-1, 1)
224224
y_pred = exe.run(val_program, feed=feed, fetch_list=[prob])[0]
225225
input_dict = {
226-
"y_true": splitted_edge["valid_edge_label"],
227-
"y_pred": y_pred.reshape(-1, ),
226+
"y_pred_pos":
227+
y_pred[splitted_edge["valid_edge_label"] == 1].reshape(-1, ),
228+
"y_pred_neg":
229+
y_pred[splitted_edge["valid_edge_label"] == 0].reshape(-1, )
228230
}
229231
result["valid"] = evaluator.eval(input_dict)
230232

@@ -234,8 +236,10 @@ def test(exe, val_program, prob, evaluator, feed, splitted_edge):
234236
"float32").reshape(-1, 1)
235237
y_pred = exe.run(val_program, feed=feed, fetch_list=[prob])[0]
236238
input_dict = {
237-
"y_true": splitted_edge["test_edge_label"],
238-
"y_pred": y_pred.reshape(-1, ),
239+
"y_pred_pos":
240+
y_pred[splitted_edge["test_edge_label"] == 1].reshape(-1, ),
241+
"y_pred_neg":
242+
y_pred[splitted_edge["test_edge_label"] == 0].reshape(-1, )
239243
}
240244
result["test"] = evaluator.eval(input_dict)
241245
return result

pgl/layers/conv.py

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from pgl import graph_wrapper
1919
from pgl.utils import paddle_helper
2020

21-
__all__ = ['gcn', 'gat']
21+
__all__ = ['gcn', 'gat', 'gin']
2222

2323

2424
def gcn(gw, feature, hidden_size, activation, name, norm=None):
@@ -178,3 +178,73 @@ def reduce_attention(msg):
178178
bias.stop_gradient = True
179179
output = fluid.layers.elementwise_add(output, bias, act=activation)
180180
return output
181+
182+
183+
def gin(gw,
184+
feature,
185+
hidden_size,
186+
activation,
187+
name,
188+
init_eps=0.0,
189+
train_eps=False):
190+
"""Implementation of Graph Isomorphism Network (GIN) layer.
191+
192+
This is an implementation of the paper How Powerful are Graph Neural Networks?
193+
(https://arxiv.org/pdf/1810.00826.pdf).
194+
195+
In their implementation, all MLPs have 2 layers. Batch normalization is applied
196+
on every hidden layer.
197+
198+
Args:
199+
gw: Graph wrapper object (:code:`StaticGraphWrapper` or :code:`GraphWrapper`)
200+
201+
feature: A tensor with shape (num_nodes, feature_size).
202+
203+
name: GIN layer names.
204+
205+
hidden_size: The hidden size for gin.
206+
207+
activation: The activation for the output.
208+
209+
init_eps: float, optional
210+
Initial :math:`\epsilon` value, default is 0.
211+
212+
train_eps: bool, optional
213+
if True, :math:`\epsilon` will be a learnable parameter.
214+
215+
Return:
216+
A tensor with shape (num_nodes, hidden_size).
217+
"""
218+
219+
def send_src_copy(src_feat, dst_feat, edge_feat):
220+
return src_feat["h"]
221+
222+
epsilon = fluid.layers.create_parameter(
223+
shape=[1, 1],
224+
dtype="float32",
225+
attr=fluid.ParamAttr(name="%s_eps" % name),
226+
default_initializer=fluid.initializer.ConstantInitializer(
227+
value=init_eps))
228+
229+
if not train_eps:
230+
epsilon.stop_gradient = True
231+
232+
msg = gw.send(send_src_copy, nfeat_list=[("h", feature)])
233+
output = gw.recv(msg, "sum") + (1.0 + epsilon) * feature
234+
235+
output = fluid.layers.fc(output,
236+
size=hidden_size,
237+
act=None,
238+
param_attr=fluid.ParamAttr(name="%s_w_0" % name),
239+
bias_attr=fluid.ParamAttr(name="%s_b_0" % name))
240+
241+
output = fluid.layers.batch_norm(output)
242+
output = getattr(fluid.layers, activation)(output)
243+
244+
output = fluid.layers.fc(output,
245+
size=hidden_size,
246+
act=activation,
247+
param_attr=fluid.ParamAttr(name="%s_w_1" % name),
248+
bias_attr=fluid.ParamAttr(name="%s_b_1" % name))
249+
250+
return output

pgl/tests/test_gin.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""
15+
This file is for testing gin layer.
16+
"""
17+
from __future__ import division
18+
from __future__ import absolute_import
19+
from __future__ import print_function
20+
from __future__ import unicode_literals
21+
import unittest
22+
import numpy as np
23+
24+
import paddle.fluid as F
25+
import paddle.fluid.layers as L
26+
27+
from pgl.layers.conv import gin
28+
from pgl import graph
29+
from pgl import graph_wrapper
30+
31+
32+
class GinTest(unittest.TestCase):
33+
"""GinTest
34+
"""
35+
36+
def test_gin(self):
37+
"""test_gin
38+
"""
39+
np.random.seed(1)
40+
hidden_size = 8
41+
42+
num_nodes = 10
43+
44+
edges = [(1, 4), (0, 5), (1, 9), (1, 8), (2, 8), (2, 5), (3, 6),
45+
(3, 7), (3, 4), (3, 8)]
46+
inver_edges = [(v, u) for u, v in edges]
47+
edges.extend(inver_edges)
48+
49+
node_feat = {"feature": np.random.rand(10, 4).astype("float32")}
50+
51+
g = graph.Graph(num_nodes=num_nodes, edges=edges, node_feat=node_feat)
52+
53+
use_cuda = False
54+
place = F.GPUPlace(0) if use_cuda else F.CPUPlace()
55+
56+
prog = F.Program()
57+
startup_prog = F.Program()
58+
with F.program_guard(prog, startup_prog):
59+
gw = graph_wrapper.GraphWrapper(
60+
name='graph',
61+
place=place,
62+
node_feat=g.node_feat_info(),
63+
edge_feat=g.edge_feat_info())
64+
65+
output = gin(gw,
66+
gw.node_feat['feature'],
67+
hidden_size=hidden_size,
68+
activation='relu',
69+
name='gin',
70+
init_eps=1,
71+
train_eps=True)
72+
73+
exe = F.Executor(place)
74+
exe.run(startup_prog)
75+
ret = exe.run(prog, feed=gw.to_feed(g), fetch_list=[output])
76+
77+
self.assertEqual(ret[0].shape[0], num_nodes)
78+
self.assertEqual(ret[0].shape[1], hidden_size)
79+
80+
81+
if __name__ == "__main__":
82+
unittest.main()

0 commit comments

Comments
 (0)