Skip to content

Commit c1cd975

Browse files
committed
multi gpu update
1 parent e6a8f8c commit c1cd975

File tree

4 files changed

+4
-2
lines changed

4 files changed

+4
-2
lines changed

cuda/spspmm_kernel.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ static void init_cusparse() {
3030
std::tuple<at::Tensor, at::Tensor>
3131
spspmm_cuda(at::Tensor indexA, at::Tensor valueA, at::Tensor indexB,
3232
at::Tensor valueB, int m, int k, int n) {
33+
cudaSetDevice(indexA.get_device());
3334
init_cusparse();
3435

3536
indexA = indexA.contiguous();

cuda/unique_kernel.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ __global__ void unique_cuda_kernel(scalar_t *__restrict__ src, uint8_t *mask,
1616
}
1717

1818
std::tuple<at::Tensor, at::Tensor> unique_cuda(at::Tensor src) {
19+
cudaSetDevice(src.get_device());
1920
at::Tensor perm;
2021
std::tie(src, perm) = src.sort();
2122

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from setuptools import setup, find_packages
33
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME
44

5-
__version__ = '0.2.3'
5+
__version__ = '0.2.4'
66
url = 'https://github.com/rusty1s/pytorch_sparse'
77

88
install_requires = ['scipy']

torch_sparse/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from .spmm import spmm
55
from .spspmm import spspmm
66

7-
__version__ = '0.2.3'
7+
__version__ = '0.2.4'
88

99
__all__ = [
1010
'__version__',

0 commit comments

Comments
 (0)