Skip to content

Commit 0712183

Browse files
committed
adding the object oriented interface
1 parent e165da8 commit 0712183

File tree

7 files changed

+194
-16
lines changed

7 files changed

+194
-16
lines changed

LICENSE

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
MIT License
22

3-
Copyright (c) 2019 coursekevin
3+
Copyright (c) 2023 Kevin Course
44

55
Permission is hereby granted, free of charge, to any person obtaining a copy
66
of this software and associated documentation files (the "Software"), to deal

README.md

Lines changed: 38 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,14 @@
1212
## What is GradOpTorch?
1313

1414
GradOpTorch is a suite of classical gradient-based optimization tools for
15-
PyTorch. The toolkit includes conjugate gradients, BFGS, and some
15+
PyTorch. The toolkit includes conjugate gradients, BFGS, and some
1616
methods for line-search.
1717

1818
## Why not [torch.optim](https://pytorch.org/docs/stable/optim.html)?
1919

20-
Not every problem is high-dimensional with noisy gradients.
21-
For such problems, classical optimization techniques
22-
can be more appropriate.
20+
Not every problem is high-dimensional, nonlinear, with noisy gradients.
21+
For such problems, classical optimization techniques
22+
can be more efficient.
2323

2424
## Installation
2525

@@ -31,6 +31,40 @@ pip install gradoptorch
3131

3232
## Usage
3333

34+
There are two primary interfaces for making use of the library.
35+
36+
1. The standard PyTorch object oriented interface:
37+
38+
```python
39+
from gradoptorch import optimize_module
40+
from torch import nn
41+
42+
class MyModel(nn.Module):
43+
...
44+
45+
model = MyModule()
46+
47+
def loss_fn(model):
48+
...
49+
50+
hist = optimize_module(model, loss_fn, opt_method="bfgs", ls_method="back_tracking")
51+
```
52+
53+
2. The functional interface:
54+
55+
```python
56+
from gradoptorch import optimizer
57+
58+
def f(x):
59+
...
60+
61+
x_guess = ...
62+
63+
x_opt, hist = optimizer(f, x_guess, opt_method="conj_grad_pr", ls_method="quad_search")
64+
```
65+
66+
Newton's method is only available in the functional interface
67+
3468
### Included optimizers:
3569

3670
'grad_exact' : exact gradient optimization
@@ -44,9 +78,3 @@ pip install gradoptorch
4478
'back_tracking' : backing tracking based line-search
4579
'quad_search' : quadratic line-search
4680
'constant' : no line search, constant step size used
47-
48-
## Setup
49-
50-
```bash
51-
pip install git+https://github.com/coursekevin/gradoptorch.git
52-
```

examples/functional_example.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,6 @@ def main():
2727
x_opt, hist = optimizer(f, x, opt_method=method, ls_method="quad_search")
2828
histories.append(hist)
2929

30-
X, Y = torch.meshgrid(
31-
torch.linspace(-2.5, 2.5, 100), torch.linspace(-1.5, 3.5, 100), indexing="ij"
32-
)
33-
3430
# ---------------------------------------------------------------------------------
3531
# making some plots
3632
_, axs = plt.subplots(1, 2)
@@ -44,6 +40,9 @@ def main():
4440
ax2.plot(x_hist[:, 0], x_hist[:, 1], "x-")
4541
ax1.legend(opt_method)
4642

43+
X, Y = torch.meshgrid(
44+
torch.linspace(-2.5, 2.5, 100), torch.linspace(-1.5, 3.5, 100), indexing="ij"
45+
)
4746
ax2.contourf(X, Y, f(torch.stack([X, Y], dim=0)), 50)
4847

4948
plt.show()

examples/module_example.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import torch
2+
from torch import nn
3+
from gradoptorch import optimize_module
4+
5+
import matplotlib.pyplot as plt # type: ignore
6+
7+
dim = 2
8+
9+
torch.set_default_dtype(torch.float64)
10+
torch.manual_seed(42)
11+
12+
13+
class SomeModule(nn.Module):
14+
def __init__(self):
15+
super().__init__()
16+
self.x = nn.Parameter(torch.tensor(-0.5))
17+
self.y = nn.Parameter(torch.tensor(3.0))
18+
19+
20+
def loss_fn(model):
21+
a = 1.0
22+
b = 100.0
23+
return (a - model.x).pow(2) + b * (model.y - model.x.pow(2)).pow(2)
24+
25+
26+
def main():
27+
# ---------------------------------------------------------------------------------
28+
# optimizing using all of the different methods
29+
opt_method = ["conj_grad_pr", "conj_grad_fr", "grad_exact", "bfgs"]
30+
histories = []
31+
for method in opt_method:
32+
# using quadratic line search
33+
model = SomeModule()
34+
hist = optimize_module(
35+
model, loss_fn, opt_method=method, ls_method="quad_search"
36+
)
37+
for n, p in model.named_parameters():
38+
print(n, p)
39+
histories.append(hist)
40+
41+
# ---------------------------------------------------------------------------------
42+
# making some plots
43+
_, axs = plt.subplots(1, 2)
44+
45+
ax1 = axs[0]
46+
ax2 = axs[1]
47+
48+
for hist in histories:
49+
ax1.plot(torch.tensor(hist.f_hist).log10().detach())
50+
x_hist = torch.stack(hist.x_hist, dim=0).detach()
51+
ax2.plot(x_hist[:, 0], x_hist[:, 1], "x-")
52+
ax1.legend(opt_method)
53+
54+
X, Y = torch.meshgrid(
55+
torch.linspace(-2.5, 2.5, 100), torch.linspace(-1.5, 3.5, 100), indexing="ij"
56+
)
57+
58+
# Define test objective function
59+
def f(x):
60+
a = 1.0
61+
b = 100.0
62+
return (a - x[0]).pow(2) + b * (x[1] - x[0].pow(2)).pow(2)
63+
64+
ax2.contourf(X, Y, f(torch.stack([X, Y], dim=0)), 50)
65+
66+
plt.show()
67+
68+
69+
if __name__ == "__main__":
70+
main()

gradoptorch/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
from .gradoptorch import optimizer, default_opt_settings, default_ls_settings
2+
from .module import optimize_module

gradoptorch/gradoptorch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def grad_fn(x: Float[Tensor, " d"]) -> Float[Tensor, "d 1"]:
4848
def optimizer(
4949
f: Callable[[Float[Tensor, " d"]], Float[Tensor, ""]],
5050
x_guess: Float[Tensor, " d"],
51-
g: Optional[Callable[[Float[Tensor, " d"]], Float[Tensor, ""]]] = None,
51+
g: Optional[Callable[[Float[Tensor, " d"]], Float[Tensor, " d"]]] = None,
5252
opt_method: str = "conj_grad_pr",
5353
opt_params: dict[str, Any] = default_opt_settings,
5454
ls_method: str = "back_tracking",

gradoptorch/module.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
from typing import Callable, Any
2+
3+
import torch
4+
from torch import nn, Tensor
5+
from jaxtyping import Float
6+
7+
from .gradoptorch import optimizer, OptimLog, default_opt_settings, default_ls_settings
8+
9+
10+
def update_model_params(model: nn.Module, new_params: Tensor) -> None:
11+
pointer = 0
12+
for param in model.parameters():
13+
num_param = param.numel()
14+
param.data = new_params[pointer : pointer + num_param].view_as(param).data
15+
pointer += num_param
16+
param.requires_grad = True
17+
18+
19+
def optimize_module(
20+
model: nn.Module,
21+
f: Callable[[nn.Module], Float[Tensor, ""]],
22+
opt_method: str = "conj_grad_pr",
23+
opt_params: dict[str, Any] = default_opt_settings,
24+
ls_method: str = "back_tracking",
25+
ls_params: dict[str, Any] = default_ls_settings,
26+
) -> OptimLog:
27+
"""
28+
Optimizes the parameters of a given nn.Module using classical optimizer.
29+
30+
See the `optimizer` function for more details on the optimization methods
31+
available.
32+
33+
INPUTS:
34+
model < nn.Module > : The torch.nn model to be optimized.
35+
f < Callable[[nn.Module], Float[Tensor, ""]] > : The loss function to be minimized
36+
opt_method < str > : The optimization method to be used.
37+
opt_params < dict[str, Any] > : The parameters to be used for the optimization method.
38+
ls_method < str > The line search method to be used.
39+
ls_params < dict[str, Any] > : The parameters to be used for the line search method.
40+
41+
OptimLog: the log of the optimization process.
42+
"""
43+
if opt_method == "newton_exact":
44+
raise NotImplementedError(
45+
"Exact Newton's method is not implemented for optimize_module."
46+
)
47+
48+
# Flatten the model parameters and use them as an initial guess if not provided
49+
params = torch.cat([param.view(-1) for param in model.parameters()])
50+
51+
def f_wrapper(params: Float[Tensor, " d"]) -> Float[Tensor, ""]:
52+
update_model_params(model, params)
53+
return f(model)
54+
55+
def grad_wrapper(params: Float[Tensor, " d"]) -> Float[Tensor, " d"]:
56+
update_model_params(model, params)
57+
model.zero_grad()
58+
with torch.enable_grad():
59+
loss = f(model)
60+
loss.backward()
61+
return torch.cat(
62+
[
63+
param.grad.view(-1)
64+
if param.grad is not None
65+
else torch.zeros_like(param).view(-1)
66+
for param in model.parameters()
67+
]
68+
)
69+
70+
final_params, hist = optimizer(
71+
f=f_wrapper,
72+
x_guess=params,
73+
g=grad_wrapper,
74+
opt_method=opt_method,
75+
opt_params=opt_params,
76+
ls_method=ls_method,
77+
ls_params=ls_params,
78+
)
79+
update_model_params(model, final_params)
80+
return hist

0 commit comments

Comments
 (0)