Skip to content

Commit c902d61

Browse files
committed
fix nested Duals
1 parent e13012f commit c902d61

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

ext/LinearSolveForwardDiffExt.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ LinearSolve.@concrete mutable struct DualLinearCache{DT}
4848
# Cached intermediate values for calculations
4949
rhs_list
5050
dual_u0_cache
51+
primal_u_cache
5152
primal_b_cache
5253

5354
dual_A
@@ -60,6 +61,7 @@ function linearsolve_forwarddiff_solve!(cache::DualLinearCache, alg, args...; kw
6061
cache.dual_u0_cache .= cache.linear_cache.u
6162
sol = solve!(cache.linear_cache, alg, args...; kwargs...)
6263

64+
cache.primal_u_cache .= cache.linear_cache.u
6365
cache.primal_b_cache .= cache.linear_cache.b
6466
uu = sol.u
6567

@@ -77,12 +79,13 @@ function linearsolve_forwarddiff_solve!(cache::DualLinearCache, alg, args...; kw
7779
cache.linear_cache.u .= cache.dual_u0_cache
7880
# We can reuse the linear cache, because the same factorization will work for the partials.
7981
for i in eachindex(rhs_list)
80-
cache.linear_cache.b .= rhs_list[i]
82+
cache.linear_cache.b = copy(rhs_list[i])
8183
rhs_list[i] .= solve!(cache.linear_cache, alg, args...; kwargs...).u
8284
end
8385

8486
# Reset to the original `b` and `u`, users will expect that `b` doesn't change if they don't tell it to
8587
cache.linear_cache.b .= cache.primal_b_cache
88+
cache.linear_cache.u .= cache.primal_u_cache
8689

8790
return primal_sol
8891
end
@@ -240,6 +243,7 @@ function __dual_init(
240243
rhs_list,
241244
similar(new_b),
242245
similar(new_b),
246+
similar(new_b),
243247
A,
244248
b,
245249
zeros(dual_type, length(b))

0 commit comments

Comments
 (0)