|
| 1 | +from typing import Callable, Any |
| 2 | + |
| 3 | +import torch |
| 4 | +from torch import nn, Tensor |
| 5 | +from jaxtyping import Float |
| 6 | + |
| 7 | +from .gradoptorch import optimizer, OptimLog, default_opt_settings, default_ls_settings |
| 8 | + |
| 9 | + |
| 10 | +def update_model_params(model: nn.Module, new_params: Tensor) -> None: |
| 11 | + pointer = 0 |
| 12 | + for param in model.parameters(): |
| 13 | + num_param = param.numel() |
| 14 | + param.data = new_params[pointer : pointer + num_param].view_as(param).data |
| 15 | + pointer += num_param |
| 16 | + param.requires_grad = True |
| 17 | + |
| 18 | + |
| 19 | +def optimize_module( |
| 20 | + model: nn.Module, |
| 21 | + f: Callable[[nn.Module], Float[Tensor, ""]], |
| 22 | + opt_method: str = "conj_grad_pr", |
| 23 | + opt_params: dict[str, Any] = default_opt_settings, |
| 24 | + ls_method: str = "back_tracking", |
| 25 | + ls_params: dict[str, Any] = default_ls_settings, |
| 26 | +) -> OptimLog: |
| 27 | + """ |
| 28 | + Optimizes the parameters of a given nn.Module using classical optimizer. |
| 29 | +
|
| 30 | + See the `optimizer` function for more details on the optimization methods |
| 31 | + available. |
| 32 | +
|
| 33 | + INPUTS: |
| 34 | + model < nn.Module > : The torch.nn model to be optimized. |
| 35 | + f < Callable[[nn.Module], Float[Tensor, ""]] > : The loss function to be minimized |
| 36 | + opt_method < str > : The optimization method to be used. |
| 37 | + opt_params < dict[str, Any] > : The parameters to be used for the optimization method. |
| 38 | + ls_method < str > The line search method to be used. |
| 39 | + ls_params < dict[str, Any] > : The parameters to be used for the line search method. |
| 40 | +
|
| 41 | + OptimLog: the log of the optimization process. |
| 42 | + """ |
| 43 | + if opt_method == "newton_exact": |
| 44 | + raise NotImplementedError( |
| 45 | + "Exact Newton's method is not implemented for optimize_module." |
| 46 | + ) |
| 47 | + |
| 48 | + # Flatten the model parameters and use them as an initial guess if not provided |
| 49 | + params = torch.cat([param.view(-1) for param in model.parameters()]) |
| 50 | + |
| 51 | + def f_wrapper(params: Float[Tensor, " d"]) -> Float[Tensor, ""]: |
| 52 | + update_model_params(model, params) |
| 53 | + return f(model) |
| 54 | + |
| 55 | + def grad_wrapper(params: Float[Tensor, " d"]) -> Float[Tensor, " d"]: |
| 56 | + update_model_params(model, params) |
| 57 | + model.zero_grad() |
| 58 | + with torch.enable_grad(): |
| 59 | + loss = f(model) |
| 60 | + loss.backward() |
| 61 | + return torch.cat( |
| 62 | + [ |
| 63 | + param.grad.view(-1) |
| 64 | + if param.grad is not None |
| 65 | + else torch.zeros_like(param).view(-1) |
| 66 | + for param in model.parameters() |
| 67 | + ] |
| 68 | + ) |
| 69 | + |
| 70 | + final_params, hist = optimizer( |
| 71 | + f=f_wrapper, |
| 72 | + x_guess=params, |
| 73 | + g=grad_wrapper, |
| 74 | + opt_method=opt_method, |
| 75 | + opt_params=opt_params, |
| 76 | + ls_method=ls_method, |
| 77 | + ls_params=ls_params, |
| 78 | + ) |
| 79 | + update_model_params(model, final_params) |
| 80 | + return hist |
0 commit comments