Skip to content

Commit 299328a

Browse files
authored
Merge pull request #94 from kirayummy/master
deeper_gcn
2 parents d016a9d + 60db1b4 commit 299328a

File tree

7 files changed

+542
-3
lines changed

7 files changed

+542
-3
lines changed

examples/deeper_gcn/README.md

+33
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# DeeperGCN: All You Need to Train Deeper GCNs
2+
3+
see more information in https://arxiv.org/pdf/2006.07739.pdf
4+
5+
6+
### Datasets
7+
8+
The datasets contain three citation networks: CORA, PUBMED, CITESEER. The details for these three datasets can be found in the [paper](https://arxiv.org/abs/1609.02907).
9+
10+
### Dependencies
11+
12+
- paddlepaddle>=1.6
13+
- pgl
14+
15+
### Performance
16+
17+
We train our models for 200 epochs and report the accuracy on the test dataset.
18+
19+
| Dataset | Accuracy |
20+
| --- | --- |
21+
| Cora | ~77% |
22+
23+
### How to run
24+
25+
For examples, use gpu to train gat on cora dataset.
26+
```
27+
python train.py --dataset cora --use_cuda
28+
```
29+
30+
#### Hyperparameters
31+
32+
- dataset: The citation dataset "cora", "citeseer", "pubmed".
33+
- use_cuda: Use gpu if assign use_cuda.

examples/deeper_gcn/model.py

+89
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
import pgl
2+
import paddle.fluid as fluid
3+
4+
def DeeperGCN(gw, feature, num_layers,
5+
hidden_size, num_tasks, name, dropout_prob):
6+
"""Implementation of DeeperGCN, see the paper
7+
"DeeperGCN: All You Need to Train Deeper GCNs" in
8+
https://arxiv.org/pdf/2006.07739.pdf
9+
10+
Args:
11+
gw: Graph wrapper object
12+
13+
feature: A tensor with shape (num_nodes, feature_size)
14+
15+
num_layers: num of layers in DeeperGCN
16+
17+
hidden_size: hidden_size in DeeperGCN
18+
19+
num_tasks: final prediction
20+
21+
name: deeper gcn layer names
22+
23+
dropout_prob: dropout prob in DeeperGCN
24+
25+
Return:
26+
A tensor with shape (num_nodes, hidden_size)
27+
"""
28+
29+
beta = "dynamic"
30+
feature = fluid.layers.fc(feature,
31+
hidden_size,
32+
bias_attr=False,
33+
param_attr=fluid.ParamAttr(name=name + '_weight'))
34+
35+
output = pgl.layers.gen_conv(gw, feature, name=name+"_gen_conv_0", beta=beta)
36+
37+
for layer in range(num_layers):
38+
# LN/BN->ReLU->GraphConv->Res
39+
old_output = output
40+
# 1. Layer Norm
41+
output = fluid.layers.layer_norm(
42+
output,
43+
begin_norm_axis=1,
44+
param_attr=fluid.ParamAttr(
45+
name="norm_scale_%s_%d" % (name, layer),
46+
initializer=fluid.initializer.Constant(1.0)),
47+
bias_attr=fluid.ParamAttr(
48+
name="norm_bias_%s_%d" % (name, layer),
49+
initializer=fluid.initializer.Constant(0.0)))
50+
51+
# 2. ReLU
52+
output = fluid.layers.relu(output)
53+
54+
#3. dropout
55+
output = fluid.layers.dropout(output,
56+
dropout_prob=dropout_prob,
57+
dropout_implementation="upscale_in_train")
58+
59+
#4 gen_conv
60+
output = pgl.layers.gen_conv(gw, output,
61+
name=name+"_gen_conv_%d"%layer, beta=beta)
62+
63+
#5 res
64+
output = output + old_output
65+
66+
# final layer: LN + relu + droput
67+
output = fluid.layers.layer_norm(
68+
output,
69+
begin_norm_axis=1,
70+
param_attr=fluid.ParamAttr(
71+
name="norm_scale_%s_%d" % (name, num_layers),
72+
initializer=fluid.initializer.Constant(1.0)),
73+
bias_attr=fluid.ParamAttr(
74+
name="norm_bias_%s_%d" % (name, num_layers),
75+
initializer=fluid.initializer.Constant(0.0)))
76+
output = fluid.layers.relu(output)
77+
output = fluid.layers.dropout(output,
78+
dropout_prob=dropout_prob,
79+
dropout_implementation="upscale_in_train")
80+
81+
# final prediction
82+
output = fluid.layers.fc(output,
83+
num_tasks,
84+
bias_attr=False,
85+
param_attr=fluid.ParamAttr(name=name + '_final_weight'))
86+
87+
return output
88+
89+

examples/deeper_gcn/train.py

+155
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#-*- coding: utf-8 -*-
15+
import pgl
16+
from pgl import data_loader
17+
from pgl.utils.logger import log
18+
import paddle.fluid as fluid
19+
import numpy as np
20+
import time
21+
import argparse
22+
from pgl.utils.log_writer import LogWriter # vdl
23+
from model import DeeperGCN
24+
25+
def load(name):
26+
if name == 'cora':
27+
dataset = data_loader.CoraDataset()
28+
elif name == "pubmed":
29+
dataset = data_loader.CitationDataset("pubmed", symmetry_edges=False)
30+
elif name == "citeseer":
31+
dataset = data_loader.CitationDataset("citeseer", symmetry_edges=False)
32+
else:
33+
raise ValueError(name + " dataset doesn't exists")
34+
return dataset
35+
36+
37+
def main(args):
38+
# vdl
39+
writer = LogWriter("checkpoints/train_history")
40+
41+
dataset = load(args.dataset)
42+
place = fluid.CUDAPlace(0) if args.use_cuda else fluid.CPUPlace()
43+
train_program = fluid.Program()
44+
startup_program = fluid.Program()
45+
test_program = fluid.Program()
46+
hidden_size = 64
47+
num_layers = 7
48+
49+
with fluid.program_guard(train_program, startup_program):
50+
gw = pgl.graph_wrapper.GraphWrapper(
51+
name="graph",
52+
node_feat=dataset.graph.node_feat_info())
53+
54+
output = DeeperGCN(gw,
55+
gw.node_feat["words"],
56+
num_layers,
57+
hidden_size,
58+
dataset.num_classes,
59+
"deepercnn",
60+
0.1)
61+
62+
node_index = fluid.layers.data(
63+
"node_index",
64+
shape=[None, 1],
65+
dtype="int64",
66+
append_batch_size=False)
67+
node_label = fluid.layers.data(
68+
"node_label",
69+
shape=[None, 1],
70+
dtype="int64",
71+
append_batch_size=False)
72+
73+
pred = fluid.layers.gather(output, node_index)
74+
loss, pred = fluid.layers.softmax_with_cross_entropy(
75+
logits=pred, label=node_label, return_softmax=True)
76+
acc = fluid.layers.accuracy(input=pred, label=node_label, k=1)
77+
loss = fluid.layers.mean(loss)
78+
79+
test_program = train_program.clone(for_test=True)
80+
with fluid.program_guard(train_program, startup_program):
81+
adam = fluid.optimizer.Adam(
82+
regularization=fluid.regularizer.L2DecayRegularizer(
83+
regularization_coeff=0.0005),
84+
learning_rate=0.005)
85+
adam.minimize(loss)
86+
87+
exe = fluid.Executor(place)
88+
exe.run(startup_program)
89+
90+
feed_dict = gw.to_feed(dataset.graph)
91+
92+
train_index = dataset.train_index
93+
train_label = np.expand_dims(dataset.y[train_index], -1)
94+
train_index = np.expand_dims(train_index, -1)
95+
96+
val_index = dataset.val_index
97+
val_label = np.expand_dims(dataset.y[val_index], -1)
98+
val_index = np.expand_dims(val_index, -1)
99+
100+
test_index = dataset.test_index
101+
test_label = np.expand_dims(dataset.y[test_index], -1)
102+
test_index = np.expand_dims(test_index, -1)
103+
104+
# get beta param
105+
beta_param_list = []
106+
for param in fluid.io.get_program_parameter(train_program):
107+
if param.name.endswith("_beta"):
108+
beta_param_list.append(param)
109+
110+
dur = []
111+
for epoch in range(200):
112+
if epoch >= 3:
113+
t0 = time.time()
114+
feed_dict["node_index"] = np.array(train_index, dtype="int64")
115+
feed_dict["node_label"] = np.array(train_label, dtype="int64")
116+
train_loss, train_acc = exe.run(train_program,
117+
feed=feed_dict,
118+
fetch_list=[loss, acc],
119+
return_numpy=True)
120+
for param in beta_param_list:
121+
beta = np.array(fluid.global_scope().find_var(param.name).get_tensor())
122+
writer.add_scalar("beta/"+param.name, beta, epoch)
123+
124+
if epoch >= 3:
125+
time_per_epoch = 1.0 * (time.time() - t0)
126+
dur.append(time_per_epoch)
127+
128+
feed_dict["node_index"] = np.array(val_index, dtype="int64")
129+
feed_dict["node_label"] = np.array(val_label, dtype="int64")
130+
val_loss, val_acc = exe.run(test_program,
131+
feed=feed_dict,
132+
fetch_list=[loss, acc],
133+
return_numpy=True)
134+
135+
log.info("Epoch %d " % epoch + "(%.5lf sec) " % np.mean(dur) +
136+
"Train Loss: %f " % train_loss + "Train Acc: %f " % train_acc
137+
+ "Val Loss: %f " % val_loss + "Val Acc: %f " % val_acc)
138+
139+
feed_dict["node_index"] = np.array(test_index, dtype="int64")
140+
feed_dict["node_label"] = np.array(test_label, dtype="int64")
141+
test_loss, test_acc = exe.run(test_program,
142+
feed=feed_dict,
143+
fetch_list=[loss, acc],
144+
return_numpy=True)
145+
log.info("Accuracy: %f" % test_acc)
146+
147+
148+
if __name__ == '__main__':
149+
parser = argparse.ArgumentParser(description='DeeperGCN')
150+
parser.add_argument(
151+
"--dataset", type=str, default="cora", help="dataset (cora, pubmed)")
152+
parser.add_argument("--use_cuda", action='store_true', help="use_cuda")
153+
args = parser.parse_args()
154+
log.info(args)
155+
main(args)

pgl/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,4 @@
2121
from pgl import heter_graph
2222
from pgl import heter_graph_wrapper
2323
from pgl import contrib
24+
from pgl import message_passing

pgl/layers/conv.py

+54-2
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@
1515
graph neural networks.
1616
"""
1717
import paddle.fluid as fluid
18-
from pgl import graph_wrapper
1918
from pgl.utils import paddle_helper
19+
from pgl import message_passing
2020

21-
__all__ = ['gcn', 'gat', 'gin', 'gaan']
21+
__all__ = ['gcn', 'gat', 'gin', 'gaan', 'gen_conv']
2222

2323

2424
def gcn(gw, feature, hidden_size, activation, name, norm=None):
@@ -352,3 +352,55 @@ def recv_func(message):
352352
output = fluid.layers.dropout(output, dropout_prob=0.1)
353353

354354
return output
355+
356+
357+
def gen_conv(gw,
358+
feature,
359+
name,
360+
beta=None):
361+
"""Implementation of GENeralized Graph Convolution (GENConv), see the paper
362+
"DeeperGCN: All You Need to Train Deeper GCNs" in
363+
https://arxiv.org/pdf/2006.07739.pdf
364+
365+
Args:
366+
gw: Graph wrapper object (:code:`StaticGraphWrapper` or :code:`GraphWrapper`)
367+
368+
feature: A tensor with shape (num_nodes, feature_size).
369+
370+
beta: [0, +infinity] or "dynamic" or None
371+
372+
name: deeper gcn layer names.
373+
374+
Return:
375+
A tensor with shape (num_nodes, feature_size)
376+
"""
377+
378+
if beta == "dynamic":
379+
beta = fluid.layers.create_parameter(
380+
shape=[1],
381+
dtype='float32',
382+
default_initializer=
383+
fluid.initializer.ConstantInitializer(value=1.0),
384+
name=name + '_beta')
385+
386+
# message passing
387+
msg = gw.send(message_passing.copy_send, nfeat_list=[("h", feature)])
388+
output = gw.recv(msg, message_passing.softmax_agg(beta))
389+
390+
# msg norm
391+
output = message_passing.msg_norm(feature, output, name)
392+
output = feature + output
393+
394+
output = fluid.layers.fc(output,
395+
feature.shape[-1],
396+
bias_attr=False,
397+
act="relu",
398+
param_attr=fluid.ParamAttr(name=name + '_weight1'))
399+
400+
output = fluid.layers.fc(output,
401+
feature.shape[-1],
402+
bias_attr=False,
403+
param_attr=fluid.ParamAttr(name=name + '_weight2'))
404+
405+
return output
406+

0 commit comments

Comments
 (0)