Skip to content

Commit 872938a

Browse files
committed
overload for cat
1 parent 468aea5 commit 872938a

File tree

3 files changed

+140
-120
lines changed

3 files changed

+140
-120
lines changed

test/test_cat.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import pytest
22
import torch
33
from torch_sparse.tensor import SparseTensor
4-
from torch_sparse.cat import cat, cat_diag
4+
from torch_sparse.cat import cat
55

66
from .utils import devices, tensor
77

@@ -31,7 +31,7 @@ def test_cat(device):
3131
assert not out.storage.has_rowptr()
3232
assert out.storage.num_cached_keys() == 2
3333

34-
out = cat_diag([mat1, mat2])
34+
out = cat([mat1, mat2], dim=(0, 1))
3535
assert out.to_dense().tolist() == [[1, 1, 0, 0, 0], [0, 0, 1, 0, 0],
3636
[0, 0, 0, 1, 1], [0, 0, 0, 0, 1],
3737
[0, 0, 0, 1, 0]]

torch_sparse/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
from .mul import mul, mul_, mul_nnz, mul_nnz_ # noqa
4545
from .reduce import sum, mean, min, max # noqa
4646
from .matmul import matmul # noqa
47-
from .cat import cat, cat_diag # noqa
47+
from .cat import cat # noqa
4848
from .rw import random_walk # noqa
4949
from .metis import partition # noqa
5050
from .bandwidth import reverse_cuthill_mckee # noqa
@@ -89,7 +89,6 @@
8989
'max',
9090
'matmul',
9191
'cat',
92-
'cat_diag',
9392
'random_walk',
9493
'partition',
9594
'reverse_cuthill_mckee',

torch_sparse/cat.py

Lines changed: 137 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -1,144 +1,168 @@
1-
from typing import Optional, List
1+
from typing import Optional, List, Tuple
22

33
import torch
44
from torch_sparse.storage import SparseStorage
55
from torch_sparse.tensor import SparseTensor
66

77

8-
def cat(tensors: List[SparseTensor], dim: int) -> SparseTensor:
8+
@torch.jit._overload # noqa: F811
9+
def cat(tensors, dim): # noqa: F811
10+
# type: (List[SparseTensor], int) -> SparseTensor
11+
pass
12+
13+
14+
@torch.jit._overload # noqa: F811
15+
def cat(tensors, dim): # noqa: F811
16+
# type: (List[SparseTensor], Tuple[int, int]) -> SparseTensor
17+
pass
18+
19+
20+
@torch.jit._overload # noqa: F811
21+
def cat(tensors, dim): # noqa: F811
22+
# type: (List[SparseTensor], List[int]) -> SparseTensor
23+
pass
24+
25+
26+
def cat(tensors, dim): # noqa: F811
927
assert len(tensors) > 0
10-
if dim < 0:
11-
dim = tensors[0].dim() + dim
12-
13-
if dim == 0:
14-
rows: List[torch.Tensor] = []
15-
rowptrs: List[torch.Tensor] = []
16-
cols: List[torch.Tensor] = []
17-
values: List[torch.Tensor] = []
18-
sparse_sizes: List[int] = [0, 0]
19-
rowcounts: List[torch.Tensor] = []
20-
21-
nnz: int = 0
22-
for tensor in tensors:
23-
row = tensor.storage._row
24-
if row is not None:
25-
rows.append(row + sparse_sizes[0])
26-
27-
rowptr = tensor.storage._rowptr
28-
if rowptr is not None:
29-
if len(rowptrs) > 0:
30-
rowptr = rowptr[1:]
31-
rowptrs.append(rowptr + nnz)
32-
33-
cols.append(tensor.storage._col)
34-
35-
value = tensor.storage._value
36-
if value is not None:
28+
29+
if isinstance(dim, int):
30+
dim = tensors[0].dim() + dim if dim < 0 else dim
31+
32+
if dim == 0:
33+
return cat_first(tensors)
34+
35+
elif dim == 1:
36+
return cat_second(tensors)
37+
pass
38+
39+
elif dim > 1 and dim < tensors[0].dim():
40+
values = []
41+
for tensor in tensors:
42+
value = tensor.storage.value()
43+
assert value is not None
3744
values.append(value)
45+
value = torch.cat(values, dim=dim - 1)
46+
return tensors[0].set_value(value, layout='coo')
3847

39-
rowcount = tensor.storage._rowcount
40-
if rowcount is not None:
41-
rowcounts.append(rowcount)
48+
else:
49+
raise IndexError(
50+
(f'Dimension out of range: Expected to be in range of '
51+
f'[{-tensors[0].dim()}, {tensors[0].dim() - 1}], but got '
52+
f'{dim}.'))
53+
else:
54+
assert isinstance(dim, (tuple, list))
55+
assert len(dim) == 2
56+
assert sorted(dim) == [0, 1]
57+
return cat_diag(tensors)
4258

43-
sparse_sizes[0] += tensor.sparse_size(0)
44-
sparse_sizes[1] = max(sparse_sizes[1], tensor.sparse_size(1))
45-
nnz += tensor.nnz()
4659

47-
row: Optional[torch.Tensor] = None
48-
if len(rows) == len(tensors):
49-
row = torch.cat(rows, dim=0)
60+
def cat_first(tensors: List[SparseTensor]) -> SparseTensor:
61+
rows: List[torch.Tensor] = []
62+
rowptrs: List[torch.Tensor] = []
63+
cols: List[torch.Tensor] = []
64+
values: List[torch.Tensor] = []
65+
sparse_sizes: List[int] = [0, 0]
66+
rowcounts: List[torch.Tensor] = []
5067

51-
rowptr: Optional[torch.Tensor] = None
52-
if len(rowptrs) == len(tensors):
53-
rowptr = torch.cat(rowptrs, dim=0)
68+
nnz: int = 0
69+
for tensor in tensors:
70+
row = tensor.storage._row
71+
if row is not None:
72+
rows.append(row + sparse_sizes[0])
5473

55-
col = torch.cat(cols, dim=0)
74+
rowptr = tensor.storage._rowptr
75+
if rowptr is not None:
76+
rowptrs.append(rowptr[1:] + nnz if len(rowptrs) > 0 else rowptr)
5677

57-
value: Optional[torch.Tensor] = None
58-
if len(values) == len(tensors):
59-
value = torch.cat(values, dim=0)
78+
cols.append(tensor.storage._col)
6079

61-
rowcount: Optional[torch.Tensor] = None
62-
if len(rowcounts) == len(tensors):
63-
rowcount = torch.cat(rowcounts, dim=0)
80+
value = tensor.storage._value
81+
if value is not None:
82+
values.append(value)
6483

65-
storage = SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
66-
sparse_sizes=sparse_sizes, rowcount=rowcount,
67-
colptr=None, colcount=None, csr2csc=None,
68-
csc2csr=None, is_sorted=True)
69-
return tensors[0].from_storage(storage)
84+
rowcount = tensor.storage._rowcount
85+
if rowcount is not None:
86+
rowcounts.append(rowcount)
7087

71-
elif dim == 1:
72-
rows: List[torch.Tensor] = []
73-
cols: List[torch.Tensor] = []
74-
values: List[torch.Tensor] = []
75-
sparse_sizes: List[int] = [0, 0]
76-
colptrs: List[torch.Tensor] = []
77-
colcounts: List[torch.Tensor] = []
88+
sparse_sizes[0] += tensor.sparse_size(0)
89+
sparse_sizes[1] = max(sparse_sizes[1], tensor.sparse_size(1))
90+
nnz += tensor.nnz()
7891

79-
nnz: int = 0
80-
for tensor in tensors:
81-
row, col, value = tensor.coo()
92+
row: Optional[torch.Tensor] = None
93+
if len(rows) == len(tensors):
94+
row = torch.cat(rows, dim=0)
8295

83-
rows.append(row)
96+
rowptr: Optional[torch.Tensor] = None
97+
if len(rowptrs) == len(tensors):
98+
rowptr = torch.cat(rowptrs, dim=0)
8499

85-
cols.append(tensor.storage._col + sparse_sizes[1])
100+
col = torch.cat(cols, dim=0)
86101

87-
if value is not None:
88-
values.append(value)
102+
value: Optional[torch.Tensor] = None
103+
if len(values) == len(tensors):
104+
value = torch.cat(values, dim=0)
89105

90-
colptr = tensor.storage._colptr
91-
if colptr is not None:
92-
if len(colptrs) > 0:
93-
colptr = colptr[1:]
94-
colptrs.append(colptr + nnz)
106+
rowcount: Optional[torch.Tensor] = None
107+
if len(rowcounts) == len(tensors):
108+
rowcount = torch.cat(rowcounts, dim=0)
95109

96-
colcount = tensor.storage._colcount
97-
if colcount is not None:
98-
colcounts.append(colcount)
110+
storage = SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
111+
sparse_sizes=(sparse_sizes[0], sparse_sizes[1]),
112+
rowcount=rowcount, colptr=None, colcount=None,
113+
csr2csc=None, csc2csr=None, is_sorted=True)
114+
return tensors[0].from_storage(storage)
99115

100-
sparse_sizes[0] = max(sparse_sizes[0], tensor.sparse_size(0))
101-
sparse_sizes[1] += tensor.sparse_size(1)
102-
nnz += tensor.nnz()
103116

104-
row = torch.cat(rows, dim=0)
117+
def cat_second(tensors: List[SparseTensor]) -> SparseTensor:
118+
rows: List[torch.Tensor] = []
119+
cols: List[torch.Tensor] = []
120+
values: List[torch.Tensor] = []
121+
sparse_sizes: List[int] = [0, 0]
122+
colptrs: List[torch.Tensor] = []
123+
colcounts: List[torch.Tensor] = []
124+
125+
nnz: int = 0
126+
for tensor in tensors:
127+
row, col, value = tensor.coo()
128+
rows.append(row)
129+
cols.append(tensor.storage._col + sparse_sizes[1])
130+
131+
if value is not None:
132+
values.append(value)
105133

106-
col = torch.cat(cols, dim=0)
134+
colptr = tensor.storage._colptr
135+
if colptr is not None:
136+
colptrs.append(colptr[1:] + nnz if len(colptrs) > 0 else colptr)
107137

108-
value: Optional[torch.Tensor] = None
109-
if len(values) == len(tensors):
110-
value = torch.cat(values, dim=0)
138+
colcount = tensor.storage._colcount
139+
if colcount is not None:
140+
colcounts.append(colcount)
111141

112-
colptr: Optional[torch.Tensor] = None
113-
if len(colptrs) == len(tensors):
114-
colptr = torch.cat(colptrs, dim=0)
142+
sparse_sizes[0] = max(sparse_sizes[0], tensor.sparse_size(0))
143+
sparse_sizes[1] += tensor.sparse_size(1)
144+
nnz += tensor.nnz()
115145

116-
colcount: Optional[torch.Tensor] = None
117-
if len(colcounts) == len(tensors):
118-
colcount = torch.cat(colcounts, dim=0)
146+
row = torch.cat(rows, dim=0)
147+
col = torch.cat(cols, dim=0)
119148

120-
storage = SparseStorage(row=row, rowptr=None, col=col, value=value,
121-
sparse_sizes=sparse_sizes, rowcount=None,
122-
colptr=colptr, colcount=colcount, csr2csc=None,
123-
csc2csr=None, is_sorted=False)
124-
return tensors[0].from_storage(storage)
149+
value: Optional[torch.Tensor] = None
150+
if len(values) == len(tensors):
151+
value = torch.cat(values, dim=0)
125152

126-
elif dim > 1 and dim < tensors[0].dim():
127-
values: List[torch.Tensor] = []
128-
for tensor in tensors:
129-
value = tensor.storage.value()
130-
if value is not None:
131-
values.append(value)
153+
colptr: Optional[torch.Tensor] = None
154+
if len(colptrs) == len(tensors):
155+
colptr = torch.cat(colptrs, dim=0)
132156

133-
value: Optional[torch.Tensor] = None
134-
if len(values) == len(tensors):
135-
value = torch.cat(values, dim=dim - 1)
157+
colcount: Optional[torch.Tensor] = None
158+
if len(colcounts) == len(tensors):
159+
colcount = torch.cat(colcounts, dim=0)
136160

137-
return tensors[0].set_value(value, layout='coo')
138-
else:
139-
raise IndexError(
140-
(f'Dimension out of range: Expected to be in range of '
141-
f'[{-tensors[0].dim()}, {tensors[0].dim() - 1}], but got {dim}.'))
161+
storage = SparseStorage(row=row, rowptr=None, col=col, value=value,
162+
sparse_sizes=(sparse_sizes[0], sparse_sizes[1]),
163+
rowcount=None, colptr=colptr, colcount=colcount,
164+
csr2csc=None, csc2csr=None, is_sorted=False)
165+
return tensors[0].from_storage(storage)
142166

143167

144168
def cat_diag(tensors: List[SparseTensor]) -> SparseTensor:
@@ -163,9 +187,7 @@ def cat_diag(tensors: List[SparseTensor]) -> SparseTensor:
163187

164188
rowptr = tensor.storage._rowptr
165189
if rowptr is not None:
166-
if len(rowptrs) > 0:
167-
rowptr = rowptr[1:]
168-
rowptrs.append(rowptr + nnz)
190+
rowptrs.append(rowptr[1:] + nnz if len(rowptrs) > 0 else rowptr)
169191

170192
cols.append(tensor.storage._col + sparse_sizes[1])
171193

@@ -179,9 +201,7 @@ def cat_diag(tensors: List[SparseTensor]) -> SparseTensor:
179201

180202
colptr = tensor.storage._colptr
181203
if colptr is not None:
182-
if len(colptrs) > 0:
183-
colptr = colptr[1:]
184-
colptrs.append(colptr + nnz)
204+
colptrs.append(colptr[1:] + nnz if len(colptrs) > 0 else colptr)
185205

186206
colcount = tensor.storage._colcount
187207
if colcount is not None:
@@ -234,7 +254,8 @@ def cat_diag(tensors: List[SparseTensor]) -> SparseTensor:
234254
csc2csr = torch.cat(csc2csrs, dim=0)
235255

236256
storage = SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
237-
sparse_sizes=sparse_sizes, rowcount=rowcount,
238-
colptr=colptr, colcount=colcount, csr2csc=csr2csc,
257+
sparse_sizes=(sparse_sizes[0], sparse_sizes[1]),
258+
rowcount=rowcount, colptr=colptr,
259+
colcount=colcount, csr2csc=csr2csc,
239260
csc2csr=csc2csr, is_sorted=True)
240261
return tensors[0].from_storage(storage)

0 commit comments

Comments
 (0)