Skip to content

Commit 996f7a9

Browse files
authored
Merge pull request #3 from PaddlePaddle/main
update
2 parents 2c9ead8 + 8743242 commit 996f7a9

File tree

5 files changed

+144
-18
lines changed

5 files changed

+144
-18
lines changed

examples/SAGPool/main.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,8 +124,7 @@ def main(args, train_dataset, val_dataset, test_dataset):
124124
break
125125

126126
correct = 0.
127-
new_test_program = fluid.Program()
128-
fluid.load(new_test_program, "./save/%s/%s" \
127+
fluid.load(test_program, "./save/%s/%s" \
129128
% (args.dataset_name, args.save_model), exe)
130129
for feed_dict in test_loader:
131130
correct_ = exe.run(test_program,

ogb_examples/nodeproppred/ogbn-arxiv/dataloader/ogbn_arxiv_dataloader.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
ssl._create_default_https_context = ssl._create_unverified_context
2525

2626
from pgl.contrib.ogb.nodeproppred.dataset_pgl import PglNodePropPredDataset
27-
#from pgl.sample import graph_saint_random_walk_sample
27+
from pgl.sample import graph_saint_random_walk_sample
2828
from ogb.nodeproppred import Evaluator
2929
import tqdm
3030
from collections import namedtuple
@@ -78,10 +78,10 @@ def k_hop_sampler(graph, samples, batch_nodes):
7878
return subgraph, sub_node_index
7979

8080

81-
#def graph_saint_randomwalk_sampler(graph, batch_nodes, max_depth=3):
82-
# subgraph = graph_saint_random_walk_sample(graph, batch_nodes, max_depth)
83-
# sub_node_index = subgraph.reindex_from_parrent_nodes(batch_nodes)
84-
# return subgraph, sub_node_index
81+
def graph_saint_randomwalk_sampler(graph, batch_nodes, max_depth=3):
82+
subgraph = graph_saint_random_walk_sample(graph, batch_nodes, max_depth)
83+
sub_node_index = subgraph.reindex_from_parrent_nodes(batch_nodes)
84+
return subgraph, sub_node_index
8585

8686

8787
class ArxivDataGenerator(BaseDataGenerator):

pgl/graph_kernel.pyx

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -321,3 +321,43 @@ def alias_sample_build_table(np.ndarray[np.float64_t, ndim=1] probs):
321321
if alias[l_i] < 1:
322322
smaller_num.push_back(l_i)
323323
return alias, events
324+
325+
@cython.boundscheck(False)
326+
@cython.wraparound(False)
327+
def extract_edges_from_nodes(
328+
np.ndarray[np.int64_t, ndim=1] adj_indptr,
329+
np.ndarray[np.int64_t, ndim=1] sorted_v,
330+
np.ndarray[np.int64_t, ndim=1] sorted_eid,
331+
vector[long long] sampled_nodes,
332+
):
333+
"""
334+
Extract all eids of given sampled_nodes for the origin graph.
335+
ret_edge_index: edge ids between sampled_nodes.
336+
337+
Refers: https://github.com/GraphSAINT/GraphSAINT
338+
"""
339+
cdef long long i, v, j
340+
cdef long long num_v_orig, num_v_sub
341+
cdef long long start_neigh, end_neigh
342+
cdef vector[int] _arr_bit
343+
cdef vector[long long] ret_edge_index
344+
num_v_orig = adj_indptr.size-1
345+
_arr_bit = vector[int](num_v_orig,-1)
346+
num_v_sub = sampled_nodes.size()
347+
i = 0
348+
with nogil:
349+
while i < num_v_sub:
350+
_arr_bit[sampled_nodes[i]] = i
351+
i = i + 1
352+
i = 0
353+
while i < num_v_sub:
354+
v = sampled_nodes[i]
355+
start_neigh = adj_indptr[v]
356+
end_neigh = adj_indptr[v+1]
357+
j = start_neigh
358+
while j < end_neigh:
359+
if _arr_bit[sorted_v[j]] > -1:
360+
ret_edge_index.push_back(sorted_eid[j])
361+
j = j + 1
362+
i = i + 1
363+
return ret_edge_index

pgl/sample.py

Lines changed: 51 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
__all__ = [
2626
'graphsage_sample', 'node2vec_sample', 'deepwalk_sample',
27-
'metapath_randomwalk', 'pinsage_sample'
27+
'metapath_randomwalk', 'pinsage_sample', 'graph_saint_random_walk_sample'
2828
]
2929

3030

@@ -55,15 +55,15 @@ def edge_hash(src, dst):
5555

5656
def graphsage_sample(graph, nodes, samples, ignore_edges=[]):
5757
"""Implement of graphsage sample.
58-
58+
5959
Reference paper: https://cs.stanford.edu/people/jure/pubs/graphsage-nips17.pdf.
6060
6161
Args:
6262
graph: A pgl graph instance
6363
nodes: Sample starting from nodes
6464
samples: A list, number of neighbors in each layer
6565
ignore_edges: list of edge(src, dst) will be ignored.
66-
66+
6767
Return:
6868
A list of subgraphs
6969
"""
@@ -129,7 +129,7 @@ def alias_sample(size, alias, events):
129129
size: Output shape.
130130
alias: The alias table build by `alias_sample_build_table`.
131131
events: The events table build by `alias_sample_build_table`.
132-
132+
133133
Return:
134134
samples: The generated random samples.
135135
"""
@@ -283,13 +283,13 @@ def metapath_randomwalk(graph,
283283
Args:
284284
graph: instance of pgl heterogeneous graph
285285
start_nodes: start nodes to generate walk
286-
metapath: meta path for sample nodes.
286+
metapath: meta path for sample nodes.
287287
e.g: "c2p-p2a-a2p-p2c"
288288
walk_length: the walk length
289289
290290
Return:
291-
a list of metapath walks.
292-
291+
a list of metapath walks.
292+
293293
"""
294294

295295
edge_types = metapath.split('-')
@@ -390,18 +390,18 @@ def pinsage_sample(graph,
390390
norm_bais=1.0,
391391
ignore_edges=set()):
392392
"""Implement of graphsage sample.
393-
393+
394394
Reference paper: .
395395
396396
Args:
397397
graph: A pgl graph instance
398398
nodes: Sample starting from nodes
399399
samples: A list, number of neighbors in each layer
400-
top_k: select the top_k visit count nodes to construct the edges
401-
proba: the probability to return the origin node
400+
top_k: select the top_k visit count nodes to construct the edges
401+
proba: the probability to return the origin node
402402
norm_bais: the normlization for the visit count
403403
ignore_edges: list of edge(src, dst) will be ignored.
404-
404+
405405
Return:
406406
A list of subgraphs
407407
"""
@@ -476,3 +476,43 @@ def pinsage_sample(graph,
476476
layer_nodes[0], dtype="int64")
477477

478478
return subgraphs
479+
480+
481+
def extract_edges_from_nodes(graph, sample_nodes):
482+
eids = graph_kernel.extract_edges_from_nodes(
483+
graph.adj_src_index._indptr, graph.adj_src_index._sorted_v,
484+
graph.adj_src_index._sorted_eid, sample_nodes)
485+
return eids
486+
487+
488+
def graph_saint_random_walk_sample(graph,
489+
nodes,
490+
max_depth,
491+
alias_name=None,
492+
events_name=None):
493+
"""Implement of graph saint random walk sample.
494+
495+
First, this function will get random walks path for given nodes and depth.
496+
Then, it will create subgraph from all sampled nodes.
497+
498+
Reference Paper: https://arxiv.org/abs/1907.04931
499+
500+
Args:
501+
graph: A pgl graph instance
502+
nodes: Walk starting from nodes
503+
max_depth: Max walking depth
504+
505+
Return:
506+
a subgraph of sampled nodes.
507+
"""
508+
graph.outdegree()
509+
walks = deepwalk_sample(graph, nodes, max_depth, alias_name, events_name)
510+
sample_nodes = []
511+
for walk in walks:
512+
sample_nodes.extend(walk)
513+
sample_nodes = np.unique(sample_nodes)
514+
eids = extract_edges_from_nodes(graph, sample_nodes)
515+
subgraph = graph.subgraph(
516+
nodes=sample_nodes, eid=eids, with_node_feat=True, with_edge_feat=True)
517+
subgraph.node_feat["index"] = np.array(sample_nodes, dtype="int64")
518+
return subgraph

pgl/tests/test_graph_saint_sample.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""graph saint sample test
15+
"""
16+
from __future__ import division
17+
from __future__ import absolute_import
18+
from __future__ import print_function
19+
from __future__ import unicode_literals
20+
import unittest
21+
import numpy as np
22+
23+
import pgl
24+
import paddle.fluid as fluid
25+
from pgl.sample import graph_saint_random_walk_sample
26+
27+
28+
class GraphSaintSampleTest(unittest.TestCase):
29+
"""GraphSaintSampleTest"""
30+
31+
def test_randomwalk_sampler(self):
32+
"""test_randomwalk_sampler"""
33+
g = pgl.graph.Graph(
34+
num_nodes=8,
35+
edges=[(1, 2), (2, 3), (0, 2), (0, 1), (6, 7), (4, 5), (6, 4),
36+
(7, 4), (3, 4)])
37+
subgraph = graph_saint_random_walk_sample(g, [6, 7], 2)
38+
print('reindex', subgraph._from_reindex)
39+
print('subedges', subgraph.edges)
40+
assert len(subgraph.nodes) == 4
41+
assert len(subgraph.edges) == 4
42+
true_edges = np.array([[0, 1], [2, 3], [2, 0], [3, 0]])
43+
assert "{}".format(subgraph.edges) == "{}".format(true_edges)
44+
45+
46+
if __name__ == '__main__':
47+
unittest.main()

0 commit comments

Comments
 (0)