Skip to content

Commit e165da8

Browse files
committed
many updates
cleaning up the code, adding some type hints, simplyfing the example
1 parent 9f571a1 commit e165da8

File tree

7 files changed

+263
-394
lines changed

7 files changed

+263
-394
lines changed

README.md

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,52 @@
1-
# gradoptorch
2-
Gradient based optimizers for python with a PyTorch backend. This package intends to allow more fine-grain control of optimizers.
1+
<div align="center">
32

4-
Note that this toolbox is meant for prototyping and built-in PyTorch optimizers will almost certainly have better performance.
3+
# GradOpTorch
4+
5+
#### Classical gradient based optimization in PyTorch.
6+
7+
</div>
8+
9+
- [Installation](#installation)
10+
- [Usage](#usage)
11+
12+
## What is GradOpTorch?
13+
14+
GradOpTorch is a suite of classical gradient-based optimization tools for
15+
PyTorch. The toolkit includes conjugate gradients, BFGS, and some
16+
methods for line-search.
17+
18+
## Why not [torch.optim](https://pytorch.org/docs/stable/optim.html)?
19+
20+
Not every problem is high-dimensional with noisy gradients.
21+
For such problems, classical optimization techniques
22+
can be more appropriate.
23+
24+
## Installation
25+
26+
GradOpTorch can be installed from PyPI:
27+
28+
```bash
29+
pip install gradoptorch
30+
```
31+
32+
## Usage
33+
34+
### Included optimizers:
535

6-
## Included optimizers:
736
'grad_exact' : exact gradient optimization
837
'conj_grad_fr' : conjugate gradient descent using Fletcher-Reeves search direction
938
'conj_grad_pr' : conjugate gradient descent using Polak-Ribiere search direction
1039
'newton_exact' : exact newton optimization
1140
'bfgs' : approximate newton optimization using bfgs
1241

13-
## Included line-search methods:
42+
### Included line-search methods:
43+
1444
'back_tracking' : backing tracking based line-search
1545
'quad_search' : quadratic line-search
1646
'constant' : no line search, constant step size used
1747

1848
## Setup
49+
1950
```bash
2051
pip install git+https://github.com/coursekevin/gradoptorch.git
2152
```

examples/functional_example.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import torch
2+
from gradoptorch import optimizer
3+
4+
import matplotlib.pyplot as plt # type: ignore
5+
6+
dim = 2
7+
8+
torch.set_default_dtype(torch.float64)
9+
torch.manual_seed(42)
10+
11+
12+
# Define test objective function
13+
def f(x):
14+
a = 1.0
15+
b = 100.0
16+
return (a - x[0]).pow(2) + b * (x[1] - x[0].pow(2)).pow(2)
17+
18+
19+
def main():
20+
# ---------------------------------------------------------------------------------
21+
# optimizing using all of the different methods
22+
opt_method = ["conj_grad_pr", "conj_grad_fr", "grad_exact", "newton_exact", "bfgs"]
23+
histories = []
24+
for method in opt_method:
25+
x = torch.tensor([-0.5, 3.0])
26+
# using quadratic line search
27+
x_opt, hist = optimizer(f, x, opt_method=method, ls_method="quad_search")
28+
histories.append(hist)
29+
30+
X, Y = torch.meshgrid(
31+
torch.linspace(-2.5, 2.5, 100), torch.linspace(-1.5, 3.5, 100), indexing="ij"
32+
)
33+
34+
# ---------------------------------------------------------------------------------
35+
# making some plots
36+
_, axs = plt.subplots(1, 2)
37+
38+
ax1 = axs[0]
39+
ax2 = axs[1]
40+
41+
for hist in histories:
42+
ax1.plot(torch.tensor(hist.f_hist).log10())
43+
x_hist = torch.stack(hist.x_hist, dim=0)
44+
ax2.plot(x_hist[:, 0], x_hist[:, 1], "x-")
45+
ax1.legend(opt_method)
46+
47+
ax2.contourf(X, Y, f(torch.stack([X, Y], dim=0)), 50)
48+
49+
plt.show()
50+
51+
52+
if __name__ == "__main__":
53+
main()

gradoptorch/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from .gradoptorch import gradoptorch
1+
from .gradoptorch import optimizer, default_opt_settings, default_ls_settings

0 commit comments

Comments
 (0)