Skip to content

Commit e13012f

Browse files
committed
use getfield for DualCache
1 parent 357579a commit e13012f

File tree

1 file changed

+13
-21
lines changed

1 file changed

+13
-21
lines changed

ext/LinearSolveForwardDiffExt.jl

Lines changed: 13 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -55,21 +55,16 @@ LinearSolve.@concrete mutable struct DualLinearCache{DT}
5555
dual_u
5656
end
5757

58-
function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwargs...)
58+
function linearsolve_forwarddiff_solve!(cache::DualLinearCache, alg, args...; kwargs...)
5959
# Solve the primal problem
6060
cache.dual_u0_cache .= cache.linear_cache.u
6161
sol = solve!(cache.linear_cache, alg, args...; kwargs...)
6262

6363
cache.primal_b_cache .= cache.linear_cache.b
6464
uu = sol.u
6565

66-
primal_sol = (;
67-
u = recursivecopy(sol.u),
68-
resid = recursivecopy(sol.resid),
69-
retcode = recursivecopy(sol.retcode),
70-
iters = recursivecopy(sol.iters),
71-
stats = recursivecopy(sol.stats)
72-
)
66+
# Store solution metadata without copying - we'll return this
67+
primal_sol = sol
7368

7469
# Solves Dual partials separately
7570
∂_A = cache.partials_A
@@ -89,9 +84,7 @@ function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwa
8984
# Reset to the original `b` and `u`, users will expect that `b` doesn't change if they don't tell it to
9085
cache.linear_cache.b .= cache.primal_b_cache
9186

92-
partial_sols = rhs_list
93-
94-
primal_sol, partial_sols
87+
return primal_sol
9588
end
9689

9790
function xp_linsolve_rhs!(uu, ∂_A::Union{<:Partials, <:AbstractArray{<:Partials}},
@@ -153,8 +146,8 @@ end
153146
function linearsolve_dual_solution(u::AbstractArray, partials,
154147
cache::DualLinearCache{DT}) where {T, V, N, DT <: Dual{T,V,N}}
155148
# Optimized in-place version that reuses cache.dual_u
156-
linearsolve_dual_solution!(cache.dual_u, u, partials)
157-
return cache.dual_u
149+
linearsolve_dual_solution!(getfield(cache, :dual_u), u, partials)
150+
return getfield(cache, :dual_u)
158151
end
159152

160153
function linearsolve_dual_solution!(dual_u::AbstractArray{DT}, u::AbstractArray, partials) where {T, V, N, DT <: Dual{T,V,N}}
@@ -254,23 +247,22 @@ function __dual_init(
254247
end
255248

256249
function SciMLBase.solve!(cache::DualLinearCache, args...; kwargs...)
257-
solve!(cache, cache.alg, args...; kwargs...)
250+
solve!(cache, getfield(cache, :linear_cache).alg, args...; kwargs...)
258251
end
259252

260253
function SciMLBase.solve!(
261254
cache::DualLinearCache{DT}, alg::SciMLLinearSolveAlgorithm, args...; kwargs...) where {DT <: ForwardDiff.Dual}
262-
sol,
263-
partials = linearsolve_forwarddiff_solve(
264-
cache::DualLinearCache, cache.alg, args...; kwargs...)
265-
dual_sol = linearsolve_dual_solution(sol.u, partials, cache)
255+
primal_sol = linearsolve_forwarddiff_solve!(
256+
cache::DualLinearCache, getfield(cache, :linear_cache).alg, args...; kwargs...)
257+
dual_sol = linearsolve_dual_solution(getfield(cache,:linear_cache).u, getfield(cache, :rhs_list), cache)
266258

267259
# For scalars, we still need to assign since cache.dual_u might not be pre-allocated
268-
if !(cache.dual_u isa AbstractArray)
269-
cache.dual_u = dual_sol
260+
if !(getfield(cache, :dual_u) isa AbstractArray)
261+
setfield!(cache, :dual_u, dual_sol)
270262
end
271263

272264
return SciMLBase.build_linear_solution(
273-
cache.alg, cache.dual_u, sol.resid, cache; sol.retcode, sol.iters, sol.stats
265+
getfield(cache, :linear_cache).alg, getfield(cache, :dual_u), primal_sol.resid, cache; primal_sol.retcode, primal_sol.iters, primal_sol.stats
274266
)
275267
end
276268

0 commit comments

Comments
 (0)