Skip to content

Commit 584bb25

Browse files
committed
Merge remote-tracking branch 'upstream/main' into main
2 parents 996f7a9 + ee61e01 commit 584bb25

20 files changed

+1367
-75
lines changed

examples/citation_benchmark/README.md

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Easy Paper Reproduction for Citation Network (Cora/Pubmed/Citeseer)
2+
3+
4+
5+
This page tries to reproduce all the **Graph Neural Network** paper for Citation Network (Cora/Pubmed/Citeseer), which is the **Hello world** dataset (**small** and **fast**) for graph neural networks. But it's very hard to achieve very high performance.
6+
7+
8+
9+
All datasets are runned with public split of **semi-supervised** settings. And we report the averarge accuracy by running 10 times.
10+
11+
12+
13+
# Experiment Results
14+
15+
| Model | Cora | Pubmed | Citeseer | Remarks |
16+
| ------------------------------------------------------------ | ------------ | ------------ | ------------ | --------------------------------------------------------- |
17+
| [Vanilla GCN (Kipf 2017)](https://openreview.net/pdf?id=SJU4ayYgl ) | 0.807(0.010) | 0.794(0.003) | 0.710(0.007) | |
18+
| [GAT (Veličković 2017)](https://arxiv.org/pdf/1710.10903.pdf) | 0.834(0.004) | 0.772(0.004) | 0.700(0.006) | |
19+
| [SGC(Wu 2019)](https://arxiv.org/pdf/1902.07153.pdf) | 0.818(0.000) | 0.782(0.000) | 0.708(0.000) | |
20+
| [APPNP (Johannes 2018)](https://arxiv.org/abs/1810.05997) | 0.846(0.003) | 0.803(0.002) | 0.719(0.003) | Almost the same with the results reported in Appendix E. |
21+
| [GCNII (64 Layers, 1500 Epochs, Chen 2020)](https://arxiv.org/pdf/2007.02133.pdf) | 0.846(0.003) | 0.798(0.003) | 0.724(0.006) | |
22+
23+
24+
25+
26+
27+
How to run the experiments?
28+
29+
30+
31+
```shell
32+
# Device choose
33+
export CUDA_VISIBLE_DEVICES=0
34+
# GCN
35+
python train.py --conf config/gcn.yaml --use_cuda --dataset cora
36+
python train.py --conf config/gcn.yaml --use_cuda --dataset pubmed
37+
python train.py --conf config/gcn.yaml --use_cuda --dataset citeseer
38+
39+
40+
# GAT
41+
python train.py --conf config/gat.yaml --use_cuda --dataset cora
42+
python train.py --conf config/gat.yaml --use_cuda --dataset pubmed
43+
python train.py --conf config/gat.yaml --use_cuda --dataset citeseer
44+
45+
46+
# SGC (Slow version)
47+
python train.py --conf config/sgc.yaml --use_cuda --dataset cora
48+
python train.py --conf config/sgc.yaml --use_cuda --dataset pubmed
49+
python train.py --conf config/sgc.yaml --use_cuda --dataset citeseer
50+
51+
# APPNP
52+
python train.py --conf config/appnp.yaml --use_cuda --dataset cora
53+
python train.py --conf config/appnp.yaml --use_cuda --dataset pubmed
54+
python train.py --conf config/appnp.yaml --use_cuda --dataset citeseer
55+
56+
# GCNII (The original code use 1500 epochs.)
57+
python train.py --conf config/gcnii.yaml --use_cuda --dataset cora --epoch 1500
58+
python train.py --conf config/gcnii.yaml --use_cuda --dataset pubmed --epoch 1500
59+
python train.py --conf config/gcnii.yaml --use_cuda --dataset citeseer --epoch 1500
60+
```
61+
62+
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import pgl
2+
import model
3+
from pgl import data_loader
4+
import paddle.fluid as fluid
5+
import numpy as np
6+
import time
7+
8+
def build_model(dataset, config, phase, main_prog):
9+
gw = pgl.graph_wrapper.GraphWrapper(
10+
name="graph",
11+
node_feat=dataset.graph.node_feat_info())
12+
13+
GraphModel = getattr(model, config.model_name)
14+
m = GraphModel(config=config, num_class=dataset.num_classes)
15+
logits = m.forward(gw, gw.node_feat["words"], phase)
16+
17+
# Take the last
18+
node_index = fluid.layers.data(
19+
"node_index",
20+
shape=[None, 1],
21+
dtype="int64",
22+
append_batch_size=False)
23+
node_label = fluid.layers.data(
24+
"node_label",
25+
shape=[None, 1],
26+
dtype="int64",
27+
append_batch_size=False)
28+
29+
pred = fluid.layers.gather(logits, node_index)
30+
loss, pred = fluid.layers.softmax_with_cross_entropy(
31+
logits=pred, label=node_label, return_softmax=True)
32+
acc = fluid.layers.accuracy(input=pred, label=node_label, k=1)
33+
loss = fluid.layers.mean(loss)
34+
35+
if phase == "train":
36+
adam = fluid.optimizer.Adam(
37+
learning_rate=config.learning_rate,
38+
regularization=fluid.regularizer.L2DecayRegularizer(
39+
regularization_coeff=config.weight_decay))
40+
adam.minimize(loss)
41+
return gw, loss, acc
42+
43+
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
model_name: APPNP
2+
k_hop: 10
3+
alpha: 0.1
4+
num_layer: 1
5+
learning_rate: 0.01
6+
dropout: 0.5
7+
hidden_size: 64
8+
weight_decay: 0.0005
9+
edge_dropout: 0.0
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
model_name: GAT
2+
learning_rate: 0.005
3+
weight_decay: 0.0005
4+
num_layers: 1
5+
feat_drop: 0.6
6+
attn_drop: 0.6
7+
num_heads: 8
8+
hidden_size: 8
9+
edge_dropout: 0.0
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
model_name: GCN
2+
num_layers: 1
3+
dropout: 0.5
4+
hidden_size: 16
5+
learning_rate: 0.01
6+
weight_decay: 0.0005
7+
edge_dropout: 0.0
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
model_name: GCNII
2+
k_hop: 64
3+
alpha: 0.1
4+
num_layer: 1
5+
learning_rate: 0.01
6+
dropout: 0.6
7+
hidden_size: 64
8+
weight_decay: 0.0005
9+
edge_dropout: 0.0
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
model_name: SGC
2+
num_layers: 2
3+
learning_rate: 0.2
4+
weight_decay: 0.000005
5+
feature_pre_normalize: False

examples/citation_benchmark/model.py

Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
import pgl
2+
import paddle.fluid.layers as L
3+
import pgl.layers.conv as conv
4+
5+
def get_norm(indegree):
6+
float_degree = L.cast(indegree, dtype="float32")
7+
float_degree = L.clamp(float_degree, min=1.0)
8+
norm = L.pow(float_degree, factor=-0.5)
9+
return norm
10+
11+
12+
class GCN(object):
13+
"""Implement of GCN
14+
"""
15+
def __init__(self, config, num_class):
16+
self.num_class = num_class
17+
self.num_layers = config.get("num_layers", 1)
18+
self.hidden_size = config.get("hidden_size", 64)
19+
self.dropout = config.get("dropout", 0.5)
20+
self.edge_dropout = config.get("edge_dropout", 0.0)
21+
22+
def forward(self, graph_wrapper, feature, phase):
23+
24+
for i in range(self.num_layers):
25+
26+
if phase == "train":
27+
ngw = pgl.sample.edge_drop(graph_wrapper, self.edge_dropout)
28+
norm = get_norm(ngw.indegree())
29+
else:
30+
ngw = graph_wrapper
31+
norm = graph_wrapper.node_feat["norm"]
32+
33+
34+
feature = pgl.layers.gcn(ngw,
35+
feature,
36+
self.hidden_size,
37+
activation="relu",
38+
norm=norm,
39+
name="layer_%s" % i)
40+
41+
feature = L.dropout(
42+
feature,
43+
self.dropout,
44+
dropout_implementation='upscale_in_train')
45+
46+
if phase == "train":
47+
ngw = pgl.sample.edge_drop(graph_wrapper, self.edge_dropout)
48+
norm = get_norm(ngw.indegree())
49+
else:
50+
ngw = graph_wrapper
51+
norm = graph_wrapper.node_feat["norm"]
52+
53+
feature = conv.gcn(ngw,
54+
feature,
55+
self.num_class,
56+
activation=None,
57+
norm=norm,
58+
name="output")
59+
60+
return feature
61+
62+
class GAT(object):
63+
"""Implement of GAT"""
64+
def __init__(self, config, num_class):
65+
self.num_class = num_class
66+
self.num_layers = config.get("num_layers", 1)
67+
self.num_heads = config.get("num_heads", 8)
68+
self.hidden_size = config.get("hidden_size", 8)
69+
self.feat_dropout = config.get("feat_drop", 0.6)
70+
self.attn_dropout = config.get("attn_drop", 0.6)
71+
self.edge_dropout = config.get("edge_dropout", 0.0)
72+
73+
def forward(self, graph_wrapper, feature, phase):
74+
if phase == "train":
75+
edge_dropout = 0
76+
else:
77+
edge_dropout = self.edge_dropout
78+
79+
for i in range(self.num_layers):
80+
ngw = pgl.sample.edge_drop(graph_wrapper, edge_dropout)
81+
82+
feature = conv.gat(ngw,
83+
feature,
84+
self.hidden_size,
85+
activation="elu",
86+
name="gat_layer_%s" % i,
87+
num_heads=self.num_heads,
88+
feat_drop=self.feat_dropout,
89+
attn_drop=self.attn_dropout)
90+
91+
ngw = pgl.sample.edge_drop(graph_wrapper, edge_dropout)
92+
feature = conv.gat(ngw,
93+
feature,
94+
self.num_class,
95+
num_heads=1,
96+
activation=None,
97+
feat_drop=self.feat_dropout,
98+
attn_drop=self.attn_dropout,
99+
name="output")
100+
return feature
101+
102+
103+
class APPNP(object):
104+
"""Implement of APPNP"""
105+
def __init__(self, config, num_class):
106+
self.num_class = num_class
107+
self.num_layers = config.get("num_layers", 1)
108+
self.hidden_size = config.get("hidden_size", 64)
109+
self.dropout = config.get("dropout", 0.5)
110+
self.alpha = config.get("alpha", 0.1)
111+
self.k_hop = config.get("k_hop", 10)
112+
self.edge_dropout = config.get("edge_dropout", 0.0)
113+
114+
def forward(self, graph_wrapper, feature, phase):
115+
if phase == "train":
116+
edge_dropout = 0
117+
else:
118+
edge_dropout = self.edge_dropout
119+
120+
for i in range(self.num_layers):
121+
feature = L.dropout(
122+
feature,
123+
self.dropout,
124+
dropout_implementation='upscale_in_train')
125+
feature = L.fc(feature, self.hidden_size, act="relu", name="lin%s" % i)
126+
127+
feature = L.dropout(
128+
feature,
129+
self.dropout,
130+
dropout_implementation='upscale_in_train')
131+
132+
feature = L.fc(feature, self.num_class, act=None, name="output")
133+
134+
feature = conv.appnp(graph_wrapper,
135+
feature=feature,
136+
edge_dropout=edge_dropout,
137+
alpha=self.alpha,
138+
k_hop=self.k_hop)
139+
return feature
140+
141+
class SGC(object):
142+
"""Implement of SGC"""
143+
def __init__(self, config, num_class):
144+
self.num_class = num_class
145+
self.num_layers = config.get("num_layers", 1)
146+
147+
def forward(self, graph_wrapper, feature, phase):
148+
feature = conv.appnp(graph_wrapper,
149+
feature=feature,
150+
edge_dropout=0,
151+
alpha=0,
152+
k_hop=self.num_layers)
153+
feature.stop_gradient=True
154+
feature = L.fc(feature, self.num_class, act=None, bias_attr=False, name="output")
155+
return feature
156+
157+
158+
class GCNII(object):
159+
"""Implement of GCNII"""
160+
def __init__(self, config, num_class):
161+
self.num_class = num_class
162+
self.num_layers = config.get("num_layers", 1)
163+
self.hidden_size = config.get("hidden_size", 64)
164+
self.dropout = config.get("dropout", 0.6)
165+
self.alpha = config.get("alpha", 0.1)
166+
self.lambda_l = config.get("lambda_l", 0.5)
167+
self.k_hop = config.get("k_hop", 64)
168+
self.edge_dropout = config.get("edge_dropout", 0.0)
169+
170+
def forward(self, graph_wrapper, feature, phase):
171+
if phase == "train":
172+
edge_dropout = 0
173+
else:
174+
edge_dropout = self.edge_dropout
175+
176+
for i in range(self.num_layers):
177+
feature = L.fc(feature, self.hidden_size, act="relu", name="lin%s" % i)
178+
feature = L.dropout(
179+
feature,
180+
self.dropout,
181+
dropout_implementation='upscale_in_train')
182+
183+
feature = conv.gcnii(graph_wrapper,
184+
feature=feature,
185+
name="gcnii",
186+
activation="relu",
187+
lambda_l=self.lambda_l,
188+
alpha=self.alpha,
189+
dropout=self.dropout,
190+
k_hop=self.k_hop)
191+
192+
feature = L.fc(feature, self.num_class, act=None, name="output")
193+
return feature
194+
195+

0 commit comments

Comments
 (0)