Skip to content

Add iterator interface #745

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 18 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions docs/src/user/minimization.md
Original file line number Diff line number Diff line change
Expand Up @@ -235,3 +235,23 @@ line search errors if `initial_x` is a stationary point. Notice, that this is on
a first order check. If `initial_x` is any type of stationary point, `g_converged`
will be true. This includes local minima, saddle points, and local maxima. If `iterations` is `0`
and `g_converged` is `true`, the user needs to keep this point in mind.

## Iterator interface
For multivariable optimizations, iterator interface is provided through `Optim.optimizing`
function. Using this interface, `optimize(args...; kwargs...)` is equivalent to

```jl
let istate
iter = Optim.optimizing(args...; kwargs...)
for istate′ in iter
istate = istate′
end
Optim.OptimizationResults(iter, istate)
end
```

The iterator returned by `Optim.optimizing` yields an iterator state for each iteration
step.

Functions that can be called on the result object (e.g. `minimizer`, `iterations`; see
[Complete list of functions](@ref)) can be used on the iteration state `istate`.
29 changes: 16 additions & 13 deletions src/api.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
Base.summary(r::OptimizationResults) = summary(r.method) # might want to do more here than just return summary of the method used
_method(r::OptimizationResults) = r.method

Base.summary(r::Union{OptimizationResults, OptimIterator}) =
summary(_method(r)) # might want to do more here than just return summary of the method used
minimizer(r::OptimizationResults) = r.minimizer
minimum(r::OptimizationResults) = r.minimum
iterations(r::OptimizationResults) = r.iterations
Expand Down Expand Up @@ -35,24 +38,24 @@ end
x_upper_trace(r::MultivariateOptimizationResults) =
error("x_upper_trace is not implemented for $(summary(r)).")

function x_trace(r::MultivariateOptimizationResults)
function x_trace(r::Union{MultivariateOptimizationResults, IteratorState})
tr = trace(r)
if isa(r.method, NelderMead)
if isa(_method(r), NelderMead)
throw(
ArgumentError(
"Nelder Mead does not operate with a single x. Please use either centroid_trace(...) or simplex_trace(...) to extract the relevant points from the trace.",
),
)
end
!haskey(tr[1].metadata, "x") && error(
"Trace does not contain x. To get a trace of x, run optimize() with extended_trace = true",
"Trace does not contain x. To get a trace of x, run optimize() with extended_trace = true and make sure x is stored in the trace for the method of choice.",
)
[state.metadata["x"] for state in tr]
end

function centroid_trace(r::MultivariateOptimizationResults)
function centroid_trace(r::Union{MultivariateOptimizationResults, IteratorState})
tr = trace(r)
if !isa(r.method, NelderMead)
Copy link
Contributor Author

@tkf tkf Sep 10, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose tr = trace(r) should be added here? I think it'll throw UndefVarError otherwise. There are two more places I did this change.

I'm including these changes (adding tr = trace(r)) although they are not directly related to PR.

if !isa(_method(r), NelderMead)
throw(
ArgumentError(
"There is no centroid involved in optimization using $(r.method). Please use x_trace(...) to grab the points from the trace.",
Expand All @@ -64,9 +67,9 @@ function centroid_trace(r::MultivariateOptimizationResults)
)
[state.metadata["centroid"] for state in tr]
end
function simplex_trace(r::MultivariateOptimizationResults)
function simplex_trace(r::Union{MultivariateOptimizationResults, IteratorState})
tr = trace(r)
if !isa(r.method, NelderMead)
if !isa(_method(r), NelderMead)
throw(
ArgumentError(
"There is no simplex involved in optimization using $(r.method). Please use x_trace(...) to grab the points from the trace.",
Expand All @@ -78,9 +81,9 @@ function simplex_trace(r::MultivariateOptimizationResults)
)
[state.metadata["simplex"] for state in tr]
end
function simplex_value_trace(r::MultivariateOptimizationResults)
function simplex_value_trace(r::Union{MultivariateOptimizationResults, IteratorState})
tr = trace(r)
if !isa(r.method, NelderMead)
if !isa(_method(r), NelderMead)
throw(
ArgumentError(
"There are no simplex values involved in optimization using $(r.method). Please use f_trace(...) to grab the objective values from the trace.",
Expand All @@ -94,10 +97,10 @@ function simplex_value_trace(r::MultivariateOptimizationResults)
end


f_trace(r::OptimizationResults) = [state.value for state in trace(r)]
f_trace(r::Union{OptimizationResults, IteratorState}) = [state.value for state in trace(r)]
g_norm_trace(r::OptimizationResults) =
error("g_norm_trace is not implemented for $(summary(r)).")
g_norm_trace(r::MultivariateOptimizationResults) = [state.g_norm for state in trace(r)]
g_norm_trace(r::Union{MultivariateOptimizationResults, IteratorState}) = [state.g_norm for state in trace(r)]

f_calls(r::OptimizationResults) = r.f_calls
f_calls(d) = first(d.f_calls)
Expand All @@ -114,7 +117,7 @@ h_calls(d) = first(d.h_calls)
h_calls(d::TwiceDifferentiableHV) = first(d.hv_calls)

converged(r::UnivariateOptimizationResults) = r.stopped_by.converged
function converged(r::MultivariateOptimizationResults)
function converged(r::Union{MultivariateOptimizationResults, IteratorState})
conv_flags = r.stopped_by.x_converged || r.stopped_by.f_converged || r.stopped_by.g_converged
x_isfinite = isfinite(x_abschange(r)) || isnan(x_relchange(r))
f_isfinite = if r.stopped_by.iterations > 0
Expand Down
105 changes: 0 additions & 105 deletions src/deprecate.jl
Original file line number Diff line number Diff line change
@@ -1,105 +0,0 @@
Base.@deprecate method(x) summary(x)

const has_deprecated_fminbox = Ref(false)
function optimize(
df::OnceDifferentiable,
initial_x::Array{T},
l::Array{T},
u::Array{T},
::Type{Fminbox};
x_tol::T = eps(T),
f_tol::T = sqrt(eps(T)),
g_tol::T = sqrt(eps(T)),
allow_f_increases::Bool = true,
iterations::Integer = 1_000,
store_trace::Bool = false,
show_trace::Bool = false,
extended_trace::Bool = false,
show_warnings::Bool = true,
callback = nothing,
show_every::Integer = 1,
linesearch = LineSearches.HagerZhang{T}(),
eta::Real = convert(T, 0.4),
mu0::T = convert(T, NaN),
mufactor::T = convert(T, 0.001),
precondprep = (P, x, l, u, mu) -> precondprepbox!(P, x, l, u, mu),
optimizer = ConjugateGradient,
optimizer_o = Options(
store_trace = store_trace,
show_trace = show_trace,
extended_trace = extended_trace,
show_warnings = show_warnings,
),
nargs...,
) where {T<:AbstractFloat}
if !has_deprecated_fminbox[]
@warn(
"Fminbox with the optimizer keyword is deprecated, construct Fminbox{optimizer}() and pass it to optimize(...) instead."
)
has_deprecated_fminbox[] = true
end
optimize(
df,
initial_x,
l,
u,
Fminbox{optimizer}();
allow_f_increases = allow_f_increases,
iterations = iterations,
store_trace = store_trace,
show_trace = show_trace,
extended_trace = extended_trace,
show_warnings = show_warnings,
show_every = show_every,
callback = callback,
linesearch = linesearch,
eta = eta,
mu0 = mu0,
mufactor = mufactor,
precondprep = precondprep,
optimizer_o = optimizer_o,
)
end

function optimize(::AbstractObjective)
throw(
ErrorException(
"Optimizing an objective `obj` without providing an initial `x` has been deprecated without backwards compatability. Please explicitly provide an `x`: `optimize(obj, x)``",
),
)
end
function optimize(::AbstractObjective, ::Method)
throw(
ErrorException(
"Optimizing an objective `obj` without providing an initial `x` has been deprecated without backwards compatability. Please explicitly provide an `x`: `optimize(obj, x, method)``",
),
)
end
function optimize(::AbstractObjective, ::Method, ::Options)
throw(
ErrorException(
"Optimizing an objective `obj` without providing an initial `x` has been deprecated without backwards compatability. Please explicitly provide an `x`: `optimize(obj, x, method, options)``",
),
)
end
function optimize(::AbstractObjective, ::Options)
throw(
ErrorException(
"Optimizing an objective `obj` without providing an initial `x` has been deprecated without backwards compatability. Please explicitly provide an `x`: `optimize(obj, x, options)``",
),
)
end

function optimize(
df::OnceDifferentiable,
l::Array{T},
u::Array{T},
F::Fminbox{O};
kwargs...,
) where {T<:AbstractFloat,O<:AbstractOptimizer}
throw(
ErrorException(
"Optimizing an objective `obj` without providing an initial `x` has been deprecated without backwards compatability. Please explicitly provide an `x`: `optimize(obj, x, l, u, method, options)``",
),
)
end
60 changes: 34 additions & 26 deletions src/multivariate/optimize/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@
) = td

# if no method or options are present
function optimize(
function optimizing(
f,
initial_x::AbstractArray;
inplace = true,
Expand All @@ -173,9 +173,9 @@
add_default_opts!(checked_kwargs, method)

options = Options(; checked_kwargs...)
optimize(d, initial_x, method, options)
optimizing(d, initial_x, method, options)
end
function optimize(
function optimizing(
f,
g,
initial_x::AbstractArray;
Expand All @@ -190,9 +190,9 @@
add_default_opts!(checked_kwargs, method)

options = Options(; checked_kwargs...)
optimize(d, initial_x, method, options)
optimizing(d, initial_x, method, options)
end
function optimize(
function optimizing(
f,
g,
h,
Expand All @@ -208,19 +208,15 @@
add_default_opts!(checked_kwargs, method)

options = Options(; checked_kwargs...)
optimize(d, initial_x, method, options)
optimizing(d, initial_x, method, options)
end

# no method supplied with objective
function optimize(
d::T,
initial_x::AbstractArray,
options::Options,
) where {T<:AbstractObjective}
optimize(d, initial_x, fallback_method(d), options)
function optimizing(d::AbstractObjective, initial_x::AbstractArray, options::Options)
optimizing(d, initial_x, fallback_method(d), options)
end
# no method supplied with inplace and autodiff keywords becauase objective is not supplied
function optimize(
function optimizing(
f,
initial_x::AbstractArray,
options::Options;
Expand All @@ -229,9 +225,9 @@
)
method = fallback_method(f)
d = promote_objtype(method, initial_x, autodiff, inplace, f)
optimize(d, initial_x, method, options)
optimizing(d, initial_x, method, options)
end
function optimize(
function optimizing(

Check warning on line 230 in src/multivariate/optimize/interface.jl

View check run for this annotation

Codecov / codecov/patch

src/multivariate/optimize/interface.jl#L230

Added line #L230 was not covered by tests
f,
g,
initial_x::AbstractArray,
Expand All @@ -242,9 +238,9 @@

method = fallback_method(f, g)
d = promote_objtype(method, initial_x, autodiff, inplace, f, g)
optimize(d, initial_x, method, options)
optimizing(d, initial_x, method, options)

Check warning on line 241 in src/multivariate/optimize/interface.jl

View check run for this annotation

Codecov / codecov/patch

src/multivariate/optimize/interface.jl#L241

Added line #L241 was not covered by tests
end
function optimize(
function optimizing(

Check warning on line 243 in src/multivariate/optimize/interface.jl

View check run for this annotation

Codecov / codecov/patch

src/multivariate/optimize/interface.jl#L243

Added line #L243 was not covered by tests
f,
g,
h,
Expand All @@ -257,11 +253,11 @@
method = fallback_method(f, g, h)
d = promote_objtype(method, initial_x, autodiff, inplace, f, g, h)

optimize(d, initial_x, method, options)
optimizing(d, initial_x, method, options)

Check warning on line 256 in src/multivariate/optimize/interface.jl

View check run for this annotation

Codecov / codecov/patch

src/multivariate/optimize/interface.jl#L256

Added line #L256 was not covered by tests
end

# potentially everything is supplied (besides caches)
function optimize(
function optimizing(
f,
initial_x::AbstractArray,
method::AbstractOptimizer,
Expand All @@ -271,7 +267,7 @@
)

d = promote_objtype(method, initial_x, autodiff, inplace, f)
optimize(d, initial_x, method, options)
optimizing(d, initial_x, method, options)
end
function optimize(
f,
Expand All @@ -286,7 +282,7 @@
d = promote_objtype(method, initial_x, autodiff, inplace, f)
optimize(d, c, initial_x, method, options)
end
function optimize(
function optimizing(
f,
g,
initial_x::AbstractArray,
Expand All @@ -298,9 +294,9 @@

d = promote_objtype(method, initial_x, autodiff, inplace, f, g)

optimize(d, initial_x, method, options)
optimizing(d, initial_x, method, options)
end
function optimize(
function optimizing(
f,
g,
h,
Expand All @@ -313,17 +309,29 @@

d = promote_objtype(method, initial_x, autodiff, inplace, f, g, h)

optimize(d, initial_x, method, options)
optimizing(d, initial_x, method, options)
end

function optimize(
function optimizing(
d::D,
initial_x::AbstractArray,
method::SecondOrderOptimizer,
options::Options = Options(; default_options(method)...);
autodiff = :finite,
inplace = true,
) where {D<:Union{NonDifferentiable,OnceDifferentiable}}

d = promote_objtype(method, initial_x, autodiff, inplace, d)
optimize(d, initial_x, method, options)
optimizing(d, initial_x, method, options)
end

function optimize(args...; kwargs...)
local istate
iter = optimizing(args...; kwargs...)
for istate′ in iter
istate = istate′
end
# We can safely assume that `istate` is defined at this point. That is to say,
# `OptimIterator` guarantees that `iterate(::OptimIterator) !== nothing`.
return OptimizationResults(iter, istate)
end
Loading
Loading