Skip to content

Commit d522508

Browse files
authored
Fix deprecations in DiffRules 1.4 (#191)
* Fix deprecations in DiffRules 1.4 * Apply suggestions
1 parent df00674 commit d522508

File tree

8 files changed

+74
-24
lines changed

8 files changed

+74
-24
lines changed

Project.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ReverseDiff"
22
uuid = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
3-
version = "1.10.0"
3+
version = "1.11.0"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
@@ -9,6 +9,7 @@ DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b"
99
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1010
FunctionWrappers = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e"
1111
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
12+
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
1213
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
1314
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
1415
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
@@ -19,9 +20,10 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1920
[compat]
2021
ChainRulesCore = "1"
2122
DiffResults = "1"
22-
DiffRules = "0.1, 1"
23+
DiffRules = "1.4"
2324
ForwardDiff = "0.10"
2425
FunctionWrappers = "1"
26+
LogExpFunctions = "0.3"
2527
MacroTools = "0.5"
2628
NaNMath = "0.3"
2729
SpecialFunctions = "0.8, 0.9, 0.10, 1.0"

src/ReverseDiff.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ using ForwardDiff
1515
using ForwardDiff: Dual, Partials
1616
using StaticArrays
1717

18+
using LogExpFunctions: LogExpFunctions
19+
1820
using MacroTools
1921

2022
using ChainRulesCore

src/derivatives/broadcast.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -264,8 +264,11 @@ end
264264

265265
@inline _materialize(f, args) = broadcast(f, args...)
266266

267-
for (M, f, arity) in DiffRules.diffrules()
268-
isdefined(ReverseDiff, M) || continue
267+
for (M, f, arity) in DiffRules.diffrules(; filter_modules=nothing)
268+
if !(isdefined(@__MODULE__, M) && isdefined(getfield(@__MODULE__, M), f))
269+
@warn "$M.$f is not available and hence rule for it can not be defined"
270+
continue # Skip rules for methods not defined in the current scope
271+
end
269272
if arity == 1
270273
@eval @inline materialize(bc::RDBroadcasted{typeof($M.$f), <:Tuple{TrackedArray}}) = _materialize(bc.f, bc.args)
271274
elseif arity == 2

src/derivatives/elementwise.jl

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,11 @@ end
6767
# dispatch #
6868
#----------#
6969

70-
for g! in (:map!, :broadcast!), (M, f, arity) in DiffRules.diffrules()
70+
for g! in (:map!, :broadcast!), (M, f, arity) in DiffRules.diffrules(; filter_modules=nothing)
71+
if !(isdefined(@__MODULE__, M) && isdefined(getfield(@__MODULE__, M), f))
72+
@warn "$M.$f is not available and hence rule for it can not be defined"
73+
continue # Skip rules for methods not defined in the current scope
74+
end
7175
if arity == 1
7276
@eval @inline Base.$(g!)(f::typeof($M.$f), out::TrackedArray, t::TrackedArray) = $(g!)(ForwardOptimize(f), out, t)
7377
elseif arity == 2
@@ -154,7 +158,11 @@ end
154158
# dispatch #
155159
#----------#
156160

157-
for g in (:map, :broadcast), (M, f, arity) in DiffRules.diffrules()
161+
for g in (:map, :broadcast), (M, f, arity) in DiffRules.diffrules(; filter_modules=nothing)
162+
if !(isdefined(@__MODULE__, M) && isdefined(getfield(@__MODULE__, M), f))
163+
@warn "$M.$f is not available and hence rule for it can not be defined"
164+
continue # Skip rules for methods not defined in the current scope
165+
end
158166
if arity == 1
159167
@eval @inline Base.$(g)(f::typeof($M.$f), t::TrackedArray) = $(g)(ForwardOptimize(f), t)
160168
elseif arity == 2

src/derivatives/scalars.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,11 @@
22
# ForwardOptimize #
33
###################
44

5-
for (M, f, arity) in DiffRules.diffrules()
5+
for (M, f, arity) in DiffRules.diffrules(; filter_modules=nothing)
6+
if !(isdefined(@__MODULE__, M) && isdefined(getfield(@__MODULE__, M), f))
7+
@warn "$M.$f is not available and hence rule for it can not be defined"
8+
continue # Skip rules for methods not defined in the current scope
9+
end
610
if arity == 1
711
@eval @inline $M.$(f)(t::TrackedReal) = ForwardOptimize($M.$(f))(t)
812
elseif arity == 2

test/derivatives/ElementwiseTests.jl

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,15 @@
11
module ElementwiseTests
22

3-
using ReverseDiff, ForwardDiff, Test, DiffRules, SpecialFunctions, NaNMath, DiffTests
3+
using ReverseDiff
4+
5+
using DiffRules
6+
using DiffTests
7+
using ForwardDiff
8+
using LogExpFunctions
9+
using NaNMath
10+
using SpecialFunctions
11+
12+
using Test
413

514
include(joinpath(dirname(@__FILE__), "../utils.jl"))
615

@@ -379,16 +388,17 @@ for f in DiffTests.NUMBER_TO_NUMBER_FUNCS
379388
test_elementwise(f, ReverseDiff.@forward(f), a, tp)
380389
end
381390

382-
DOMAIN_ERR_FUNCS = (:asec, :acsc, :asecd, :acscd, :acoth, :acosh)
383-
384-
for (M, fsym, arity) in DiffRules.diffrules()
391+
for (M, fsym, arity) in DiffRules.diffrules(; filter_modules=nothing)
392+
# ensure that all rules can be tested
393+
if !(isdefined(@__MODULE__, M) && isdefined(getfield(@__MODULE__, M), fsym))
394+
error("$M.$fsym is not available")
395+
end
385396
fsym === :rem2pi && continue
386397
if arity == 1
387398
f = eval(:($M.$fsym))
388-
is_domain_err_func = in(fsym, DOMAIN_ERR_FUNCS)
389399
test_println("forward-mode unary scalar functions", f)
390-
test_elementwise(f, f, is_domain_err_func ? x .+ 1 : x, tp)
391-
test_elementwise(f, f, is_domain_err_func ? a .+ 1 : a, tp)
400+
test_elementwise(f, f, modify_input(fsym, x), tp)
401+
test_elementwise(f, f, modify_input(fsym, a), tp)
392402
elseif arity == 2
393403
in(fsym, SKIPPED_BINARY_SCALAR_TESTS) && continue
394404
f = eval(:($M.$fsym))

test/derivatives/ScalarTests.jl

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,22 @@
11
module ScalarTests
22

3-
using ReverseDiff, ForwardDiff, Test, DiffRules, SpecialFunctions, NaNMath
3+
using ReverseDiff
4+
5+
using DiffRules
6+
using ForwardDiff
7+
using LogExpFunctions
8+
using NaNMath
9+
using SpecialFunctions
10+
11+
using Test
412

513
include(joinpath(dirname(@__FILE__), "../utils.jl"))
614

715
x, a, b = rand(3)
816
tp = InstructionTape()
917
int_range = 1:10
1018

11-
function test_forward(f, x, tp::InstructionTape, is_domain_err_func::Bool)
19+
function test_forward(f, x, tp::InstructionTape, fsym::Symbol)
1220
xt = ReverseDiff.TrackedReal(x, zero(x), tp)
1321
y = f(x)
1422

@@ -23,7 +31,7 @@ function test_forward(f, x, tp::InstructionTape, is_domain_err_func::Bool)
2331
@test deriv(xt) == ForwardDiff.derivative(f, x)
2432

2533
# forward
26-
x2 = is_domain_err_func ? rand() + 1 : rand()
34+
x2 = modify_input(fsym, rand())
2735
ReverseDiff.value!(xt, x2)
2836
ReverseDiff.forward_pass!(tp)
2937
@test value(yt) == f(x2)
@@ -133,15 +141,15 @@ function test_skip(f, a, b, tp)
133141
@test isempty(tp)
134142
end
135143

136-
DOMAIN_ERR_FUNCS = (:asec, :acsc, :asecd, :acscd, :acoth, :acosh)
137-
138-
for (M, f, arity) in DiffRules.diffrules()
144+
for (M, f, arity) in DiffRules.diffrules(; filter_modules=nothing)
145+
# ensure that function is defined
146+
if !(isdefined(@__MODULE__, M) && isdefined(getfield(@__MODULE__, M), f))
147+
error("$M.$f is not available")
148+
end
139149
f === :rem2pi && continue
140150
if arity == 1
141151
test_println("forward-mode unary scalar functions", string(M, ".", f))
142-
is_domain_err_func = in(f, DOMAIN_ERR_FUNCS)
143-
n = is_domain_err_func ? x + 1 : x
144-
test_forward(eval(:($M.$f)), n, tp, is_domain_err_func)
152+
test_forward(eval(:($M.$f)), modify_input(f, x), tp, f)
145153
elseif arity == 2
146154
in(f, SKIPPED_BINARY_SCALAR_TESTS) && continue
147155
test_println("forward-mode binary scalar functions", f)
@@ -153,7 +161,7 @@ INT_ONLY_FUNCS = (:iseven, :isodd)
153161

154162
for f in ReverseDiff.SKIPPED_UNARY_SCALAR_FUNCS
155163
test_println("SKIPPED_UNARY_SCALAR_FUNCS", f)
156-
n = in(f, DOMAIN_ERR_FUNCS) ? x + 1 : x
164+
n = modify_input(f, x)
157165
n = in(f, INT_ONLY_FUNCS) ? ceil(Int, n) : n
158166
test_skip(eval(f), n, tp)
159167
end

test/utils.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,3 +23,16 @@ test_println(kind, f, pad = " ") = println(pad, "testing $(kind): `$(f)`...")
2323
tracked_is(a, b) = value(a) === value(b) && deriv(a) === deriv(b) && tape(a) === tape(b)
2424
tracked_is(a::AbstractArray, b::AbstractArray) = all(map(tracked_is, a, b))
2525
tracked_is(a::Tuple, b::Tuple) = all(map(tracked_is, a, b))
26+
27+
# ensure that input is in domain of function
28+
# here `x` is a scalar or array generated with `rand(dims...)`, i.e., values of `x`
29+
# are between 0 and 1
30+
function modify_input(f, x)
31+
return if in(f, (:asec, :acsc, :asecd, :acscd, :acosh, :acoth))
32+
x .+ one(eltype(x))
33+
elseif f === :log1mexp || f === :log2mexp
34+
x .- one(eltype(x))
35+
else
36+
x
37+
end
38+
end

0 commit comments

Comments
 (0)