Skip to content

Commit 15ff09d

Browse files
committed
removed load library calls
1 parent 60a2946 commit 15ff09d

File tree

6 files changed

+191
-106
lines changed

6 files changed

+191
-106
lines changed

torch_sparse/__init__.py

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
# flake8: noqa
2-
31
import importlib
42
import os.path as osp
53

@@ -9,8 +7,9 @@
97
expected_torch_version = (1, 4)
108

119
try:
12-
torch.ops.load_library(importlib.machinery.PathFinder().find_spec(
13-
'_version', [osp.dirname(__file__)]).origin)
10+
for library in ['_version', '_convert', '_diag', '_spmm', '_spspmm']:
11+
torch.ops.load_library(importlib.machinery.PathFinder().find_spec(
12+
library, [osp.dirname(__file__)]).origin)
1413
except OSError as e:
1514
if 'undefined symbol' in str(e):
1615
major, minor = [int(x) for x in torch.__version__.split('.')[:2]]
@@ -40,26 +39,27 @@
4039
f'{major}.{minor}. Please reinstall the torch_sparse that '
4140
f'matches your PyTorch install.')
4241

43-
from .storage import SparseStorage
44-
from .tensor import SparseTensor
45-
from .transpose import t
46-
from .narrow import narrow, __narrow_diag__
47-
from .select import select
48-
from .index_select import index_select, index_select_nnz
49-
from .masked_select import masked_select, masked_select_nnz
50-
from .diag import remove_diag, set_diag, fill_diag
51-
from .add import add, add_, add_nnz, add_nnz_
52-
from .mul import mul, mul_, mul_nnz, mul_nnz_
53-
from .reduce import sum, mean, min, max
54-
from .matmul import matmul
55-
from .cat import cat, cat_diag
42+
from .storage import SparseStorage # noqa: E4402
43+
from .tensor import SparseTensor # noqa: E4402
44+
from .transpose import t # noqa: E4402
45+
from .narrow import narrow, __narrow_diag__ # noqa: E4402
46+
from .select import select # noqa: E4402
47+
from .index_select import index_select, index_select_nnz # noqa: E4402
48+
from .masked_select import masked_select, masked_select_nnz # noqa: E4402
49+
from .diag import remove_diag, set_diag, fill_diag # noqa: E4402
50+
from .add import add, add_, add_nnz, add_nnz_ # noqa: E4402
51+
from .mul import mul, mul_, mul_nnz, mul_nnz_ # noqa: E4402
52+
from .reduce import sum, mean, min, max # noqa: E4402
53+
from .matmul import matmul # noqa: E4402
54+
from .cat import cat, cat_diag # noqa: E4402
5655

57-
from .convert import to_torch_sparse, from_torch_sparse, to_scipy, from_scipy
58-
from .coalesce import coalesce
59-
from .transpose import transpose
60-
from .eye import eye
61-
from .spmm import spmm
62-
from .spspmm import spspmm
56+
from .convert import to_torch_sparse, from_torch_sparse # noqa: E4402
57+
from .convert import to_scipy, from_scipy # noqa: E4402
58+
from .coalesce import coalesce # noqa: E4402
59+
from .transpose import transpose # noqa: E4402
60+
from .eye import eye # noqa: E4402
61+
from .spmm import spmm # noqa: E4402
62+
from .spspmm import spspmm # noqa: E4402
6363

6464
__all__ = [
6565
'SparseStorage',

torch_sparse/cat.py

Lines changed: 37 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List, Optional
1+
from typing import List
22

33
import torch
44
from torch_sparse.storage import SparseStorage
@@ -63,10 +63,18 @@ def cat(tensors: List[SparseTensor], dim: int) -> SparseTensor:
6363
if len(rowcounts) == len(tensors):
6464
rowcount = torch.cat(rowcounts, dim=0)
6565

66-
storage = SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
67-
sparse_sizes=sparse_sizes, rowcount=rowcount,
68-
colptr=None, colcount=None, csr2csc=None,
69-
csc2csr=None, is_sorted=True)
66+
storage = SparseStorage(
67+
row=row,
68+
rowptr=rowptr,
69+
col=col,
70+
value=value,
71+
sparse_sizes=sparse_sizes,
72+
rowcount=rowcount,
73+
colptr=None,
74+
colcount=None,
75+
csr2csc=None,
76+
csc2csr=None,
77+
is_sorted=True)
7078
return tensors[0].from_storage(storage)
7179

7280
elif dim == 1:
@@ -118,10 +126,18 @@ def cat(tensors: List[SparseTensor], dim: int) -> SparseTensor:
118126
if len(colcounts) == len(tensors):
119127
colcount = torch.cat(colcounts, dim=0)
120128

121-
storage = SparseStorage(row=row, rowptr=None, col=col, value=value,
122-
sparse_sizes=sparse_sizes, rowcount=None,
123-
colptr=colptr, colcount=colcount, csr2csc=None,
124-
csc2csr=None, is_sorted=False)
129+
storage = SparseStorage(
130+
row=row,
131+
rowptr=None,
132+
col=col,
133+
value=value,
134+
sparse_sizes=sparse_sizes,
135+
rowcount=None,
136+
colptr=colptr,
137+
colcount=colcount,
138+
csr2csc=None,
139+
csc2csr=None,
140+
is_sorted=False)
125141
return tensors[0].from_storage(storage)
126142

127143
elif dim > 1 and dim < tensors[0].dim():
@@ -235,8 +251,16 @@ def cat_diag(tensors: List[SparseTensor]) -> SparseTensor:
235251
if len(csc2csrs) == len(tensors):
236252
csc2csr = torch.cat(csc2csrs, dim=0)
237253

238-
storage = SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
239-
sparse_sizes=sparse_sizes, rowcount=rowcount,
240-
colptr=colptr, colcount=colcount, csr2csc=csr2csc,
241-
csc2csr=csc2csr, is_sorted=True)
254+
storage = SparseStorage(
255+
row=row,
256+
rowptr=rowptr,
257+
col=col,
258+
value=value,
259+
sparse_sizes=sparse_sizes,
260+
rowcount=rowcount,
261+
colptr=colptr,
262+
colcount=colcount,
263+
csr2csc=csr2csc,
264+
csc2csr=csc2csr,
265+
is_sorted=True)
242266
return tensors[0].from_storage(storage)

torch_sparse/diag.py

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,9 @@
1-
import importlib
2-
import os.path as osp
31
from typing import Optional
42

53
import torch
64
from torch_sparse.storage import SparseStorage
75
from torch_sparse.tensor import SparseTensor
86

9-
torch.ops.load_library(importlib.machinery.PathFinder().find_spec(
10-
'_diag', [osp.dirname(__file__)]).origin)
11-
127

138
@torch.jit.script
149
def remove_diag(src: SparseTensor, k: int = 0) -> SparseTensor:
@@ -30,15 +25,24 @@ def remove_diag(src: SparseTensor, k: int = 0) -> SparseTensor:
3025
colcount = colcount.clone()
3126
colcount[col[mask]] -= 1
3227

33-
storage = SparseStorage(row=new_row, rowptr=None, col=new_col, value=value,
34-
sparse_sizes=src.sparse_sizes(), rowcount=rowcount,
35-
colptr=None, colcount=colcount, csr2csc=None,
36-
csc2csr=None, is_sorted=True)
28+
storage = SparseStorage(
29+
row=new_row,
30+
rowptr=None,
31+
col=new_col,
32+
value=value,
33+
sparse_sizes=src.sparse_sizes(),
34+
rowcount=rowcount,
35+
colptr=None,
36+
colcount=colcount,
37+
csr2csc=None,
38+
csc2csr=None,
39+
is_sorted=True)
3740
return src.from_storage(storage)
3841

3942

4043
@torch.jit.script
41-
def set_diag(src: SparseTensor, values: Optional[torch.Tensor] = None,
44+
def set_diag(src: SparseTensor,
45+
values: Optional[torch.Tensor] = None,
4246
k: int = 0) -> SparseTensor:
4347
src = remove_diag(src, k=k)
4448
row, col, value = src.coo()
@@ -65,7 +69,8 @@ def set_diag(src: SparseTensor, values: Optional[torch.Tensor] = None,
6569
if values is not None:
6670
new_value[inv_mask] = values
6771
else:
68-
new_value[inv_mask] = torch.ones((num_diag, ), dtype=value.dtype,
72+
new_value[inv_mask] = torch.ones((num_diag, ),
73+
dtype=value.dtype,
6974
device=value.device)
7075

7176
rowcount = src.storage._rowcount
@@ -78,10 +83,18 @@ def set_diag(src: SparseTensor, values: Optional[torch.Tensor] = None,
7883
colcount = colcount.clone()
7984
colcount[start + k:start + num_diag + k] += 1
8085

81-
storage = SparseStorage(row=new_row, rowptr=None, col=new_col,
82-
value=new_value, sparse_sizes=src.sparse_sizes(),
83-
rowcount=rowcount, colptr=None, colcount=colcount,
84-
csr2csc=None, csc2csr=None, is_sorted=True)
86+
storage = SparseStorage(
87+
row=new_row,
88+
rowptr=None,
89+
col=new_col,
90+
value=new_value,
91+
sparse_sizes=src.sparse_sizes(),
92+
rowcount=rowcount,
93+
colptr=None,
94+
colcount=colcount,
95+
csr2csc=None,
96+
csc2csr=None,
97+
is_sorted=True)
8598
return src.from_storage(storage)
8699

87100

torch_sparse/intersection.py

Whitespace-only changes.

torch_sparse/matmul.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,8 @@
1-
import importlib
2-
import os.path as osp
31
from typing import Union, Tuple
42

53
import torch
64
from torch_sparse.tensor import SparseTensor
75

8-
torch.ops.load_library(importlib.machinery.PathFinder().find_spec(
9-
'_spmm', [osp.dirname(__file__)]).origin)
10-
torch.ops.load_library(importlib.machinery.PathFinder().find_spec(
11-
'_spspmm', [osp.dirname(__file__)]).origin)
12-
136

147
@torch.jit.script
158
def spmm_sum(src: SparseTensor, other: torch.Tensor) -> torch.Tensor:
@@ -95,8 +88,13 @@ def spspmm_sum(src: SparseTensor, other: SparseTensor) -> SparseTensor:
9588
M, K = src.sparse_size(0), other.sparse_size(1)
9689
rowptrC, colC, valueC = torch.ops.torch_sparse.spspmm_sum(
9790
rowptrA, colA, valueA, rowptrB, colB, valueB, K)
98-
return SparseTensor(row=None, rowptr=rowptrC, col=colC, value=valueC,
99-
sparse_sizes=torch.Size([M, K]), is_sorted=True)
91+
return SparseTensor(
92+
row=None,
93+
rowptr=rowptrC,
94+
col=colC,
95+
value=valueC,
96+
sparse_sizes=torch.Size([M, K]),
97+
is_sorted=True)
10098

10199

102100
@torch.jit.script
@@ -115,7 +113,8 @@ def spspmm(src: SparseTensor, other: SparseTensor,
115113
raise ValueError
116114

117115

118-
def matmul(src: SparseTensor, other: Union[torch.Tensor, SparseTensor],
116+
def matmul(src: SparseTensor,
117+
other: Union[torch.Tensor, SparseTensor],
119118
reduce: str = "sum"):
120119
if torch.is_tensor(other):
121120
return spmm(src, other, reduce)

0 commit comments

Comments
 (0)