Skip to content

Commit ef30d35

Browse files
authored
Merge pull request #274 from PaddlePaddle/liwb
PGL Distributed Graph Engine Demo
2 parents 0203099 + 2d7fccd commit ef30d35

22 files changed

+2361
-9
lines changed

.gitignore

+7-1
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,14 @@ coverage.xml
5757

5858
# Sphinx documentation
5959
/docs/_build/
60+
docs/build
6061

6162
# tutorials: jupyter log
6263
tutorials/.ipynb_checkpoints/
6364
.ipynb_checkpoints
64-
*.ipynb
65+
**checkpoints
66+
**logs
67+
**outputs
68+
pgl/tests/local*
69+
jupyter
70+
.gitignore.bak

README.md

+4
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@ Leaderboards can be found [here](https://ogb.stanford.edu/kddcup2021/results/).
2323
- Masked Label Prediction: Unified Message Passing Model for Semi-Supervised Classification, to appear in **IJCAI2021**.
2424
- HGAMN: Heterogeneous Graph Attention Matching Network for Multilingual POI Retrieval at Baidu Maps, to appear in **KDD2021**.
2525

26+
**PGL Dstributed Graph Engine API released!!**
27+
28+
- Our Dstributed Graph Engine API has been released and we developed a [tutorial](./tutorial/working_with_distributed_graph_engine.ipynb) to show how to launch a graph engine and a [demo](./examples/metapath2vec) for training model using graph engine.
29+
2630

2731
PGL v2.1 2021.02.02
2832

examples/metapath2vec/README.md

+98
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
# metapath2vec: Scalable Representation Learning for Heterogeneous Networks
2+
3+
[metapath2vec](https://ericdongyx.github.io/papers/KDD17-dong-chawla-swami-metapath2vec.pdf) is a algorithm framework for representation learning in heterogeneous networks which contains multiple types of nodes and links. Given a heterogeneous graph, metapath2vec algorithm first generates meta-path-based random walks and then use skipgram model to train a language model. Based on PGL, we reproduce metapath2vec algorithm using PGL graph engine for scalable representation learning.
4+
5+
6+
## Dependencies
7+
8+
- paddlepaddle>=2.1.0
9+
10+
- pgl>=2.1.4
11+
12+
- OpenMPI==1.4.1
13+
14+
## Datasets
15+
16+
You can download datasets from [here](https://ericdongyx.github.io/metapath2vec/m2v.html).
17+
18+
We use the "aminer" data for example. After downloading the aminer data, put them, let's say, in `./data/net_aminer/`. We also need to move the `label/` directory to `./data/` directory.
19+
20+
## Data preprocessing
21+
22+
After downloading the dataset, run the folowing command to preprocess the data:
23+
24+
```
25+
python data_preprocess.py --config config.yaml
26+
```
27+
28+
## Hyperparameters
29+
30+
All the hyper parameters are saved in `config.yaml` file. So before training, you can open the `config.yaml` to modify the hyper parameters as you like.
31+
32+
## PGL Graph Engine Launching
33+
34+
Now we support distributed loading graph data using **PGL Graph Engine**. We also develop a simple tutorial to show how to launch a graph engine, please refer to [here](../../tutorials/working_with_distributed_graph_engine.ipynb).
35+
36+
To launch a distributed graph service, please follow the steps below.
37+
38+
### IP address setting
39+
40+
The first step is to set the IP list for each graph server. Each IP address with port represents a server. In `ip_list.txt` file, we set up 4 ip addresses as follow for demo:
41+
42+
```
43+
127.0.0.1:8553
44+
127.0.0.1:8554
45+
127.0.0.1:8555
46+
127.0.0.1:8556
47+
```
48+
49+
### Launching Graph Engine by OpenMPI
50+
51+
Before launching the graph engine, you should set up the below hyper-parameters in `config.yaml`:
52+
53+
```
54+
etype2files: "p2a:./graph_data/paper2author_edges.txt,p2c:./graph_data/paper2conf_edges.txt"
55+
ntype2files: "p:./graph_data/node_types.txt,a:./graph_data/node_types.txt,c:./graph_data/node_types.txt"
56+
symmetry: True
57+
shard_num: 100
58+
```
59+
60+
Then, we can launch the graph engine with the help of OpenMPI.
61+
62+
```
63+
mpirun -np 4 python -m pgl.distributed.launch --ip_config ./ip_list.txt --conf ./config.yaml --mode mpi --shard_num 100
64+
```
65+
66+
### Launching Graph Engine manually
67+
68+
If you didn't install OpenMPI, you can launch the graph engine manually.
69+
70+
Fox example, if we want to use 4 servers, we should run the following command separately on 4 terminals.
71+
72+
```
73+
# terminal 3
74+
python -m pgl.distributed.launch --ip_config ./ip_list.txt --conf ./config.yaml --shard_num 100 --server_id 3
75+
76+
# terminal 2
77+
python -m pgl.distributed.launch --ip_config ./ip_list.txt --conf ./config.yaml --shard_num 100 --server_id 2
78+
79+
# terminal 1
80+
python -m pgl.distributed.launch --ip_config ./ip_list.txt --conf ./config.yaml --shard_num 100 --server_id 1
81+
82+
# terminal 0
83+
python -m pgl.distributed.launch --ip_config ./ip_list.txt --conf ./config.yaml --shard_num 100 --server_id 0
84+
```
85+
86+
Note that the `server_id` of 0 should be the last one to be launched.
87+
88+
89+
## Training
90+
91+
After successfully launching the graph engine, you can run the below command to train the model.
92+
93+
```
94+
export CUDA_VISIBLE_DEVICES=0
95+
python train.py --config ./config.yaml --ip ./ip_list.txt
96+
```
97+
98+
Note that the trained model will be saved `./ckpt_custom/$task_name/`

examples/metapath2vec/config.yaml

+43
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
task_name: distributed_metapath2vec
2+
3+
# ---------------------------数据配置-------------------------------------------------#
4+
# for data preprocessing
5+
data_path: ./data/net_aminer
6+
author_label_file: ./data/label/googlescholar.8area.author.label.txt
7+
venue_label_file: ./data/label/googlescholar.8area.venue.label.txt
8+
processed_path: ./graph_data
9+
10+
# for pgl graph engine
11+
etype2files: "p2a:./graph_data/paper2author_edges.txt,p2c:./graph_data/paper2conf_edges.txt"
12+
ntype2files: "p:./graph_data/node_types.txt,a:./graph_data/node_types.txt,c:./graph_data/node_types.txt"
13+
symmetry: True
14+
meta_path: "c2p-p2a-a2p-p2c"
15+
first_node_type: "c"
16+
17+
shard_num: 100
18+
19+
walk_len: 24
20+
win_size: 3
21+
neg_num: 5
22+
walk_times: 20
23+
24+
25+
# ---------------------------模型参数配置---------------------------------------------#
26+
model_type: SkipGramModel
27+
warm_start_from: null
28+
num_nodes: 5000000
29+
embed_size: 64
30+
sparse_embed: False
31+
32+
# ---------------------------训练参数配置---------------------------------------------#
33+
epochs: 1
34+
num_workers: 4
35+
lr: 0.001
36+
lazy_mode: False
37+
batch_node_size: 200
38+
batch_pair_size: 1000
39+
pair_stream_shuffle_size: 100000
40+
log_dir: ./logs
41+
output_dir: ./outputs
42+
save_dir: ./checkpoints
43+
log_steps: 1000
+157
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""
15+
Data pre-processing for metapath2vec model.
16+
"""
17+
18+
import os
19+
import sys
20+
import tqdm
21+
import time
22+
import logging
23+
import random
24+
import argparse
25+
import numpy as np
26+
import pickle as pkl
27+
28+
from pgl.utils.logger import log
29+
from utils.config import prepare_config, make_dir
30+
31+
# name ID g_index
32+
33+
34+
def remapping_id(file_, start_index, node_type, separator="\t"):
35+
"""Mapp the ID and name of nodes to index.
36+
"""
37+
node_types = []
38+
id2index = {}
39+
name2index = {}
40+
index = start_index
41+
with open(file_, encoding="ISO-8859-1") as reader:
42+
for line in reader:
43+
tokens = line.strip().split(separator)
44+
id2index[tokens[0]] = str(index)
45+
if len(tokens) == 2:
46+
name2index[tokens[1]] = str(index)
47+
node_types.append((str(index), node_type))
48+
index += 1
49+
50+
return id2index, name2index, node_types
51+
52+
53+
def load_edges(file_, src2index, dst2index, symmetry=False):
54+
"""Load edges from file.
55+
"""
56+
edges = []
57+
with open(file_, 'r') as reader:
58+
for line in reader:
59+
items = line.strip().split()
60+
src, dst = src2index[items[0]], dst2index[items[1]]
61+
edges.append((src, dst))
62+
if symmetry:
63+
edges.append((dst, src))
64+
edges = list(set(edges))
65+
return edges
66+
67+
68+
def load_label(file_, name2index):
69+
index_label = []
70+
with open(file_, encoding="ISO-8859-1") as reader:
71+
for line in reader:
72+
tokens = line.strip().split(' ')
73+
name, label = tokens[0], int(tokens[1]) - 1
74+
if name in name2index:
75+
index_label.append((name2index[name], str(label)))
76+
77+
return index_label
78+
79+
80+
def main(config):
81+
conf_id2index, conf_name2index, conf_node_type = remapping_id(
82+
os.path.join(config.data_path, 'id_conf.txt'),
83+
start_index=0,
84+
node_type='c')
85+
log.info('%d venues have been loaded.' % (len(conf_id2index)))
86+
87+
author_id2index, author_name2index, author_node_type = remapping_id(
88+
os.path.join(config.data_path, 'id_author.txt'),
89+
start_index=len(conf_id2index),
90+
node_type='a')
91+
log.info('%d authors have been loaded.' % (len(author_id2index)))
92+
93+
paper_id2index, paper_name2index, paper_node_type = remapping_id(
94+
os.path.join(config.data_path, 'paper.txt'),
95+
start_index=(len(conf_id2index) + len(author_id2index)),
96+
node_type='p',
97+
separator='\t')
98+
log.info('%d papers have been loaded.' % (len(paper_id2index)))
99+
100+
node_types = conf_node_type + author_node_type + paper_node_type
101+
102+
paper2author_edges = load_edges(
103+
os.path.join(config.data_path, 'paper_author.txt'), paper_id2index,
104+
author_id2index)
105+
log.info('%d paper2author edges have been loaded.' %
106+
(len(paper2author_edges)))
107+
108+
paper2conf_edges = load_edges(
109+
os.path.join(config.data_path, 'paper_conf.txt'), paper_id2index,
110+
conf_id2index)
111+
log.info('%d paper2conf edges have been loaded.' % (len(paper2conf_edges)))
112+
113+
author_label = load_label(config.author_label_file, author_name2index)
114+
conf_label = load_label(config.venue_label_file, conf_name2index)
115+
116+
make_dir(config.processed_path)
117+
node_types_file = os.path.join(config.processed_path, 'node_types.txt')
118+
log.info("saving node_types to %s" % node_types_file)
119+
with open(node_types_file, 'w') as writer:
120+
for item in tqdm.tqdm(node_types):
121+
writer.write("%s\t%s\n" % (item[1], item[0]))
122+
123+
p2a_edges_file = os.path.join(config.processed_path,
124+
'paper2author_edges.txt')
125+
log.info("saving paper2author edges to %s" % p2a_edges_file)
126+
with open(p2a_edges_file, 'w') as writer:
127+
for item in tqdm.tqdm(paper2author_edges):
128+
writer.write("\t".join(item) + "\n")
129+
130+
p2c_edges_file = os.path.join(config.processed_path,
131+
'paper2conf_edges.txt')
132+
log.info("saving paper2conf edges to %s" % p2c_edges_file)
133+
with open(p2c_edges_file, 'w') as writer:
134+
for item in tqdm.tqdm(paper2conf_edges):
135+
writer.write("\t".join(item) + "\n")
136+
137+
author_label_file = os.path.join(config.processed_path, 'author_label.txt')
138+
log.info("saving author label to %s" % author_label_file)
139+
with open(author_label_file, 'w') as writer:
140+
for item in tqdm.tqdm(author_label):
141+
writer.write("\t".join(item) + "\n")
142+
143+
conf_label_file = os.path.join(config.processed_path, 'conf_label.txt')
144+
log.info("saving conf label to %s" % conf_label_file)
145+
with open(conf_label_file, 'w') as writer:
146+
for item in tqdm.tqdm(conf_label):
147+
writer.write("\t".join(item) + "\n")
148+
149+
150+
if __name__ == "__main__":
151+
parser = argparse.ArgumentParser(description='metapath2vec')
152+
parser.add_argument('--config', default="./config.yaml", type=str)
153+
args = parser.parse_args()
154+
155+
config = prepare_config(args.config)
156+
157+
main(config)
+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from datasets import dataset
16+
from datasets.dataset import *
17+
18+
__all__ = []
19+
__all__ += dataset.__all__

0 commit comments

Comments
 (0)