1
1
module ScalarTests
2
2
3
- using ReverseDiff, ForwardDiff, Test, DiffRules, SpecialFunctions, NaNMath
3
+ using ReverseDiff
4
+
5
+ using DiffRules
6
+ using ForwardDiff
7
+ using LogExpFunctions
8
+ using NaNMath
9
+ using SpecialFunctions
10
+
11
+ using Test
4
12
5
13
include (joinpath (dirname (@__FILE__ ), " ../utils.jl" ))
6
14
7
15
x, a, b = rand (3 )
8
16
tp = InstructionTape ()
9
17
int_range = 1 : 10
10
18
11
- function test_forward (f, x, tp:: InstructionTape , is_domain_err_func :: Bool )
19
+ function test_forward (f, x, tp:: InstructionTape , fsym :: Symbol )
12
20
xt = ReverseDiff. TrackedReal (x, zero (x), tp)
13
21
y = f (x)
14
22
@@ -23,7 +31,7 @@ function test_forward(f, x, tp::InstructionTape, is_domain_err_func::Bool)
23
31
@test deriv (xt) == ForwardDiff. derivative (f, x)
24
32
25
33
# forward
26
- x2 = is_domain_err_func ? rand () + 1 : rand ()
34
+ x2 = modify_input (fsym, rand () )
27
35
ReverseDiff. value! (xt, x2)
28
36
ReverseDiff. forward_pass! (tp)
29
37
@test value (yt) == f (x2)
@@ -133,15 +141,15 @@ function test_skip(f, a, b, tp)
133
141
@test isempty (tp)
134
142
end
135
143
136
- DOMAIN_ERR_FUNCS = (:asec , :acsc , :asecd , :acscd , :acoth , :acosh )
137
-
138
- for (M, f, arity) in DiffRules. diffrules ()
144
+ for (M, f, arity) in DiffRules. diffrules (; filter_modules= nothing )
145
+ # ensure that function is defined
146
+ if ! (isdefined (@__MODULE__ , M) && isdefined (getfield (@__MODULE__ , M), f))
147
+ error (" $M .$f is not available" )
148
+ end
139
149
f === :rem2pi && continue
140
150
if arity == 1
141
151
test_println (" forward-mode unary scalar functions" , string (M, " ." , f))
142
- is_domain_err_func = in (f, DOMAIN_ERR_FUNCS)
143
- n = is_domain_err_func ? x + 1 : x
144
- test_forward (eval (:($ M.$ f)), n, tp, is_domain_err_func)
152
+ test_forward (eval (:($ M.$ f)), modify_input (f, x), tp, f)
145
153
elseif arity == 2
146
154
in (f, SKIPPED_BINARY_SCALAR_TESTS) && continue
147
155
test_println (" forward-mode binary scalar functions" , f)
@@ -153,7 +161,7 @@ INT_ONLY_FUNCS = (:iseven, :isodd)
153
161
154
162
for f in ReverseDiff. SKIPPED_UNARY_SCALAR_FUNCS
155
163
test_println (" SKIPPED_UNARY_SCALAR_FUNCS" , f)
156
- n = in (f, DOMAIN_ERR_FUNCS) ? x + 1 : x
164
+ n = modify_input (f, x)
157
165
n = in (f, INT_ONLY_FUNCS) ? ceil (Int, n) : n
158
166
test_skip (eval (f), n, tp)
159
167
end
0 commit comments