@@ -55,21 +55,16 @@ LinearSolve.@concrete mutable struct DualLinearCache{DT}
55
55
dual_u
56
56
end
57
57
58
- function linearsolve_forwarddiff_solve (cache:: DualLinearCache , alg, args... ; kwargs... )
58
+ function linearsolve_forwarddiff_solve! (cache:: DualLinearCache , alg, args... ; kwargs... )
59
59
# Solve the primal problem
60
60
cache. dual_u0_cache .= cache. linear_cache. u
61
61
sol = solve! (cache. linear_cache, alg, args... ; kwargs... )
62
62
63
63
cache. primal_b_cache .= cache. linear_cache. b
64
64
uu = sol. u
65
65
66
- primal_sol = (;
67
- u = recursivecopy (sol. u),
68
- resid = recursivecopy (sol. resid),
69
- retcode = recursivecopy (sol. retcode),
70
- iters = recursivecopy (sol. iters),
71
- stats = recursivecopy (sol. stats)
72
- )
66
+ # Store solution metadata without copying - we'll return this
67
+ primal_sol = sol
73
68
74
69
# Solves Dual partials separately
75
70
∂_A = cache. partials_A
@@ -89,9 +84,7 @@ function linearsolve_forwarddiff_solve(cache::DualLinearCache, alg, args...; kwa
89
84
# Reset to the original `b` and `u`, users will expect that `b` doesn't change if they don't tell it to
90
85
cache. linear_cache. b .= cache. primal_b_cache
91
86
92
- partial_sols = rhs_list
93
-
94
- primal_sol, partial_sols
87
+ return primal_sol
95
88
end
96
89
97
90
function xp_linsolve_rhs! (uu, ∂_A:: Union{<:Partials, <:AbstractArray{<:Partials}} ,
153
146
function linearsolve_dual_solution (u:: AbstractArray , partials,
154
147
cache:: DualLinearCache{DT} ) where {T, V, N, DT <: Dual{T,V,N} }
155
148
# Optimized in-place version that reuses cache.dual_u
156
- linearsolve_dual_solution! (cache. dual_u, u, partials)
157
- return cache. dual_u
149
+ linearsolve_dual_solution! (getfield ( cache, : dual_u) , u, partials)
150
+ return getfield ( cache, : dual_u)
158
151
end
159
152
160
153
function linearsolve_dual_solution! (dual_u:: AbstractArray{DT} , u:: AbstractArray , partials) where {T, V, N, DT <: Dual{T,V,N} }
@@ -254,23 +247,22 @@ function __dual_init(
254
247
end
255
248
256
249
function SciMLBase. solve! (cache:: DualLinearCache , args... ; kwargs... )
257
- solve! (cache, cache. alg, args... ; kwargs... )
250
+ solve! (cache, getfield ( cache, :linear_cache ) . alg, args... ; kwargs... )
258
251
end
259
252
260
253
function SciMLBase. solve! (
261
254
cache:: DualLinearCache{DT} , alg:: SciMLLinearSolveAlgorithm , args... ; kwargs... ) where {DT <: ForwardDiff.Dual }
262
- sol,
263
- partials = linearsolve_forwarddiff_solve (
264
- cache:: DualLinearCache , cache. alg, args... ; kwargs... )
265
- dual_sol = linearsolve_dual_solution (sol. u, partials, cache)
255
+ primal_sol = linearsolve_forwarddiff_solve! (
256
+ cache:: DualLinearCache , getfield (cache, :linear_cache ). alg, args... ; kwargs... )
257
+ dual_sol = linearsolve_dual_solution (getfield (cache,:linear_cache ). u, getfield (cache, :rhs_list ), cache)
266
258
267
259
# For scalars, we still need to assign since cache.dual_u might not be pre-allocated
268
- if ! (cache. dual_u isa AbstractArray)
269
- cache. dual_u = dual_sol
260
+ if ! (getfield ( cache, : dual_u) isa AbstractArray)
261
+ setfield! ( cache, : dual_u, dual_sol)
270
262
end
271
263
272
264
return SciMLBase. build_linear_solution (
273
- cache. alg, cache. dual_u, sol . resid, cache; sol . retcode, sol . iters, sol . stats
265
+ getfield ( cache, :linear_cache ) . alg, getfield ( cache, : dual_u), primal_sol . resid, cache; primal_sol . retcode, primal_sol . iters, primal_sol . stats
274
266
)
275
267
end
276
268
0 commit comments