Skip to content

Commit 6b4d97f

Browse files
authored
Merge pull request #48 from rusty1s/graph_saint
Graph saint
2 parents a1ae903 + a597b82 commit 6b4d97f

File tree

13 files changed

+378
-51
lines changed

13 files changed

+378
-51
lines changed

csrc/cpu/rw_cpu.cpp

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
#include "rw_cpu.h"
2+
3+
#include "utils.h"
4+
5+
torch::Tensor random_walk_cpu(torch::Tensor rowptr, torch::Tensor col,
6+
torch::Tensor start, int64_t walk_length) {
7+
CHECK_CPU(rowptr);
8+
CHECK_CPU(col);
9+
CHECK_CPU(start);
10+
11+
CHECK_INPUT(rowptr.dim() == 1);
12+
CHECK_INPUT(col.dim() == 1);
13+
CHECK_INPUT(start.dim() == 1);
14+
15+
auto rand = torch::rand({start.size(0), walk_length},
16+
start.options().dtype(torch::kFloat));
17+
18+
auto L = walk_length + 1;
19+
auto out = torch::full({start.size(0), L}, -1, start.options());
20+
21+
auto rowptr_data = rowptr.data_ptr<int64_t>();
22+
auto col_data = col.data_ptr<int64_t>();
23+
auto start_data = start.data_ptr<int64_t>();
24+
auto rand_data = rand.data_ptr<float>();
25+
auto out_data = out.data_ptr<int64_t>();
26+
27+
for (auto n = 0; n < start.size(0); n++) {
28+
auto cur = start_data[n];
29+
out_data[n * L] = cur;
30+
31+
int64_t row_start, row_end;
32+
for (auto l = 0; l < walk_length; l++) {
33+
row_start = rowptr_data[cur];
34+
row_end = rowptr_data[cur + 1];
35+
36+
cur = col_data[row_start + int64_t(rand_data[n * walk_length + l] *
37+
(row_end - row_start))];
38+
out_data[n * L + l + 1] = cur;
39+
}
40+
}
41+
42+
return out;
43+
}

csrc/cpu/rw_cpu.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
#pragma once
2+
3+
#include <torch/extension.h>
4+
5+
torch::Tensor random_walk_cpu(torch::Tensor rowptr, torch::Tensor col,
6+
torch::Tensor start, int64_t walk_length);

csrc/cpu/saint_cpu.cpp

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
#include "saint_cpu.h"
2+
3+
#include "utils.h"
4+
5+
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>
6+
subgraph_cpu(torch::Tensor idx, torch::Tensor rowptr, torch::Tensor row,
7+
torch::Tensor col) {
8+
CHECK_CPU(idx);
9+
CHECK_CPU(rowptr);
10+
CHECK_CPU(col);
11+
12+
CHECK_INPUT(idx.dim() == 1);
13+
CHECK_INPUT(rowptr.dim() == 1);
14+
CHECK_INPUT(col.dim() == 1);
15+
16+
auto assoc = torch::full({rowptr.size(0) - 1}, -1, idx.options());
17+
assoc.index_copy_(0, idx, torch::arange(idx.size(0), idx.options()));
18+
19+
auto idx_data = idx.data_ptr<int64_t>();
20+
auto rowptr_data = rowptr.data_ptr<int64_t>();
21+
auto col_data = col.data_ptr<int64_t>();
22+
auto assoc_data = assoc.data_ptr<int64_t>();
23+
24+
std::vector<int64_t> rows, cols, indices;
25+
26+
int64_t v, w, w_new, row_start, row_end;
27+
for (int64_t v_new = 0; v_new < idx.size(0); v_new++) {
28+
v = idx_data[v_new];
29+
row_start = rowptr_data[v];
30+
row_end = rowptr_data[v + 1];
31+
32+
for (int64_t j = row_start; j < row_end; j++) {
33+
w = col_data[j];
34+
w_new = assoc_data[w];
35+
if (w_new > -1) {
36+
rows.push_back(v_new);
37+
cols.push_back(w_new);
38+
indices.push_back(j);
39+
}
40+
}
41+
}
42+
43+
int64_t length = rows.size();
44+
row = torch::from_blob(rows.data(), {length}, row.options()).clone();
45+
col = torch::from_blob(cols.data(), {length}, row.options()).clone();
46+
idx = torch::from_blob(indices.data(), {length}, row.options()).clone();
47+
48+
return std::make_tuple(row, col, idx);
49+
}

csrc/cpu/saint_cpu.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
#pragma once
2+
3+
#include <torch/extension.h>
4+
5+
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>
6+
subgraph_cpu(torch::Tensor idx, torch::Tensor rowptr, torch::Tensor row,
7+
torch::Tensor col);

csrc/cuda/rw_cuda.cu

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
#include "rw_cuda.h"
2+
3+
#include <ATen/cuda/CUDAContext.h>
4+
5+
#include "utils.cuh"
6+
7+
#define THREADS 1024
8+
#define BLOCKS(N) (N + THREADS - 1) / THREADS
9+
10+
__global__ void uniform_random_walk_kernel(const int64_t *rowptr,
11+
const int64_t *col,
12+
const int64_t *start,
13+
const float *rand, int64_t *out,
14+
int64_t walk_length, int64_t numel) {
15+
const int64_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
16+
17+
if (thread_idx < numel) {
18+
int64_t cur = start[thread_idx];
19+
out[thread_idx] = cur;
20+
21+
int64_t row_start, row_end;
22+
for (int64_t l = 0; l < walk_length; l++) {
23+
row_start = rowptr[cur], row_end = rowptr[cur + 1];
24+
cur = col[row_start +
25+
int64_t(rand[l * numel + thread_idx] * (row_end - row_start))];
26+
out[(l + 1) * numel + thread_idx] = cur;
27+
}
28+
}
29+
}
30+
31+
torch::Tensor random_walk_cuda(torch::Tensor rowptr, torch::Tensor col,
32+
torch::Tensor start, int64_t walk_length) {
33+
CHECK_CUDA(rowptr);
34+
CHECK_CUDA(col);
35+
CHECK_CUDA(start);
36+
cudaSetDevice(rowptr.get_device());
37+
38+
CHECK_INPUT(rowptr.dim() == 1);
39+
CHECK_INPUT(col.dim() == 1);
40+
CHECK_INPUT(start.dim() == 1);
41+
42+
auto rand = torch::rand({walk_length, start.size(0)},
43+
start.options().dtype(torch::kFloat));
44+
auto out = torch::full({walk_length + 1, start.size(0)}, -1, start.options());
45+
46+
auto stream = at::cuda::getCurrentCUDAStream();
47+
uniform_random_walk_kernel<<<BLOCKS(start.numel()), THREADS, 0, stream>>>(
48+
rowptr.data_ptr<int64_t>(), col.data_ptr<int64_t>(),
49+
start.data_ptr<int64_t>(), rand.data_ptr<float>(),
50+
out.data_ptr<int64_t>(), walk_length, start.numel());
51+
52+
return out.t().contiguous();
53+
}

csrc/cuda/rw_cuda.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
#pragma once
2+
3+
#include <torch/extension.h>
4+
5+
torch::Tensor random_walk_cuda(torch::Tensor rowptr, torch::Tensor col,
6+
torch::Tensor start, int64_t walk_length);

csrc/rw.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
#include <Python.h>
2+
#include <torch/script.h>
3+
4+
#include "cpu/rw_cpu.h"
5+
6+
#ifdef WITH_CUDA
7+
#include "cuda/rw_cuda.h"
8+
#endif
9+
10+
#ifdef _WIN32
11+
PyMODINIT_FUNC PyInit__rw(void) { return NULL; }
12+
#endif
13+
14+
torch::Tensor random_walk(torch::Tensor rowptr, torch::Tensor col,
15+
torch::Tensor start, int64_t walk_length) {
16+
if (rowptr.device().is_cuda()) {
17+
#ifdef WITH_CUDA
18+
return random_walk_cuda(rowptr, col, start, walk_length);
19+
#else
20+
AT_ERROR("Not compiled with CUDA support");
21+
#endif
22+
} else {
23+
return random_walk_cpu(rowptr, col, start, walk_length);
24+
}
25+
}
26+
27+
static auto registry =
28+
torch::RegisterOperators().op("torch_sparse::random_walk", &random_walk);

csrc/saint.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
#include <Python.h>
2+
#include <torch/script.h>
3+
4+
#include "cpu/saint_cpu.h"
5+
6+
#ifdef _WIN32
7+
PyMODINIT_FUNC PyInit__saint(void) { return NULL; }
8+
#endif
9+
10+
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>
11+
subgraph(torch::Tensor idx, torch::Tensor rowptr, torch::Tensor row,
12+
torch::Tensor col) {
13+
if (idx.device().is_cuda()) {
14+
#ifdef WITH_CUDA
15+
AT_ERROR("No CUDA version supported");
16+
#else
17+
AT_ERROR("Not compiled with CUDA support");
18+
#endif
19+
} else {
20+
return subgraph_cpu(idx, rowptr, row, col);
21+
}
22+
}
23+
24+
static auto registry =
25+
torch::RegisterOperators().op("torch_sparse::saint_subgraph", &subgraph);

test/test_saint.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
import torch
2+
from torch_sparse.tensor import SparseTensor
3+
4+
5+
def test_saint_subgraph():
6+
row = torch.tensor([0, 0, 1, 1, 2, 2, 2, 3, 3, 4])
7+
col = torch.tensor([1, 2, 0, 2, 0, 1, 3, 2, 4, 3])
8+
adj = SparseTensor(row=row, col=col)
9+
node_idx = torch.tensor([0, 1, 2])
10+
11+
adj, edge_index = adj.saint_subgraph(node_idx)

torch_sparse/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88

99
try:
1010
for library in [
11-
'_version', '_convert', '_diag', '_spmm', '_spspmm', '_metis'
11+
'_version', '_convert', '_diag', '_spmm', '_spspmm', '_metis',
12+
'_rw', '_saint'
1213
]:
1314
torch.ops.load_library(importlib.machinery.PathFinder().find_spec(
1415
library, [osp.dirname(__file__)]).origin)
@@ -54,7 +55,9 @@
5455
from .reduce import sum, mean, min, max # noqa
5556
from .matmul import matmul # noqa
5657
from .cat import cat, cat_diag # noqa
58+
from .rw import random_walk # noqa
5759
from .metis import partition # noqa
60+
from .saint import saint_subgraph # noqa
5861

5962
from .convert import to_torch_sparse, from_torch_sparse # noqa
6063
from .convert import to_scipy, from_scipy # noqa
@@ -94,7 +97,9 @@
9497
'matmul',
9598
'cat',
9699
'cat_diag',
100+
'random_walk',
97101
'partition',
102+
'saint_subgraph',
98103
'to_torch_sparse',
99104
'from_torch_sparse',
100105
'to_scipy',

0 commit comments

Comments
 (0)