Skip to content

Commit b1853b6

Browse files
authored
Merge pull request #61 from mariogeiger/view
view
2 parents 93540a3 + 0e2ddfa commit b1853b6

File tree

3 files changed

+50
-0
lines changed

3 files changed

+50
-0
lines changed

test/test_storage.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,3 +122,24 @@ def test_coalesce(dtype, device):
122122
assert storage.row().tolist() == [0, 0, 1, 1]
123123
assert storage.col().tolist() == [0, 1, 0, 1]
124124
assert storage.value().tolist() == [1, 2, 3, 4]
125+
126+
127+
@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
128+
def test_sparse_reshape(dtype, device):
129+
row, col = tensor([[0, 1, 2, 3], [0, 1, 2, 3]], torch.long, device)
130+
storage = SparseStorage(row=row, col=col)
131+
132+
storage = storage.sparse_reshape(2, 8)
133+
assert storage.sparse_sizes() == (2, 8)
134+
assert storage.row().tolist() == [0, 0, 1, 1]
135+
assert storage.col().tolist() == [0, 5, 2, 7]
136+
137+
storage = storage.sparse_reshape(-1, 4)
138+
assert storage.sparse_sizes() == (4, 4)
139+
assert storage.row().tolist() == [0, 1, 2, 3]
140+
assert storage.col().tolist() == [0, 1, 2, 3]
141+
142+
storage = storage.sparse_reshape(2, -1)
143+
assert storage.sparse_sizes() == (2, 8)
144+
assert storage.row().tolist() == [0, 0, 1, 1]
145+
assert storage.col().tolist() == [0, 5, 2, 7]

torch_sparse/storage.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,31 @@ def sparse_resize(self, sparse_sizes: Tuple[int, int]):
260260
colcount=colcount, csr2csc=self._csr2csc,
261261
csc2csr=self._csc2csr, is_sorted=True)
262262

263+
def sparse_reshape(self, num_rows: int, num_cols: int):
264+
assert num_rows > 0 or num_rows == -1
265+
assert num_cols > 0 or num_cols == -1
266+
assert num_rows > 0 or num_cols > 0
267+
268+
total = self.sparse_size(0) * self.sparse_size(1)
269+
270+
if num_rows == -1:
271+
num_rows = total // num_cols
272+
273+
if num_cols == -1:
274+
num_cols = total // num_rows
275+
276+
assert num_rows * num_cols == total
277+
278+
idx = self.sparse_size(1) * self.row() + self.col()
279+
280+
row = idx / num_cols
281+
col = idx % num_cols
282+
283+
return SparseStorage(row=row, rowptr=None, col=col, value=self._value,
284+
sparse_sizes=(num_rows, num_cols), rowcount=None,
285+
colptr=None, colcount=None, csr2csc=None,
286+
csc2csr=None, is_sorted=True)
287+
263288
def has_rowcount(self) -> bool:
264289
return self._rowcount is not None
265290

torch_sparse/tensor.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,10 @@ def sparse_size(self, dim: int) -> int:
171171
def sparse_resize(self, sparse_sizes: Tuple[int, int]):
172172
return self.from_storage(self.storage.sparse_resize(sparse_sizes))
173173

174+
def sparse_reshape(self, num_rows: int, num_cols: int):
175+
return self.from_storage(
176+
self.storage.sparse_reshape(num_rows, num_cols))
177+
174178
def is_coalesced(self) -> bool:
175179
return self.storage.is_coalesced()
176180

0 commit comments

Comments
 (0)