Skip to content

Commit 03cb362

Browse files
authored
Merge pull request #4 from PaddlePaddle/develop
Develop
2 parents b505f7d + 2a8223a commit 03cb362

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

60 files changed

+5622
-657
lines changed

examples/GATNE/Dataset.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import numpy as np
2222
import logging
2323
import random
24-
from pgl.contrib import heter_graph
24+
from pgl import heter_graph
2525
import pickle as pkl
2626

2727

examples/GATNE/model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
import paddle.fluid as fluid
2323
import paddle.fluid.layers as fl
24-
from pgl.contrib import heter_graph_wrapper
24+
from pgl import heter_graph_wrapper
2525

2626

2727
class GATNE(object):
+43
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# Distributed metapath2vec in PGL
2+
[metapath2vec](https://ericdongyx.github.io/papers/KDD17-dong-chawla-swami-metapath2vec.pdf) is a algorithm framework for representation learning in heterogeneous networks which contains multiple types of nodes and links. Given a heterogeneous graph, metapath2vec algorithm first generates meta-path-based random walks and then use skipgram model to train a language model. Based on PGL, we reproduce metapath2vec algorithm in distributed mode.
3+
4+
5+
## Datasets
6+
DBLP: The dataset contains 14376 papers (P), 20 conferences (C), 14475 authors (A), and 8920 terms (T). There are 33791 nodes in this dataset.
7+
You can dowload datasets from [here](https://github.com/librahu/HIN-Datasets-for-Recommendation-and-Network-Embedding)
8+
9+
We use the ```DBLP``` dataset for example. After downloading the dataset, put them, let's say, in ```./data/DBLP/``` .
10+
11+
## Dependencies
12+
- paddlepaddle>=1.6
13+
- pgl>=1.0.0
14+
15+
## How to run
16+
Before training, run the below command to do data preprocessing.
17+
```sh
18+
python data_process.py --data_path ./data/DBLP --output_path ./data/data_processed
19+
```
20+
21+
We adopt [PaddlePaddle Fleet](https://github.com/PaddlePaddle/Fleet) as our distributed training frameworks. ```config.yaml``` is a configure file for metapath2vec hyperparameters and ```local_config``` is a configure file for parameter servers of PaddlePaddle. By default, we have 2 pservers and 2 trainers. One can use ```cloud_run.sh``` to help startup the parameter servers and model trainers.
22+
23+
For examples, train metapath2vec in distributed mode on DBLP dataset.
24+
```sh
25+
# train metapath2vec in distributed mode.
26+
sh cloud_run.sh
27+
28+
# multiclass task example
29+
python multi_class.py --dataset ./data/data_processed/author_label.txt --ckpt_path ./checkpoints/2000 --num_nodes 33791
30+
31+
```
32+
33+
34+
## Hyperparameters
35+
All the hyper parameters are saved in ```config.yaml``` file. So before training, you can open the config.yaml to modify the hyper parameters as you like.
36+
37+
Some important hyper parameters in config.yaml:
38+
- **edge_path**: the directory of graph data that you want to load
39+
- **lr**: learning rate
40+
- **neg_num**: number of negative samples.
41+
- **num_walks**: number of walks started from each node
42+
- **walk_len**: walk length
43+
- **meta_path**: meta path scheme
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
#!/bin/bash
2+
set -x
3+
mode=${1}
4+
5+
source ./utils.sh
6+
unset http_proxy https_proxy
7+
8+
source ./local_config
9+
if [ ! -d ${log_dir} ]; then
10+
mkdir ${log_dir}
11+
fi
12+
13+
for((i=0;i<${PADDLE_PSERVERS_NUM};i++))
14+
do
15+
echo "start ps server: ${i}"
16+
echo $log_dir
17+
TRAINING_ROLE="PSERVER" PADDLE_TRAINER_ID=${i} sh job.sh &> $log_dir/pserver.$i.log &
18+
done
19+
sleep 10s
20+
for((j=0;j<${PADDLE_TRAINERS_NUM};j++))
21+
do
22+
echo "start ps work: ${j}"
23+
TRAINING_ROLE="TRAINER" PADDLE_TRAINER_ID=${j} sh job.sh &> $log_dir/worker.$j.log &
24+
done
25+
tail -f $log_dir/worker.0.log
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import argparse
15+
import time
16+
import os
17+
import math
18+
import numpy as np
19+
20+
import paddle.fluid as F
21+
import paddle.fluid.layers as L
22+
from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler import fleet
23+
from paddle.fluid.transpiler.distribute_transpiler import DistributeTranspilerConfig
24+
import paddle.fluid.incubate.fleet.base.role_maker as role_maker
25+
from pgl.utils.logger import log
26+
27+
from model import Metapath2vecModel
28+
from graph import m2vGraph
29+
from utils import load_config
30+
from walker import multiprocess_data_generator
31+
32+
33+
def init_role():
34+
# reset the place according to role of parameter server
35+
training_role = os.getenv("TRAINING_ROLE", "TRAINER")
36+
paddle_role = role_maker.Role.WORKER
37+
place = F.CPUPlace()
38+
if training_role == "PSERVER":
39+
paddle_role = role_maker.Role.SERVER
40+
41+
# set the fleet runtime environment according to configure
42+
ports = os.getenv("PADDLE_PORT", "6174").split(",")
43+
pserver_ips = os.getenv("PADDLE_PSERVERS").split(",") # ip,ip...
44+
eplist = []
45+
if len(ports) > 1:
46+
# local debug mode, multi port
47+
for port in ports:
48+
eplist.append(':'.join([pserver_ips[0], port]))
49+
else:
50+
# distributed mode, multi ip
51+
for ip in pserver_ips:
52+
eplist.append(':'.join([ip, ports[0]]))
53+
54+
pserver_endpoints = eplist # ip:port,ip:port...
55+
worker_num = int(os.getenv("PADDLE_TRAINERS_NUM", "0"))
56+
trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
57+
role = role_maker.UserDefinedRoleMaker(
58+
current_id=trainer_id,
59+
role=paddle_role,
60+
worker_num=worker_num,
61+
server_endpoints=pserver_endpoints)
62+
fleet.init(role)
63+
64+
65+
def optimization(base_lr, loss, train_steps, optimizer='sgd'):
66+
decayed_lr = L.learning_rate_scheduler.polynomial_decay(
67+
learning_rate=base_lr,
68+
decay_steps=train_steps,
69+
end_learning_rate=0.0001 * base_lr,
70+
power=1.0,
71+
cycle=False)
72+
if optimizer == 'sgd':
73+
optimizer = F.optimizer.SGD(decayed_lr)
74+
elif optimizer == 'adam':
75+
optimizer = F.optimizer.Adam(decayed_lr, lazy_mode=True)
76+
else:
77+
raise ValueError
78+
79+
log.info('learning rate:%f' % (base_lr))
80+
#create the DistributeTranspiler configure
81+
config = DistributeTranspilerConfig()
82+
config.sync_mode = False
83+
#config.runtime_split_send_recv = False
84+
85+
config.slice_var_up = False
86+
#create the distributed optimizer
87+
optimizer = fleet.distributed_optimizer(optimizer, config)
88+
optimizer.minimize(loss)
89+
90+
91+
def build_complied_prog(train_program, model_loss):
92+
num_threads = int(os.getenv("CPU_NUM", 10))
93+
trainer_id = int(os.getenv("PADDLE_TRAINER_ID", 0))
94+
exec_strategy = F.ExecutionStrategy()
95+
exec_strategy.num_threads = num_threads
96+
#exec_strategy.use_experimental_executor = True
97+
build_strategy = F.BuildStrategy()
98+
build_strategy.enable_inplace = True
99+
#build_strategy.memory_optimize = True
100+
build_strategy.memory_optimize = False
101+
build_strategy.remove_unnecessary_lock = False
102+
if num_threads > 1:
103+
build_strategy.reduce_strategy = F.BuildStrategy.ReduceStrategy.Reduce
104+
105+
compiled_prog = F.compiler.CompiledProgram(
106+
train_program).with_data_parallel(loss_name=model_loss.name)
107+
return compiled_prog
108+
109+
110+
def train_prog(exe, program, loss, node2vec_pyreader, args, train_steps):
111+
trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
112+
step = 0
113+
if not os.path.exists(args.save_path):
114+
os.makedirs(args.save_path)
115+
while True:
116+
try:
117+
begin_time = time.time()
118+
loss_val, = exe.run(program, fetch_list=[loss])
119+
log.info("step %s: loss %.5f speed: %.5f s/step" %
120+
(step, np.mean(loss_val), time.time() - begin_time))
121+
step += 1
122+
except F.core.EOFException:
123+
node2vec_pyreader.reset()
124+
125+
if step % args.steps_per_save == 0 or step == train_steps:
126+
save_path = args.save_path
127+
if trainer_id == 0:
128+
model_path = os.path.join(save_path, "%s" % step)
129+
fleet.save_persistables(exe, model_path)
130+
131+
if step == train_steps:
132+
break
133+
134+
135+
def main(args):
136+
log.info("start")
137+
138+
worker_num = int(os.getenv("PADDLE_TRAINERS_NUM", "0"))
139+
num_devices = int(os.getenv("CPU_NUM", 10))
140+
141+
model = Metapath2vecModel(config=args)
142+
pyreader = model.pyreader
143+
loss = model.forward()
144+
145+
# init fleet
146+
init_role()
147+
148+
train_steps = math.ceil(args.num_nodes * args.epochs / args.batch_size /
149+
num_devices / worker_num)
150+
log.info("Train step: %s" % train_steps)
151+
152+
real_batch_size = args.batch_size * args.walk_len * args.win_size
153+
if args.optimizer == "sgd":
154+
args.lr *= real_batch_size
155+
optimization(args.lr, loss, train_steps, args.optimizer)
156+
157+
# init and run server or worker
158+
if fleet.is_server():
159+
fleet.init_server(args.warm_start_from_dir)
160+
fleet.run_server()
161+
162+
if fleet.is_worker():
163+
log.info("start init worker done")
164+
fleet.init_worker()
165+
#just the worker, load the sample
166+
log.info("init worker done")
167+
168+
exe = F.Executor(F.CPUPlace())
169+
exe.run(fleet.startup_program)
170+
log.info("Startup done")
171+
172+
dataset = m2vGraph(args)
173+
log.info("Build graph done.")
174+
175+
data_generator = multiprocess_data_generator(args, dataset)
176+
177+
cur_time = time.time()
178+
for idx, _ in enumerate(data_generator()):
179+
log.info("iter %s: %s s" % (idx, time.time() - cur_time))
180+
cur_time = time.time()
181+
if idx == 100:
182+
break
183+
184+
pyreader.decorate_tensor_provider(data_generator)
185+
pyreader.start()
186+
187+
compiled_prog = build_complied_prog(fleet.main_program, loss)
188+
train_prog(exe, compiled_prog, loss, pyreader, args, train_steps)
189+
190+
191+
if __name__ == '__main__':
192+
parser = argparse.ArgumentParser(description='metapath2vec')
193+
parser.add_argument("-c", "--config", type=str, default="./config.yaml")
194+
args = parser.parse_args()
195+
config = load_config(args.config)
196+
log.info(config)
197+
main(config)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# graph data config
2+
edge_path: "./data/data_processed"
3+
edge_files: "p2a:paper_author.txt,p2c:paper_conference.txt,p2t:paper_type.txt"
4+
node_types_file: "node_types.txt"
5+
num_nodes: 37791
6+
symmetry: True
7+
8+
# skipgram pair data config
9+
win_size: 5
10+
neg_num: 5
11+
# average; m2v_plus
12+
neg_sample_type: "average"
13+
14+
# random walk config
15+
# m2v; multi_m2v;
16+
walk_mode: "m2v"
17+
meta_path: "c2p-p2a-a2p-p2c"
18+
first_node_type: "c"
19+
walk_len: 24
20+
batch_size: 4
21+
node_shuffle: True
22+
node_files: null
23+
num_sample_workers: 2
24+
25+
# model config
26+
embed_dim: 64
27+
is_sparse: True
28+
# only use when num_nodes > 100,000,000, slower than noraml embedding
29+
is_distributed: False
30+
31+
# trainging config
32+
epochs: 10
33+
optimizer: "sgd"
34+
lr: 1.0
35+
warm_start_from_dir: null
36+
walkpath_files: "None"
37+
train_files: "None"
38+
steps_per_save: 1000
39+
save_path: "./checkpoints"
40+
log_dir: "./logs"
41+
CPU_NUM: 16

0 commit comments

Comments
 (0)