@@ -48,6 +48,7 @@ LinearSolve.@concrete mutable struct DualLinearCache{DT}
48
48
# Cached intermediate values for calculations
49
49
rhs_list
50
50
dual_u0_cache
51
+ primal_u_cache
51
52
primal_b_cache
52
53
53
54
dual_A
@@ -60,6 +61,7 @@ function linearsolve_forwarddiff_solve!(cache::DualLinearCache, alg, args...; kw
60
61
cache. dual_u0_cache .= cache. linear_cache. u
61
62
sol = solve! (cache. linear_cache, alg, args... ; kwargs... )
62
63
64
+ cache. primal_u_cache .= cache. linear_cache. u
63
65
cache. primal_b_cache .= cache. linear_cache. b
64
66
uu = sol. u
65
67
@@ -77,12 +79,13 @@ function linearsolve_forwarddiff_solve!(cache::DualLinearCache, alg, args...; kw
77
79
cache. linear_cache. u .= cache. dual_u0_cache
78
80
# We can reuse the linear cache, because the same factorization will work for the partials.
79
81
for i in eachindex (rhs_list)
80
- cache. linear_cache. b . = rhs_list[i]
82
+ cache. linear_cache. b = copy ( rhs_list[i])
81
83
rhs_list[i] .= solve! (cache. linear_cache, alg, args... ; kwargs... ). u
82
84
end
83
85
84
86
# Reset to the original `b` and `u`, users will expect that `b` doesn't change if they don't tell it to
85
87
cache. linear_cache. b .= cache. primal_b_cache
88
+ cache. linear_cache. u .= cache. primal_u_cache
86
89
87
90
return primal_sol
88
91
end
@@ -240,6 +243,7 @@ function __dual_init(
240
243
rhs_list,
241
244
similar (new_b),
242
245
similar (new_b),
246
+ similar (new_b),
243
247
A,
244
248
b,
245
249
zeros (dual_type, length (b))
0 commit comments