Skip to content

Commit df00674

Browse files
authored
Merge pull request #180 from KDr2/chainrules
Integrate ReverseDiff with ChainRules
2 parents dd6eb4a + 90a55ad commit df00674

File tree

6 files changed

+349
-2
lines changed

6 files changed

+349
-2
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
name = "ReverseDiff"
22
uuid = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
3-
version = "1.9.0"
3+
version = "1.10.0"
44

55
[deps]
6+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
67
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
78
DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b"
89
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
@@ -16,6 +17,7 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1617
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1718

1819
[compat]
20+
ChainRulesCore = "1"
1921
DiffResults = "1"
2022
DiffRules = "0.1, 1"
2123
ForwardDiff = "0.10"

src/ReverseDiff.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ using StaticArrays
1717

1818
using MacroTools
1919

20+
using ChainRulesCore
21+
2022
# Not all operations will be valid over all of these types, but that's okay; such cases
2123
# will simply error when they hit the original operation in the overloaded definition.
2224
const ARRAY_TYPES = (:AbstractArray, :AbstractVector, :AbstractMatrix, :Array, :Vector, :Matrix)

src/macros.jl

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,136 @@ macro grad(expr)
237237
end
238238
end |> esc
239239
end
240+
241+
"""
242+
_make_fwd_args(func, arg_list)
243+
244+
Function `_make_fwd_args` accepts a function name and an argument
245+
list, returns a tuple of argument lists whose elements are:
246+
1. the`arg_list` untouched, 2. a new argument list with the function
247+
as its first element and other elements in `arg_list` followed, 3. a
248+
new argument for the definition of function `track`, 4. a new argument
249+
list with all kwargs removed, 5, types of the arguments in the 4th
250+
element, 5 the kwargs name if any otherwise an empty tuple. E.g.:
251+
252+
_make_fwd_args(:f, [:(a::String), :(b::TrackedReal), :(args...)])
253+
254+
returns
255+
256+
([:(a::String), :(b::TrackedReal), :(args...)],
257+
[:f, :(a::String), :(b::TrackedReal), :(args...)],
258+
[:(::typeof(f)), :(a::String), :(b::TrackedReal), :(args...)],
259+
[:(a::String), :(b::TrackedReal), :(args...)],
260+
[:String, :TrackedReal, :(Vararg{Any})],
261+
:kwargs)
262+
263+
It also deals with varargs and variable keyword arguments, and ensures
264+
that at least one of the argument is tracked.
265+
266+
"""
267+
function _make_fwd_args(func, args_l)
268+
kwargs = :(())
269+
args_r = copy(args_l)
270+
args_track = copy(args_l)
271+
if Meta.isexpr(args_r[1], :parameters) # has kw args
272+
insert!(args_r, 2, func)
273+
insert!(args_track, 2, :(::typeof($func)))
274+
kwargs = gensym(:kwargs)
275+
args_track[1].args = [:($(kwargs)...)]
276+
else
277+
insert!(args_r, 1, func)
278+
insert!(args_track, 1, :(::typeof($func)))
279+
end
280+
281+
args_fixed = filter(copy(args_l)) do arg
282+
!Meta.isexpr(arg, :parameters)
283+
end
284+
285+
arg_types = map(args_fixed) do arg
286+
if Meta.isexpr(arg, :(...))
287+
Meta.isexpr(arg.args[1], :(::)) ? :(Vararg{$(arg.args[1].args[end])}) : :(Vararg{Any})
288+
elseif Meta.isexpr(arg, :(::))
289+
arg.args[end]
290+
else
291+
:Any
292+
end
293+
end
294+
295+
return args_l, args_r, args_track, args_fixed, arg_types, kwargs
296+
end
297+
298+
"""
299+
@grad_from_chainrules f(args...; kwargs...)
300+
301+
The `@grad_from_chainrules` macro provides a way to import
302+
adjoints(rrule) defined in ChainRules to ReverseDiff. One must provide
303+
a method signature to import the corresponding `rrule`. In the
304+
provided method signature, one should replace the types of arguments
305+
to which one wants to take derivatives with respect with
306+
`ReverseDiff.TrackedReal` and `ReverseDiff.TrackedArray`
307+
respectively. For example, we can import `rrule` of `f(x::Real,
308+
y::Array)` like below:
309+
310+
```julia
311+
ReverseDiff.@grad_from_chainrules f(x::TrackedReal, y::TrackedArray)
312+
ReverseDiff.@grad_from_chainrules f(x::TrackedReal, y::Array)
313+
ReverseDiff.@grad_from_chainrules f(x::Real, y::TrackedArray)
314+
```
315+
"""
316+
macro grad_from_chainrules(fcall)
317+
Meta.isexpr(fcall, :call) && length(fcall.args) >= 2 ||
318+
error("`@grad_from_chainrules` has to be applied to a function signature")
319+
f = esc(fcall.args[1])
320+
xs = fcall.args[2:end]
321+
args_l, args_r, args_track, args_fixed, arg_types, kwargs = _make_fwd_args(f, xs)
322+
323+
return quote
324+
$f($(args_l...)) = ReverseDiff.track($(args_r...))
325+
function ReverseDiff.track($(args_track...))
326+
args = ($(args_fixed...),)
327+
tp = ReverseDiff.tape(args...)
328+
output_value, back = ChainRulesCore.rrule($f, map(ReverseDiff.value, args)...; $kwargs...)
329+
output = ReverseDiff.track(output_value, tp)
330+
closure(cls_args...; cls_kwargs...) = ChainRulesCore.rrule($f, map(ReverseDiff.value, cls_args)...; cls_kwargs...)
331+
ReverseDiff.record!(
332+
tp,
333+
ReverseDiff.SpecialInstruction,
334+
$f,
335+
args,
336+
output,
337+
(back, closure, $kwargs),
338+
)
339+
return output
340+
end
341+
342+
@noinline function ReverseDiff.special_reverse_exec!(instruction::ReverseDiff.SpecialInstruction{typeof($f), <:Tuple{$(arg_types...)}})
343+
output = instruction.output
344+
input = instruction.input
345+
back = instruction.cache[1]
346+
back_output = back(ReverseDiff.deriv(output))
347+
input_derivs = back_output[2:end]
348+
@assert input_derivs isa Tuple
349+
ReverseDiff._add_to_deriv!.(input, input_derivs)
350+
ReverseDiff.unseed!(output)
351+
return nothing
352+
end
353+
354+
@noinline function ReverseDiff.special_forward_exec!(instruction::ReverseDiff.SpecialInstruction{typeof($f), <:Tuple{$(arg_types...)}})
355+
output, input = instruction.output, instruction.input
356+
ReverseDiff.pull_value!.(input)
357+
pullback = instruction.cache[2]
358+
kwargs = instruction.cache[3]
359+
out_value = pullback(input...; kwargs...)[1]
360+
ReverseDiff.value!(output, out_value)
361+
return nothing
362+
end
363+
end
364+
end
365+
240366
_add_to_deriv!(d1, d2) = nothing
367+
function _add_to_deriv!(d1::Union{TrackedReal, AbstractArray{<:TrackedReal}}, d2::AbstractThunk)
368+
increment_deriv!(d1, unthunk(d2))
369+
end
241370
function _add_to_deriv!(d1::Union{TrackedReal, AbstractArray{<:TrackedReal}}, d2)
242371
increment_deriv!(d1, d2)
243372
end

test/ChainRulesTests.jl

Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
module ChainRulesTest
2+
3+
using LinearAlgebra
4+
using ChainRulesCore
5+
using DiffResults
6+
using ReverseDiff
7+
using Test
8+
9+
f(x) = sum(4x .+ 1)
10+
11+
function ChainRulesCore.rrule(::typeof(f), x)
12+
r = f(x)
13+
function back(d)
14+
#=
15+
The proper derivative of `f` is 4, but in order to
16+
check if `ChainRulesCore.rrule` had taken over the compuation,
17+
we define a rrule that returns 3 as `f`'s derivative.
18+
19+
After importing this rrule into ReverseDiff, if we get 3
20+
rather than 4 when we compute the derivative of `f`, it means
21+
the importing mechanism works.
22+
=#
23+
return ChainRulesCore.NoTangent(), fill(3 * d, size(x))
24+
end
25+
return r, back
26+
end
27+
28+
ReverseDiff.@grad_from_chainrules f(x::ReverseDiff.TrackedArray)
29+
30+
31+
g(x, y) = sum(4x .+ 4y)
32+
33+
function ChainRulesCore.rrule(::typeof(g), x, y)
34+
r = g(x, y)
35+
function back(d)
36+
# same as above, use 3 and 5 as the derivatives
37+
return ChainRulesCore.NoTangent(), fill(3 * d, size(x)), fill(5 * d, size(x))
38+
end
39+
return r, back
40+
end
41+
42+
ReverseDiff.@grad_from_chainrules g(x::ReverseDiff.TrackedArray, y)
43+
ReverseDiff.@grad_from_chainrules g(x, y::ReverseDiff.TrackedArray)
44+
ReverseDiff.@grad_from_chainrules g(x::ReverseDiff.TrackedArray, y::ReverseDiff.TrackedArray)
45+
46+
@testset "rrule in ChainRules and ReverseDiff" begin
47+
## ChainRules
48+
# function f
49+
input = rand(3, 3)
50+
output, back = ChainRulesCore.rrule(f, input);
51+
_, d = back(1)
52+
@test output == f(input)
53+
@test d == fill(3, size(input))
54+
# function g
55+
inputs = rand(3, 3), rand(3, 3)
56+
output, back = ChainRulesCore.rrule(g, inputs...);
57+
_, d1, d2 = back(1)
58+
@test output == g(inputs...)
59+
@test d1 == fill(3, size(inputs[1]))
60+
@test d2 == fill(5, size(inputs[2]))
61+
62+
63+
## ReverseDiff
64+
#function f
65+
inputs = (rand(3, 3), )
66+
67+
results = (similar(inputs[1]),)
68+
f_tape = ReverseDiff.GradientTape(x -> f(x) + 2, (rand(3, 3),))
69+
ReverseDiff.gradient!(results, f_tape, inputs)
70+
71+
@test results[1] == fill(3, size(inputs[1]))
72+
73+
results = (similar(inputs[1]),)
74+
compiled_tape = ReverseDiff.CompiledTape(f_tape)
75+
ReverseDiff.gradient!(results, compiled_tape, inputs)
76+
@test results[1] == fill(3, size(inputs[1]))
77+
78+
# function g
79+
inputs = rand(3, 3), rand(3, 3)
80+
81+
results = (similar(inputs[1]), similar(inputs[2]))
82+
f_tape = ReverseDiff.GradientTape((x, y) -> g(x, y) + 2, (rand(3, 3), rand(3, 3)))
83+
ReverseDiff.gradient!(results, f_tape, inputs)
84+
85+
@test results[1] == fill(3, size(inputs[1]))
86+
@test results[2] == fill(5, size(inputs[2]))
87+
88+
results = (similar(inputs[1]), similar(inputs[2]),)
89+
compiled_tape = ReverseDiff.CompiledTape(f_tape)
90+
ReverseDiff.gradient!(results, compiled_tape, inputs)
91+
@test results[1] == fill(3, size(inputs[1]))
92+
@test results[2] == fill(5, size(inputs[2]))
93+
94+
end
95+
96+
### Tape test
97+
@testset "Tape test: Ensure ordinary call is not tracked" begin
98+
tp = ReverseDiff.InstructionTape()
99+
100+
f(x) = sum(2x .+ g([1, 2], [3, 4]))
101+
x = rand(3, 3)
102+
xt = ReverseDiff.track(copy(x), tp)
103+
# record
104+
yt = f(xt)
105+
@test length(tp) == 3 # sum, broadcast+, broadcast*, but not `g`
106+
end
107+
108+
### Functions with varargs and kwargs
109+
# Varargs
110+
f_vararg(x, args...) = sum(4x .+ sum(args))
111+
112+
function ChainRulesCore.rrule(::typeof(f_vararg), x, args...)
113+
r = f_vararg(x, args...)
114+
function back(d)
115+
return ChainRulesCore.NoTangent(), fill(3 * d, size(x))
116+
end
117+
return r, back
118+
end
119+
120+
ReverseDiff.@grad_from_chainrules f_vararg(x::ReverseDiff.TrackedArray, args...)
121+
122+
@testset "Function with Varargs" begin
123+
inputs = (rand(3, 3), )
124+
125+
results = (similar(inputs[1]),)
126+
f_tape = ReverseDiff.GradientTape(x -> f_vararg(x, 1, 2, 3) + 2, (rand(3, 3),))
127+
ReverseDiff.gradient!(results, f_tape, inputs)
128+
129+
@test results[1] == fill(3, size(inputs[1]))
130+
end
131+
132+
133+
# Vargs and kwargs
134+
f_kw(x, args...; k=1, kwargs...) = sum(4x .+ sum(args) .+ (k + kwargs[:j]))
135+
136+
function ChainRulesCore.rrule(::typeof(f_kw), x, args...; k=1, kwargs...)
137+
r = f_kw(x, args...; k=k, kwargs...)
138+
function back(d)
139+
return ChainRulesCore.NoTangent(), fill(3 * d, size(x))
140+
end
141+
return r, back
142+
end
143+
144+
ReverseDiff.@grad_from_chainrules f_kw(x::ReverseDiff.TrackedArray, args...; k=1, kwargs...)
145+
146+
@testset "Function with Varargs and kwargs" begin
147+
inputs = (rand(3, 3), )
148+
149+
results = (similar(inputs[1]),)
150+
f_tape = ReverseDiff.GradientTape(x -> f_kw(x, 1, 2, 3; k=2, j=3) + 2, (rand(3, 3),))
151+
ReverseDiff.gradient!(results, f_tape, inputs)
152+
153+
@test results[1] == fill(3, size(inputs[1]))
154+
end
155+
156+
### Mix @grad and @grad_from_chainrules
157+
158+
h(x) = 10x
159+
h(x::ReverseDiff.TrackedArray) = ReverseDiff.track(h, x)
160+
ReverseDiff.@grad function h(x)
161+
xv = ReverseDiff.value(x)
162+
return h(xv), Δ ->* 7,) # use 7 asits derivatives
163+
end
164+
165+
@testset "ReverseDiff and ChainRules Mixed" begin
166+
t(x) = g(x, h(x))
167+
inputs = (rand(3, 3), )
168+
results = (similar(inputs[1]),)
169+
170+
g_tape = ReverseDiff.GradientTape(t, (rand(3, 3),))
171+
ReverseDiff.gradient!(results, g_tape, inputs)
172+
@test results[1] == fill(38, size(inputs[1])) # 38 = 3 + 5 * 7
173+
end
174+
175+
### Isolated Scope
176+
module IsolatedModuleForTestingScoping
177+
using ChainRulesCore
178+
using ReverseDiff: @grad_from_chainrules
179+
180+
f(x) = sum(4x .+ 1)
181+
182+
function ChainRulesCore.rrule(::typeof(f), x)
183+
r = f(x)
184+
function back(d)
185+
# return a distinguishable but improper grad
186+
return ChainRulesCore.NoTangent(), fill(3 * d, size(x))
187+
end
188+
return r, back
189+
end
190+
191+
@grad_from_chainrules f(x::TrackedArray)
192+
193+
module SubModule
194+
using Test
195+
using ReverseDiff: TrackedArray, GradientTape, gradient!
196+
using ..IsolatedModuleForTestingScoping: f
197+
@testset "rrule in Isolated Scope" begin
198+
inputs = (rand(3, 3), )
199+
200+
results = (similar(inputs[1]),)
201+
f_tape = GradientTape(x -> f(x) + 2, (rand(3, 3),))
202+
gradient!(results, f_tape, inputs)
203+
204+
@test results[1] == fill(3, size(inputs[1]))
205+
end
206+
207+
end # end of SubModule
208+
end # end of IsolatedModuleForTestingScoping
209+
210+
end

test/derivatives/LinAlgTests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ function test_arr2arr(f, a, b, tp)
7979
ReverseDiff.value!(at, a2)
8080
ReverseDiff.forward_pass!(tp)
8181
@test value(ct) == f(a2, b)
82-
82+
8383
ReverseDiff.value!(at, a)
8484
empty!(tp)
8585

test/runtests.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@ println("running MacrosTests...")
1414
t = @elapsed include(joinpath(TESTDIR, "MacrosTests.jl"))
1515
println("done (took $t seconds).")
1616

17+
println("running ChainRulesTests...")
18+
t = @elapsed include(joinpath(TESTDIR, "ChainRulesTests.jl"))
19+
println("done (took $t seconds).")
20+
1721
println("running ScalarTests...")
1822
t = @elapsed include(joinpath(TESTDIR, "derivatives/ScalarTests.jl"))
1923
println("done (took $t seconds).")

0 commit comments

Comments
 (0)