@@ -48,6 +48,7 @@ LinearSolve.@concrete mutable struct DualLinearCache{DT}
48
48
# Cached intermediate values for calculations
49
49
rhs_list
50
50
dual_u0_cache
51
+ primal_u_cache
51
52
primal_b_cache
52
53
53
54
dual_A
58
59
function linearsolve_forwarddiff_solve! (cache:: DualLinearCache , alg, args... ; kwargs... )
59
60
# Solve the primal problem
60
61
cache. dual_u0_cache .= cache. linear_cache. u
61
- sol = solve! (cache. linear_cache, alg, args... ; kwargs... )
62
+ sol = solve! (cache. linear_cache, alg, args... ; kwargs... )
62
63
63
64
cache. primal_b_cache .= cache. linear_cache. b
64
- uu = sol . u
65
+ cache . primal_u_cache . = cache . linear_cache . u
65
66
66
67
# Store solution metadata without copying - we'll return this
67
68
primal_sol = sol
68
69
69
- # Solves Dual partials separately
70
+ # Solves Dual partials separately
70
71
∂_A = cache. partials_A
71
72
∂_b = cache. partials_b
72
73
73
- xp_linsolve_rhs! (uu, ∂_A, ∂_b, cache)
74
+ xp_linsolve_rhs! (∂_A, ∂_b, cache)
74
75
75
76
rhs_list = cache. rhs_list
76
77
@@ -83,11 +84,13 @@ function linearsolve_forwarddiff_solve!(cache::DualLinearCache, alg, args...; kw
83
84
84
85
# Reset to the original `b` and `u`, users will expect that `b` doesn't change if they don't tell it to
85
86
cache. linear_cache. b .= cache. primal_b_cache
87
+ cache. linear_cache. u .= cache. primal_u_cache
86
88
87
89
return primal_sol
88
90
end
89
91
90
- function xp_linsolve_rhs! (uu, ∂_A:: Union{<:Partials, <:AbstractArray{<:Partials}} ,
92
+ function xp_linsolve_rhs! (
93
+ ∂_A:: Union{<:Partials, <:AbstractArray{<:Partials}} ,
91
94
∂_b:: Union{<:Partials, <:AbstractArray{<:Partials}} , cache:: DualLinearCache )
92
95
93
96
# Update cached partials lists
@@ -100,14 +103,14 @@ function xp_linsolve_rhs!(uu, ∂_A::Union{<:Partials, <:AbstractArray{<:Partial
100
103
# Compute rhs = b - A*uu using five-argument mul!
101
104
for i in eachindex (b_list)
102
105
cache. rhs_list[i] .= b_list[i]
103
- mul! (cache. rhs_list[i], A_list[i], uu , - 1 , 1 )
106
+ mul! (cache. rhs_list[i], A_list[i], cache . primal_u_cache , - 1 , 1 )
104
107
end
105
108
106
109
return cache. rhs_list
107
110
end
108
111
109
112
function xp_linsolve_rhs! (
110
- uu, ∂_A:: Union{<:Partials, <:AbstractArray{<:Partials}} ,
113
+ ∂_A:: Union{<:Partials, <:AbstractArray{<:Partials}} ,
111
114
∂_b:: Nothing , cache:: DualLinearCache )
112
115
113
116
# Update cached partials list for A
@@ -116,14 +119,14 @@ function xp_linsolve_rhs!(
116
119
117
120
# Compute rhs = -A*uu using five-argument mul!
118
121
for i in eachindex (A_list)
119
- mul! (cache. rhs_list[i], A_list[i], uu , - 1 , 0 )
122
+ mul! (cache. rhs_list[i], A_list[i], cache . primal_u_cache , - 1 , 0 )
120
123
end
121
124
122
125
return cache. rhs_list
123
126
end
124
127
125
128
function xp_linsolve_rhs! (
126
- uu, ∂_A:: Nothing , ∂_b:: Union{<:Partials, <:AbstractArray{<:Partials}} ,
129
+ ∂_A:: Nothing , ∂_b:: Union{<:Partials, <:AbstractArray{<:Partials}} ,
127
130
cache:: DualLinearCache )
128
131
129
132
# Update cached partials list for b
@@ -219,7 +222,6 @@ function __dual_init(
219
222
partials_b_list = ! isnothing (∂_b) ? partials_to_list (∂_b) : nothing
220
223
221
224
# Determine size and type for rhs_list
222
- n_partials = 0
223
225
if ! isnothing (partials_A_list)
224
226
n_partials = length (partials_A_list)
225
227
rhs_list = [similar (non_partial_cache. b) for _ in 1 : n_partials]
@@ -240,6 +242,7 @@ function __dual_init(
240
242
rhs_list,
241
243
similar (new_b),
242
244
similar (new_b),
245
+ similar (new_b),
243
246
A,
244
247
b,
245
248
zeros (dual_type, length (b))
@@ -254,7 +257,7 @@ function SciMLBase.solve!(
254
257
cache:: DualLinearCache{DT} , alg:: SciMLLinearSolveAlgorithm , args... ; kwargs... ) where {DT <: ForwardDiff.Dual }
255
258
primal_sol = linearsolve_forwarddiff_solve! (
256
259
cache:: DualLinearCache , getfield (cache, :linear_cache ). alg, args... ; kwargs... )
257
- dual_sol = linearsolve_dual_solution (getfield (cache, :linear_cache ) . u, getfield (cache, :rhs_list ), cache)
260
+ dual_sol = linearsolve_dual_solution (primal_sol . u, getfield (cache, :rhs_list ), cache)
258
261
259
262
# For scalars, we still need to assign since cache.dual_u might not be pre-allocated
260
263
if ! (getfield (cache, :dual_u ) isa AbstractArray)
0 commit comments