Skip to content

Commit cb1580b

Browse files
committed
fix caching issues
1 parent e13012f commit cb1580b

File tree

2 files changed

+43
-11
lines changed

2 files changed

+43
-11
lines changed

debug_dual.jl

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
using LinearSolve
2+
using ForwardDiff
3+
using Test
4+
5+
function h(p)
6+
(A = [p[1] p[2]+1 p[2]^3;
7+
3*p[1] p[1]+5 p[2] * p[1]-4;
8+
p[2]^2 9*p[1] p[2]],
9+
b = [p[1] + 1, p[2] * 2, p[1]^2])
10+
end
11+
12+
A, b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)])
13+
14+
println("A:")
15+
display(A)
16+
println("\nb:")
17+
display(b)
18+
19+
prob = LinearProblem(A, b)
20+
overload_x_p = solve(prob)
21+
backslash_x_p = A \ b
22+
23+
println("\nExpected result (backslash):")
24+
display(backslash_x_p)
25+
println("\nLinearSolve result:")
26+
display(overload_x_p)
27+
28+
println("\nDifference:")
29+
display(overload_x_p - backslash_x_p)

ext/LinearSolveForwardDiffExt.jl

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ LinearSolve.@concrete mutable struct DualLinearCache{DT}
4848
# Cached intermediate values for calculations
4949
rhs_list
5050
dual_u0_cache
51+
primal_u_cache
5152
primal_b_cache
5253

5354
dual_A
@@ -58,19 +59,19 @@ end
5859
function linearsolve_forwarddiff_solve!(cache::DualLinearCache, alg, args...; kwargs...)
5960
# Solve the primal problem
6061
cache.dual_u0_cache .= cache.linear_cache.u
61-
sol = solve!(cache.linear_cache, alg, args...; kwargs...)
62+
sol = solve!(cache.linear_cache, alg, args...; kwargs...)
6263

6364
cache.primal_b_cache .= cache.linear_cache.b
64-
uu = sol.u
65+
cache.primal_u_cache .= cache.linear_cache.u
6566

6667
# Store solution metadata without copying - we'll return this
6768
primal_sol = sol
6869

69-
# Solves Dual partials separately
70+
# Solves Dual partials separately
7071
∂_A = cache.partials_A
7172
∂_b = cache.partials_b
7273

73-
xp_linsolve_rhs!(uu, ∂_A, ∂_b, cache)
74+
xp_linsolve_rhs!(∂_A, ∂_b, cache)
7475

7576
rhs_list = cache.rhs_list
7677

@@ -83,11 +84,13 @@ function linearsolve_forwarddiff_solve!(cache::DualLinearCache, alg, args...; kw
8384

8485
# Reset to the original `b` and `u`, users will expect that `b` doesn't change if they don't tell it to
8586
cache.linear_cache.b .= cache.primal_b_cache
87+
cache.linear_cache.u .= cache.primal_u_cache
8688

8789
return primal_sol
8890
end
8991

90-
function xp_linsolve_rhs!(uu, ∂_A::Union{<:Partials, <:AbstractArray{<:Partials}},
92+
function xp_linsolve_rhs!(
93+
∂_A::Union{<:Partials, <:AbstractArray{<:Partials}},
9194
∂_b::Union{<:Partials, <:AbstractArray{<:Partials}}, cache::DualLinearCache)
9295

9396
# Update cached partials lists
@@ -100,14 +103,14 @@ function xp_linsolve_rhs!(uu, ∂_A::Union{<:Partials, <:AbstractArray{<:Partial
100103
# Compute rhs = b - A*uu using five-argument mul!
101104
for i in eachindex(b_list)
102105
cache.rhs_list[i] .= b_list[i]
103-
mul!(cache.rhs_list[i], A_list[i], uu, -1, 1)
106+
mul!(cache.rhs_list[i], A_list[i], cache.primal_u_cache, -1, 1)
104107
end
105108

106109
return cache.rhs_list
107110
end
108111

109112
function xp_linsolve_rhs!(
110-
uu, ∂_A::Union{<:Partials, <:AbstractArray{<:Partials}},
113+
∂_A::Union{<:Partials, <:AbstractArray{<:Partials}},
111114
∂_b::Nothing, cache::DualLinearCache)
112115

113116
# Update cached partials list for A
@@ -116,14 +119,14 @@ function xp_linsolve_rhs!(
116119

117120
# Compute rhs = -A*uu using five-argument mul!
118121
for i in eachindex(A_list)
119-
mul!(cache.rhs_list[i], A_list[i], uu, -1, 0)
122+
mul!(cache.rhs_list[i], A_list[i], cache.primal_u_cache, -1, 0)
120123
end
121124

122125
return cache.rhs_list
123126
end
124127

125128
function xp_linsolve_rhs!(
126-
uu, ∂_A::Nothing, ∂_b::Union{<:Partials, <:AbstractArray{<:Partials}},
129+
∂_A::Nothing, ∂_b::Union{<:Partials, <:AbstractArray{<:Partials}},
127130
cache::DualLinearCache)
128131

129132
# Update cached partials list for b
@@ -219,7 +222,6 @@ function __dual_init(
219222
partials_b_list = !isnothing(∂_b) ? partials_to_list(∂_b) : nothing
220223

221224
# Determine size and type for rhs_list
222-
n_partials = 0
223225
if !isnothing(partials_A_list)
224226
n_partials = length(partials_A_list)
225227
rhs_list = [similar(non_partial_cache.b) for _ in 1:n_partials]
@@ -240,6 +242,7 @@ function __dual_init(
240242
rhs_list,
241243
similar(new_b),
242244
similar(new_b),
245+
similar(new_b),
243246
A,
244247
b,
245248
zeros(dual_type, length(b))
@@ -254,7 +257,7 @@ function SciMLBase.solve!(
254257
cache::DualLinearCache{DT}, alg::SciMLLinearSolveAlgorithm, args...; kwargs...) where {DT <: ForwardDiff.Dual}
255258
primal_sol = linearsolve_forwarddiff_solve!(
256259
cache::DualLinearCache, getfield(cache, :linear_cache).alg, args...; kwargs...)
257-
dual_sol = linearsolve_dual_solution(getfield(cache,:linear_cache).u, getfield(cache, :rhs_list), cache)
260+
dual_sol = linearsolve_dual_solution(primal_sol.u, getfield(cache, :rhs_list), cache)
258261

259262
# For scalars, we still need to assign since cache.dual_u might not be pre-allocated
260263
if !(getfield(cache, :dual_u) isa AbstractArray)

0 commit comments

Comments
 (0)