@@ -2,64 +2,80 @@ module LinearSolveEnzymeExt
2
2
3
3
using LinearSolve
4
4
using LinearSolve. LinearAlgebra
5
- isdefined (Base, :get_extension ) ? (import Enzyme) : (import .. Enzyme)
6
-
7
- using Enzyme
8
-
9
5
using EnzymeCore
6
+ using EnzymeCore: EnzymeRules
10
7
11
- function EnzymeCore . EnzymeRules. forward (
8
+ function EnzymeRules. forward (config :: EnzymeRules.FwdConfigWidth{1} ,
12
9
func:: Const{typeof(LinearSolve.init)} , :: Type{RT} , prob:: EnzymeCore.Annotation{LP} ,
13
10
alg:: Const ; kwargs... ) where {RT, LP <: LinearSolve.LinearProblem }
14
11
@assert ! (prob isa Const)
15
12
res = func. val (prob. val, alg. val; kwargs... )
16
13
if RT <: Const
17
- return res
14
+ if EnzymeRules. needs_primal (config)
15
+ return res
16
+ else
17
+ return nothing
18
+ end
18
19
end
20
+
19
21
dres = func. val (prob. dval, alg. val; kwargs... )
20
- dres. b .= res. b == dres. b ? zero (dres. b) : dres. b
21
- dres. A .= res. A == dres. A ? zero (dres. A) : dres. A
22
- if RT <: DuplicatedNoNeed
23
- return dres
24
- elseif RT <: Duplicated
22
+
23
+ if dres. b == res. b
24
+ dres. b .= false
25
+ end
26
+ if dres. A == res. A
27
+ dres. A .= false
28
+ end
29
+
30
+ if EnzymeRules. needs_primal (config) && EnzymeRules. needs_shadow (config)
25
31
return Duplicated (res, dres)
32
+ elseif EnzymeRules. needs_shadow (config)
33
+ return dres
34
+ elseif EnzymeRules. needs_primal (config)
35
+ return res
36
+ else
37
+ return nothing
26
38
end
27
- error (" Unsupported return type $RT " )
28
39
end
29
40
30
- function EnzymeCore. EnzymeRules. forward (func:: Const{typeof(LinearSolve.solve!)} ,
41
+ function EnzymeRules. forward (
42
+ config:: EnzymeRules.FwdConfigWidth{1} , func:: Const{typeof(LinearSolve.solve!)} ,
31
43
:: Type{RT} , linsolve:: EnzymeCore.Annotation{LP} ;
32
44
kwargs... ) where {RT, LP <: LinearSolve.LinearCache }
33
45
@assert ! (linsolve isa Const)
34
46
35
47
res = func. val (linsolve. val; kwargs... )
36
48
37
49
if RT <: Const
38
- return res
50
+ if EnzymeRules. needs_primal (config)
51
+ return res
52
+ else
53
+ return nothing
54
+ end
39
55
end
40
56
if linsolve. val. alg isa LinearSolve. AbstractKrylovSubspaceMethod
41
57
error (" Algorithm $(_linsolve. alg) is currently not supported by Enzyme rules on LinearSolve.jl. Please open an issue on LinearSolve.jl detailing which algorithm is missing the adjoint handling" )
42
58
end
43
- b = deepcopy (linsolve. val. b)
44
59
45
- db = linsolve. dval. b
46
- dA = linsolve. dval. A
60
+ res = deepcopy (res) # Without this copy, the next solve will end up mutating the result
47
61
48
- linsolve. val. b = db - dA * res. u
62
+ b = linsolve. val. b
63
+ linsolve. val. b = linsolve. dval. b - linsolve. dval. A * res. u
49
64
dres = func. val (linsolve. val; kwargs... )
50
-
51
65
linsolve. val. b = b
52
66
53
- if RT <: DuplicatedNoNeed
54
- return dres
55
- elseif RT <: Duplicated
67
+ if EnzymeRules. needs_primal (config) && EnzymeRules. needs_shadow (config)
56
68
return Duplicated (res, dres)
69
+ elseif EnzymeRules. needs_shadow (config)
70
+ return dres
71
+ elseif EnzymeRules. needs_primal (config)
72
+ return res
73
+ else
74
+ return nothing
57
75
end
58
-
59
- return Duplicated (res, dres)
60
76
end
61
77
62
- function EnzymeCore . EnzymeRules. augmented_primal (
78
+ function EnzymeRules. augmented_primal (
63
79
config, func:: Const{typeof(LinearSolve.init)} ,
64
80
:: Type{RT} , prob:: EnzymeCore.Annotation{LP} , alg:: Const ;
65
81
kwargs... ) where {RT, LP <: LinearSolve.LinearProblem }
@@ -94,10 +110,10 @@ function EnzymeCore.EnzymeRules.augmented_primal(
94
110
(dval. b for dval in prob. dval)
95
111
end
96
112
97
- return EnzymeCore . EnzymeRules. AugmentedReturn (res, dres, (d_A, d_b, prob_d_A, prob_d_b))
113
+ return EnzymeRules. AugmentedReturn (res, dres, (d_A, d_b, prob_d_A, prob_d_b))
98
114
end
99
115
100
- function EnzymeCore . EnzymeRules. reverse (
116
+ function EnzymeRules. reverse (
101
117
config, func:: Const{typeof(LinearSolve.init)} , :: Type{RT} ,
102
118
cache, prob:: EnzymeCore.Annotation{LP} , alg:: Const ;
103
119
kwargs... ) where {RT, LP <: LinearSolve.LinearProblem }
131
147
# y=inv(A) B
132
148
# dA −= z y^T
133
149
# dB += z, where z = inv(A^T) dy
134
- function EnzymeCore . EnzymeRules. augmented_primal (
150
+ function EnzymeRules. augmented_primal (
135
151
config, func:: Const{typeof(LinearSolve.solve!)} ,
136
152
:: Type{RT} , linsolve:: EnzymeCore.Annotation{LP} ;
137
153
kwargs... ) where {RT, LP <: LinearSolve.LinearCache }
@@ -184,10 +200,10 @@ function EnzymeCore.EnzymeRules.augmented_primal(
184
200
cachesolve = deepcopy (linsolve. val)
185
201
186
202
cache = (copy (res. u), resvals, cachesolve, dAs, dbs)
187
- return EnzymeCore . EnzymeRules. AugmentedReturn (res, dres, cache)
203
+ return EnzymeRules. AugmentedReturn (res, dres, cache)
188
204
end
189
205
190
- function EnzymeCore . EnzymeRules. reverse (config, func:: Const{typeof(LinearSolve.solve!)} ,
206
+ function EnzymeRules. reverse (config, func:: Const{typeof(LinearSolve.solve!)} ,
191
207
:: Type{RT} , cache, linsolve:: EnzymeCore.Annotation{LP} ;
192
208
kwargs... ) where {RT, LP <: LinearSolve.LinearCache }
193
209
y, dys, _linsolve, dAs, dbs = cache
0 commit comments