1
- from typing import Optional , List
1
+ from typing import Optional , List , Tuple
2
2
3
3
import torch
4
4
from torch_sparse .storage import SparseStorage
5
5
from torch_sparse .tensor import SparseTensor
6
6
7
7
8
- def cat (tensors : List [SparseTensor ], dim : int ) -> SparseTensor :
8
+ @torch .jit ._overload # noqa: F811
9
+ def cat (tensors , dim ): # noqa: F811
10
+ # type: (List[SparseTensor], int) -> SparseTensor
11
+ pass
12
+
13
+
14
+ @torch .jit ._overload # noqa: F811
15
+ def cat (tensors , dim ): # noqa: F811
16
+ # type: (List[SparseTensor], Tuple[int, int]) -> SparseTensor
17
+ pass
18
+
19
+
20
+ @torch .jit ._overload # noqa: F811
21
+ def cat (tensors , dim ): # noqa: F811
22
+ # type: (List[SparseTensor], List[int]) -> SparseTensor
23
+ pass
24
+
25
+
26
+ def cat (tensors , dim ): # noqa: F811
9
27
assert len (tensors ) > 0
10
- if dim < 0 :
11
- dim = tensors [0 ].dim () + dim
12
-
13
- if dim == 0 :
14
- rows : List [torch .Tensor ] = []
15
- rowptrs : List [torch .Tensor ] = []
16
- cols : List [torch .Tensor ] = []
17
- values : List [torch .Tensor ] = []
18
- sparse_sizes : List [int ] = [0 , 0 ]
19
- rowcounts : List [torch .Tensor ] = []
20
-
21
- nnz : int = 0
22
- for tensor in tensors :
23
- row = tensor .storage ._row
24
- if row is not None :
25
- rows .append (row + sparse_sizes [0 ])
26
-
27
- rowptr = tensor .storage ._rowptr
28
- if rowptr is not None :
29
- if len (rowptrs ) > 0 :
30
- rowptr = rowptr [1 :]
31
- rowptrs .append (rowptr + nnz )
32
-
33
- cols .append (tensor .storage ._col )
34
-
35
- value = tensor .storage ._value
36
- if value is not None :
28
+
29
+ if isinstance (dim , int ):
30
+ dim = tensors [0 ].dim () + dim if dim < 0 else dim
31
+
32
+ if dim == 0 :
33
+ return cat_first (tensors )
34
+
35
+ elif dim == 1 :
36
+ return cat_second (tensors )
37
+ pass
38
+
39
+ elif dim > 1 and dim < tensors [0 ].dim ():
40
+ values = []
41
+ for tensor in tensors :
42
+ value = tensor .storage .value ()
43
+ assert value is not None
37
44
values .append (value )
45
+ value = torch .cat (values , dim = dim - 1 )
46
+ return tensors [0 ].set_value (value , layout = 'coo' )
38
47
39
- rowcount = tensor .storage ._rowcount
40
- if rowcount is not None :
41
- rowcounts .append (rowcount )
48
+ else :
49
+ raise IndexError (
50
+ (f'Dimension out of range: Expected to be in range of '
51
+ f'[{ - tensors [0 ].dim ()} , { tensors [0 ].dim () - 1 } ], but got '
52
+ f'{ dim } .' ))
53
+ else :
54
+ assert isinstance (dim , (tuple , list ))
55
+ assert len (dim ) == 2
56
+ assert sorted (dim ) == [0 , 1 ]
57
+ return cat_diag (tensors )
42
58
43
- sparse_sizes [0 ] += tensor .sparse_size (0 )
44
- sparse_sizes [1 ] = max (sparse_sizes [1 ], tensor .sparse_size (1 ))
45
- nnz += tensor .nnz ()
46
59
47
- row : Optional [torch .Tensor ] = None
48
- if len (rows ) == len (tensors ):
49
- row = torch .cat (rows , dim = 0 )
60
+ def cat_first (tensors : List [SparseTensor ]) -> SparseTensor :
61
+ rows : List [torch .Tensor ] = []
62
+ rowptrs : List [torch .Tensor ] = []
63
+ cols : List [torch .Tensor ] = []
64
+ values : List [torch .Tensor ] = []
65
+ sparse_sizes : List [int ] = [0 , 0 ]
66
+ rowcounts : List [torch .Tensor ] = []
50
67
51
- rowptr : Optional [torch .Tensor ] = None
52
- if len (rowptrs ) == len (tensors ):
53
- rowptr = torch .cat (rowptrs , dim = 0 )
68
+ nnz : int = 0
69
+ for tensor in tensors :
70
+ row = tensor .storage ._row
71
+ if row is not None :
72
+ rows .append (row + sparse_sizes [0 ])
54
73
55
- col = torch .cat (cols , dim = 0 )
74
+ rowptr = tensor .storage ._rowptr
75
+ if rowptr is not None :
76
+ rowptrs .append (rowptr [1 :] + nnz if len (rowptrs ) > 0 else rowptr )
56
77
57
- value : Optional [torch .Tensor ] = None
58
- if len (values ) == len (tensors ):
59
- value = torch .cat (values , dim = 0 )
78
+ cols .append (tensor .storage ._col )
60
79
61
- rowcount : Optional [ torch . Tensor ] = None
62
- if len ( rowcounts ) == len ( tensors ) :
63
- rowcount = torch . cat ( rowcounts , dim = 0 )
80
+ value = tensor . storage . _value
81
+ if value is not None :
82
+ values . append ( value )
64
83
65
- storage = SparseStorage (row = row , rowptr = rowptr , col = col , value = value ,
66
- sparse_sizes = sparse_sizes , rowcount = rowcount ,
67
- colptr = None , colcount = None , csr2csc = None ,
68
- csc2csr = None , is_sorted = True )
69
- return tensors [0 ].from_storage (storage )
84
+ rowcount = tensor .storage ._rowcount
85
+ if rowcount is not None :
86
+ rowcounts .append (rowcount )
70
87
71
- elif dim == 1 :
72
- rows : List [torch .Tensor ] = []
73
- cols : List [torch .Tensor ] = []
74
- values : List [torch .Tensor ] = []
75
- sparse_sizes : List [int ] = [0 , 0 ]
76
- colptrs : List [torch .Tensor ] = []
77
- colcounts : List [torch .Tensor ] = []
88
+ sparse_sizes [0 ] += tensor .sparse_size (0 )
89
+ sparse_sizes [1 ] = max (sparse_sizes [1 ], tensor .sparse_size (1 ))
90
+ nnz += tensor .nnz ()
78
91
79
- nnz : int = 0
80
- for tensor in tensors :
81
- row , col , value = tensor . coo ( )
92
+ row : Optional [ torch . Tensor ] = None
93
+ if len ( rows ) == len ( tensors ) :
94
+ row = torch . cat ( rows , dim = 0 )
82
95
83
- rows .append (row )
96
+ rowptr : Optional [torch .Tensor ] = None
97
+ if len (rowptrs ) == len (tensors ):
98
+ rowptr = torch .cat (rowptrs , dim = 0 )
84
99
85
- cols . append ( tensor . storage . _col + sparse_sizes [ 1 ] )
100
+ col = torch . cat ( cols , dim = 0 )
86
101
87
- if value is not None :
88
- values .append (value )
102
+ value : Optional [torch .Tensor ] = None
103
+ if len (values ) == len (tensors ):
104
+ value = torch .cat (values , dim = 0 )
89
105
90
- colptr = tensor .storage ._colptr
91
- if colptr is not None :
92
- if len (colptrs ) > 0 :
93
- colptr = colptr [1 :]
94
- colptrs .append (colptr + nnz )
106
+ rowcount : Optional [torch .Tensor ] = None
107
+ if len (rowcounts ) == len (tensors ):
108
+ rowcount = torch .cat (rowcounts , dim = 0 )
95
109
96
- colcount = tensor .storage ._colcount
97
- if colcount is not None :
98
- colcounts .append (colcount )
110
+ storage = SparseStorage (row = row , rowptr = rowptr , col = col , value = value ,
111
+ sparse_sizes = (sparse_sizes [0 ], sparse_sizes [1 ]),
112
+ rowcount = rowcount , colptr = None , colcount = None ,
113
+ csr2csc = None , csc2csr = None , is_sorted = True )
114
+ return tensors [0 ].from_storage (storage )
99
115
100
- sparse_sizes [0 ] = max (sparse_sizes [0 ], tensor .sparse_size (0 ))
101
- sparse_sizes [1 ] += tensor .sparse_size (1 )
102
- nnz += tensor .nnz ()
103
116
104
- row = torch .cat (rows , dim = 0 )
117
+ def cat_second (tensors : List [SparseTensor ]) -> SparseTensor :
118
+ rows : List [torch .Tensor ] = []
119
+ cols : List [torch .Tensor ] = []
120
+ values : List [torch .Tensor ] = []
121
+ sparse_sizes : List [int ] = [0 , 0 ]
122
+ colptrs : List [torch .Tensor ] = []
123
+ colcounts : List [torch .Tensor ] = []
124
+
125
+ nnz : int = 0
126
+ for tensor in tensors :
127
+ row , col , value = tensor .coo ()
128
+ rows .append (row )
129
+ cols .append (tensor .storage ._col + sparse_sizes [1 ])
130
+
131
+ if value is not None :
132
+ values .append (value )
105
133
106
- col = torch .cat (cols , dim = 0 )
134
+ colptr = tensor .storage ._colptr
135
+ if colptr is not None :
136
+ colptrs .append (colptr [1 :] + nnz if len (colptrs ) > 0 else colptr )
107
137
108
- value : Optional [ torch . Tensor ] = None
109
- if len ( values ) == len ( tensors ) :
110
- value = torch . cat ( values , dim = 0 )
138
+ colcount = tensor . storage . _colcount
139
+ if colcount is not None :
140
+ colcounts . append ( colcount )
111
141
112
- colptr : Optional [ torch . Tensor ] = None
113
- if len ( colptrs ) == len ( tensors ):
114
- colptr = torch . cat ( colptrs , dim = 0 )
142
+ sparse_sizes [ 0 ] = max ( sparse_sizes [ 0 ], tensor . sparse_size ( 0 ))
143
+ sparse_sizes [ 1 ] += tensor . sparse_size ( 1 )
144
+ nnz += tensor . nnz ( )
115
145
116
- colcount : Optional [torch .Tensor ] = None
117
- if len (colcounts ) == len (tensors ):
118
- colcount = torch .cat (colcounts , dim = 0 )
146
+ row = torch .cat (rows , dim = 0 )
147
+ col = torch .cat (cols , dim = 0 )
119
148
120
- storage = SparseStorage (row = row , rowptr = None , col = col , value = value ,
121
- sparse_sizes = sparse_sizes , rowcount = None ,
122
- colptr = colptr , colcount = colcount , csr2csc = None ,
123
- csc2csr = None , is_sorted = False )
124
- return tensors [0 ].from_storage (storage )
149
+ value : Optional [torch .Tensor ] = None
150
+ if len (values ) == len (tensors ):
151
+ value = torch .cat (values , dim = 0 )
125
152
126
- elif dim > 1 and dim < tensors [0 ].dim ():
127
- values : List [torch .Tensor ] = []
128
- for tensor in tensors :
129
- value = tensor .storage .value ()
130
- if value is not None :
131
- values .append (value )
153
+ colptr : Optional [torch .Tensor ] = None
154
+ if len (colptrs ) == len (tensors ):
155
+ colptr = torch .cat (colptrs , dim = 0 )
132
156
133
- value : Optional [torch .Tensor ] = None
134
- if len (values ) == len (tensors ):
135
- value = torch .cat (values , dim = dim - 1 )
157
+ colcount : Optional [torch .Tensor ] = None
158
+ if len (colcounts ) == len (tensors ):
159
+ colcount = torch .cat (colcounts , dim = 0 )
136
160
137
- return tensors [ 0 ]. set_value ( value , layout = 'coo' )
138
- else :
139
- raise IndexError (
140
- ( f'Dimension out of range: Expected to be in range of '
141
- f'[ { - tensors [0 ].dim () } , { tensors [ 0 ]. dim () - 1 } ], but got { dim } .' ) )
161
+ storage = SparseStorage ( row = row , rowptr = None , col = col , value = value ,
162
+ sparse_sizes = ( sparse_sizes [ 0 ], sparse_sizes [ 1 ]),
163
+ rowcount = None , colptr = colptr , colcount = colcount ,
164
+ csr2csc = None , csc2csr = None , is_sorted = False )
165
+ return tensors [0 ].from_storage ( storage )
142
166
143
167
144
168
def cat_diag (tensors : List [SparseTensor ]) -> SparseTensor :
@@ -163,9 +187,7 @@ def cat_diag(tensors: List[SparseTensor]) -> SparseTensor:
163
187
164
188
rowptr = tensor .storage ._rowptr
165
189
if rowptr is not None :
166
- if len (rowptrs ) > 0 :
167
- rowptr = rowptr [1 :]
168
- rowptrs .append (rowptr + nnz )
190
+ rowptrs .append (rowptr [1 :] + nnz if len (rowptrs ) > 0 else rowptr )
169
191
170
192
cols .append (tensor .storage ._col + sparse_sizes [1 ])
171
193
@@ -179,9 +201,7 @@ def cat_diag(tensors: List[SparseTensor]) -> SparseTensor:
179
201
180
202
colptr = tensor .storage ._colptr
181
203
if colptr is not None :
182
- if len (colptrs ) > 0 :
183
- colptr = colptr [1 :]
184
- colptrs .append (colptr + nnz )
204
+ colptrs .append (colptr [1 :] + nnz if len (colptrs ) > 0 else colptr )
185
205
186
206
colcount = tensor .storage ._colcount
187
207
if colcount is not None :
@@ -234,7 +254,8 @@ def cat_diag(tensors: List[SparseTensor]) -> SparseTensor:
234
254
csc2csr = torch .cat (csc2csrs , dim = 0 )
235
255
236
256
storage = SparseStorage (row = row , rowptr = rowptr , col = col , value = value ,
237
- sparse_sizes = sparse_sizes , rowcount = rowcount ,
238
- colptr = colptr , colcount = colcount , csr2csc = csr2csc ,
257
+ sparse_sizes = (sparse_sizes [0 ], sparse_sizes [1 ]),
258
+ rowcount = rowcount , colptr = colptr ,
259
+ colcount = colcount , csr2csc = csr2csc ,
239
260
csc2csr = csc2csr , is_sorted = True )
240
261
return tensors [0 ].from_storage (storage )
0 commit comments