Skip to content

Commit d38731d

Browse files
committed
graph_kernel: add doc
1 parent 4c3b994 commit d38731d

File tree

1 file changed

+25
-5
lines changed

1 file changed

+25
-5
lines changed

pgl/graph_kernel.pyx

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -253,10 +253,22 @@ def sample_subset_with_eid(list nids, list eids, long long maxdegree, shuffle=Fa
253253

254254
@cython.boundscheck(False)
255255
@cython.wraparound(False)
256-
def skip_gram_gen_pair(vector[long long] walk, long win_size=5):
256+
def skip_gram_gen_pair(vector[long long] walk_path, long win_size=5):
257+
"""Return node paris generated by skip-gram algorithm.
258+
259+
This function will auto remove the pair which src node is the same
260+
as dst node.
261+
262+
Args:
263+
walk_path: List of nodes as a walk path.
264+
win_size: the windows size used in skip-gram.
265+
266+
Return:
267+
A tuple of (src node list, dst node list).
268+
"""
257269
cdef vector[long long] src
258270
cdef vector[long long] dst
259-
cdef long long l = len(walk)
271+
cdef long long l = len(walk_path)
260272
cdef long long real_win_size, left, right, i
261273
cdef np.ndarray[np.int64_t, ndim=1] rnd = np.random.randint(1, win_size+1,
262274
dtype=np.int64, size=l)
@@ -270,15 +282,23 @@ def skip_gram_gen_pair(vector[long long] walk, long win_size=5):
270282
if right >= l:
271283
right = l - 1
272284
for j in xrange(left, right+1):
273-
if walk[i] == walk[j]:
285+
if walk_path[i] == walk_path[j]:
274286
continue
275-
src.push_back(walk[i])
276-
dst.push_back(walk[j])
287+
src.push_back(walk_path[i])
288+
dst.push_back(walk_path[j])
277289
return src, dst
278290

279291
@cython.boundscheck(False)
280292
@cython.wraparound(False)
281293
def alias_sample_build_table(np.ndarray[np.float64_t, ndim=1] probs):
294+
"""Return the alias table and event table for alias sampling.
295+
296+
Args:
297+
porobs: A list of float numbers as the probability.
298+
299+
Return:
300+
A tuple of (alias table, event table).
301+
"""
282302
cdef long long l = len(probs)
283303
cdef np.ndarray[np.float64_t, ndim=1] alias = probs * l
284304
cdef np.ndarray[np.int64_t, ndim=1] events = np.zeros(l, dtype=np.int64)

0 commit comments

Comments
 (0)