Skip to content

Commit 34a1f8e

Browse files
Fix copy mechanism (#71)
* Fix copy mechanism * Update Project.toml Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com>
1 parent 118e45e commit 34a1f8e

File tree

6 files changed

+13
-13
lines changed

6 files changed

+13
-13
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "AdvancedPS"
22
uuid = "576499cb-2369-40b2-a588-c64705576edc"
33
authors = ["TuringLang"]
4-
version = "0.4.2"
4+
version = "0.4.3"
55

66
[deps]
77
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

examples/gaussian-process/script.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,6 @@ function gp_update(model::GPSSM, state, step)
5454
return Normal(μ[1], σ[1])
5555
end
5656

57-
Libtask.tape_copy(model::GPSSM) = deepcopy(model)
58-
5957
AdvancedPS.initialization(::GPSSM) = h(model)
6058
AdvancedPS.transition(model::GPSSM, state, step) = gp_update(model, state, step)
6159
AdvancedPS.observation(model::GPSSM, state, step) = logpdf(g(model, state, step), y[step])

examples/particle-gibbs/script.jl

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,9 @@ Parameters = @NamedTuple begin
3333
end
3434

3535
mutable struct NonLinearTimeSeries <: AbstractMCMC.AbstractModel
36-
X::TArray
36+
X::Array
3737
θ::Parameters
38-
NonLinearTimeSeries::Parameters) = new(TArray(Float64, θ.T), θ)
38+
NonLinearTimeSeries::Parameters) = new(zeros(Float64, θ.T), θ)
3939
end
4040

4141
f(model::NonLinearTimeSeries, state, t) = Normal(model.θ.a * state, model.θ.q)
@@ -87,10 +87,6 @@ function (model::NonLinearTimeSeries)(rng::Random.AbstractRNG)
8787
end
8888
end
8989

90-
# `AdvancedPS` relies on `Libtask` to copy models during their execution but we need to make sure the
91-
# internal data of each model is properly copied over as well.
92-
Libtask.tape_copy(model::NonLinearTimeSeries) = deepcopy(model)
93-
9490
# Here we use the particle gibbs kernel without adaptive resampling.
9591
model = NonLinearTimeSeries(θ₀)
9692
pgas = AdvancedPS.PG(Nₚ, 1.0)

src/model.jl

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,11 @@ struct GenericModel{F1,F2} <: AbstractMCMC.AbstractModel
1010
GenericModel(f::F1, ctask::Libtask.TapedTask{F2}) where {F1,F2} = new{F1,F2}(f, ctask)
1111
end
1212

13-
GenericModel(f, args...) = GenericModel(f, Libtask.TapedTask(f, args...))
13+
function GenericModel(f, args...)
14+
return GenericModel(
15+
f, Libtask.TapedTask(f, args...; deepcopy_types=Union{TracedRNG,typeof(f)})
16+
)
17+
end
1418
Base.copy(model::GenericModel) = GenericModel(model.f, copy(model.ctask))
1519

1620
"""
@@ -74,7 +78,10 @@ function forkr(trace::GenericTrace)
7478
newf = reset_model(trace.model.f)
7579
Random123.set_counter!(trace.rng, 1)
7680

77-
ctask = Libtask.TapedTask(newf, trace.rng)
81+
ctask = Libtask.TapedTask(
82+
newf, trace.rng; deepcopy_types=Union{TracedRNG,typeof(trace.model.f)}
83+
)
84+
#ctask = Libtask.TapedTask(newf, trace.rng)
7885
new_tapedmodel = GenericModel(newf, ctask)
7986

8087
# add backward reference

test/container.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,6 @@
190190

191191
AdvancedPS.advance!(ref)
192192
child = AdvancedPS.fork(ref, true)
193-
@test length(new_tr.rng.keys) == 1
193+
@test length(child.rng.keys) == 1
194194
end
195195
end

test/pgas.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,6 @@
7979

8080
@testset "rng stability" begin
8181
model = BaseModel(Params(0.9, 0.32, 1))
82-
8382
seed = 10
8483
rng = Random.MersenneTwister(seed)
8584

0 commit comments

Comments
 (0)