-
Notifications
You must be signed in to change notification settings - Fork 162
Open
Description
Note that this is an issue that occurs on branch of #417, and not master
, because this code fails for a different reason that #417 fixes.
The following code
using Gen
@gen (static) function foo()
@param b::Vector{Float32}
n = length(b)
x = zeros(n)
a ~ normal(sum(x), 1.0)
return nothing
end
@load_generated_functions()
init_parameter!((foo, :b), [0.0, 0.0])
trace = simulate(foo, ())
accumulate_param_gradients!(trace)
produces the error:
ERROR: LoadError: MethodError: no method matching zeros(::ReverseDiff.TrackedReal{Int64, Int64, Nothing})
Closest candidates are:
zeros(::Union{Integer, AbstractUnitRange}...) at array.jl:498
zeros(::Tuple{Vararg{Union{Integer, AbstractUnitRange}, N} where N}) at array.jl:500
zeros(::Type{StaticArrays.MVector{N, T} where T}) where N at /home/marcoct/.julia/packages/StaticArrays/xV8rq/src/MVector.jl:25
...
Stacktrace:
[1] (::var"#2#7")(n::ReverseDiff.TrackedReal{Int64, Int64, Nothing})
@ Main ./none:0
[2] macro expansion
@ ~/.julia/packages/Gen/3mYgc/src/static_ir/backprop.jl:0 [inlined]
[3] accumulate_param_gradients!(trace::var"##StaticIRTrace_foo#270", retval_grad::Nothing, scale_factor::Float64)
@ Main ~/.julia/packages/Gen/3mYgc/src/static_ir/backprop.jl:549
[4] accumulate_param_gradients!(trace::var"##StaticIRTrace_foo#270")
@ Gen ~/.julia/packages/Gen/3mYgc/src/gen_fn_interface.jl:403
[5] top-level scope
@ ~/dev/GenExamples.jl/test/test.jl:15
in expression starting at /home/marcoct/dev/GenExamples.jl/test/test.jl:15
A careful redesign of how ReverseDiff is used for AD is probably needed. (ReverseDiff is currently being used as a stop-gap because it provides differentiation of arithmetic and linear algebra operations, and support for AD of new operations should be added by writing generative functions -- e.g. using https://www.gen.dev/dev/ref/extending/#Gen.CustomGradientGF -- instead of by extending ReverseDiff).
Metadata
Metadata
Assignees
Labels
No labels