Skip to content

Commit c232153

Browse files
committed
docs
1 parent 53abd36 commit c232153

File tree

5 files changed

+66
-10
lines changed

5 files changed

+66
-10
lines changed

README.md

Lines changed: 53 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@
1616
This package consists of a small extension library of optimized sparse matrix operations for the use in [PyTorch](http://pytorch.org/), which are missing and or lack autograd support in the main package.
1717
This package currently consists of the following methods:
1818

19-
* **[Autograd Sparse Tensor Creation](#Autograd Sparse Tensor Creation)**
20-
* **[Autograd Sparse Tensor Value Extraction](#Autograd Sparse Tensor Value Extraction)**
21-
* **[Sparse Sparse Matrix Multiplication](#Sparse Sparse Matrix Multiplication)**
19+
* **[Autograd Sparse Tensor Creation](#autograd-sparse-tensor-creation)**
20+
* **[Autograd Sparse Tensor Value Extraction](#autograd-sparse-tensor-value-extraction)**
21+
* **[Sparse Sparse Matrix Multiplication](#sparse-sparse-matrix-multiplication)**
2222

2323
All included operations work on varying data types and are implemented both for CPU and GPU.
2424

@@ -47,10 +47,60 @@ If you are running into any installation problems, please follow these [instruct
4747

4848
## Autograd Sparse Tensor Creation
4949

50+
```
51+
torch_sparse.sparse_coo_tensor(torch.LongTensor, torch.Tensor, torch.Size) -> torch.SparseTensor
52+
```
53+
54+
Constructs a [`torch.SparseTensor`](https://pytorch.org/docs/stable/sparse.html) with autograd capabilities w.r.t. `value`.
55+
56+
```python
57+
from torch_sparse import sparse_coo_tensor
58+
59+
i = torch.tensor([[0, 1, 1],
60+
[2, 0, 2]])
61+
v = torch.Tensor([3, 4, 5], requires_grad=True)
62+
A = sparse_coo_tensor(i, v, torch.Size([2,3]))
63+
```
64+
65+
This method may become obsolete in future PyTorch releases (>= 0.4.1) as reported by this [issue](https://github.com/pytorch/pytorch/issues/9674).
66+
5067
## Autograd Sparse Tensor Value Extraction
5168

69+
```
70+
torch_sparse.to_value(SparseTensor) --> Tensor
71+
```
72+
73+
Wrapper method to support autograd on values of sparse tensors.
74+
75+
```python
76+
from torch_sparse import to_value
77+
78+
i = torch.tensor([[0, 1, 1],
79+
[2, 0, 2]])
80+
v = torch.Tensor([3, 4, 5], requires_grad=True)
81+
A = torch.sparse_coo_tensor(i, v, torch.Size([2,3]), requires_grad=True)
82+
v = to_value(A)
83+
```
84+
85+
This method may become obsolete in future PyTorch releases (>= 0.4.1) as reported by this [issue](https://github.com/pytorch/pytorch/issues/9674).
86+
5287
## Sparse Sparse Matrix Multiplication
5388

89+
```
90+
torch_sparse.spspmm(SparseTensor, SparseTensor) --> SparseTensor
91+
```
92+
93+
Sparse matrix product of two sparse tensors with autograd support.
94+
95+
```
96+
from torch_sparse import spspmm
97+
98+
A = torch.sparse_coo_tensor(..., requries_grad=True)
99+
B = torch.sparse_coo_tensor(..., requries_grad=True)
100+
101+
C = spspmm(A, B)
102+
```
103+
54104
## Running tests
55105

56106
```

test/test_matmul.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import pytest
44
import torch
5-
from torch_sparse import SparseTensor, spspmm, to_value
5+
from torch_sparse import sparse_coo_tensor, spspmm, to_value
66

77
from .utils import dtypes, devices, tensor
88

@@ -30,13 +30,13 @@ def test_spspmm(test, dtype, device):
3030
indexA = torch.tensor(test['indexA'], device=device)
3131
valueA = tensor(test['valueA'], dtype, device, requires_grad=True)
3232
sizeA = torch.Size(test['sizeA'])
33-
A = SparseTensor(indexA, valueA, sizeA)
33+
A = sparse_coo_tensor(indexA, valueA, sizeA)
3434
denseA = A.detach().to_dense().requires_grad_()
3535

3636
indexB = torch.tensor(test['indexB'], device=device)
3737
valueB = tensor(test['valueB'], dtype, device, requires_grad=True)
3838
sizeB = torch.Size(test['sizeB'])
39-
B = SparseTensor(indexB, valueB, sizeB)
39+
B = sparse_coo_tensor(indexB, valueB, sizeB)
4040
denseB = B.detach().to_dense().requires_grad_()
4141

4242
C = spspmm(A, B)

torch_sparse/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
from .sparse import SparseTensor, to_value
1+
from .sparse import sparse_coo_tensor, to_value
22
from .matmul import spspmm
33

44
__all__ = [
5-
'SparseTensor',
5+
'sparse_coo_tensor',
66
'to_value',
77
'spspmm',
88
]

torch_sparse/matmul.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77

88
class SpSpMM(torch.autograd.Function):
9+
"""Sparse matrix product of two sparse tensors with autograd support."""
10+
911
@staticmethod
1012
def forward(ctx, A, B):
1113
ctx.save_for_backward(A, B)

torch_sparse/sparse.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import torch
22

33

4-
class _SparseTensor(torch.autograd.Function):
4+
class SparseCooTensor(torch.autograd.Function):
5+
"""Constructs Sparse matrix with autograd capabilities w.r.t. to value."""
6+
57
@staticmethod
68
def forward(ctx, index, value, size):
79
ctx.size = size
@@ -26,10 +28,12 @@ def backward(ctx, grad_out):
2628
return None, grad_in, None
2729

2830

29-
SparseTensor = _SparseTensor.apply
31+
sparse_coo_tensor = SparseCooTensor.apply
3032

3133

3234
class ToValue(torch.autograd.Function):
35+
"""Extract values of sparse tensors with autograd support."""
36+
3337
@staticmethod
3438
def forward(ctx, A):
3539
ctx.save_for_backward(A)

0 commit comments

Comments
 (0)