Skip to content

Commit c791095

Browse files
Merge gpu code
1 parent b78efce commit c791095

File tree

6 files changed

+30
-50
lines changed

6 files changed

+30
-50
lines changed

src/GeneralisedFilters.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ using NNlib
1313

1414
abstract type AbstractFilter <: AbstractSampler end
1515

16+
abstract type AbstractParticleFilter{N} <: AbstractFilter end
17+
1618
"""
1719
predict([rng,] model, alg, iter, state; kwargs...)
1820

src/algorithms/apf.jl

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
export AuxiliaryParticleFilter, APF
22

3-
mutable struct AuxiliaryParticleFilter{RS<:AbstractConditionalResampler} <: AbstractFilter
4-
N::Integer
3+
mutable struct AuxiliaryParticleFilter{N,RS<:AbstractConditionalResampler} <: AbstractParticleFilter{N}
54
resampler::RS
65
aux::Vector # Auxiliary weights
76
end
@@ -10,20 +9,20 @@ function AuxiliaryParticleFilter(
109
N::Integer; threshold::Real=0., resampler::AbstractResampler=Systematic()
1110
)
1211
conditional_resampler = ESSResampler(threshold, resampler)
13-
return AuxiliaryParticleFilter(N, conditional_resampler, zeros(N))
12+
return AuxiliaryParticleFilter{N,typeof(conditional_resampler)}(conditional_resampler, zeros(N))
1413
end
1514

1615
const APF = AuxiliaryParticleFilter
1716

1817
function initialise(
1918
rng::AbstractRNG,
2019
model::StateSpaceModel{T},
21-
filter::AuxiliaryParticleFilter;
20+
filter::AuxiliaryParticleFilter{N},
2221
ref_state::Union{Nothing,AbstractVector}=nothing,
2322
kwargs...,
24-
) where {T}
25-
initial_states = map(x -> SSMProblems.simulate(rng, model.dyn; kwargs...), 1:(filter.N))
26-
initial_weights = fill(-log(T(filter.N)), filter.N)
23+
) where {N,T}
24+
initial_states = map(x -> SSMProblems.simulate(rng, model.dyn; kwargs...), 1:N)
25+
initial_weights = zeros(T, N)
2726

2827
return update_ref!(ParticleContainer(initial_states, initial_weights), ref_state, filter)
2928
end
@@ -86,7 +85,7 @@ function update(
8685
states.filtered.log_weights = states.proposed.log_weights + log_increments
8786
states.filtered.particles = states.proposed.particles
8887

89-
return (states, logsumexp(log_increments) - log(T(filter.N)))
88+
return states, logmarginal(states, filter)
9089
end
9190

9291
function step(

src/algorithms/bootstrap.jl

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,28 @@
11
export BootstrapFilter, BF
22

3-
struct BootstrapFilter{RS<:AbstractResampler} <: AbstractFilter
4-
N::Integer
3+
struct BootstrapFilter{N,RS<:AbstractResampler} <: AbstractParticleFilter{N}
54
resampler::RS
65
end
76

87
function BootstrapFilter(
98
N::Integer; threshold::Real=1.0, resampler::AbstractResampler=Systematic()
109
)
1110
conditional_resampler = ESSResampler(threshold, resampler)
12-
return BootstrapFilter(N, conditional_resampler)
11+
return BootstrapFilter{N, typeof(conditional_resampler)}(conditional_resampler)
1312
end
1413

1514
"""Shorthand for `BootstrapFilter`"""
1615
const BF = BootstrapFilter
1716

18-
function BootstrapFilter(
19-
N::Integer; threshold::Real=1.0, resampler::AbstractResampler=Systematic()
20-
)
21-
conditional_resampler = ESSResampler(threshold, resampler)
22-
return BootstrapFilter(N, conditional_resampler)
23-
end
24-
2517
function initialise(
2618
rng::AbstractRNG,
2719
model::StateSpaceModel{T},
28-
filter::BootstrapFilter;
20+
filter::BootstrapFilter{N};
2921
ref_state::Union{Nothing,AbstractVector}=nothing,
3022
kwargs...,
31-
) where {T}
32-
initial_states = map(x -> SSMProblems.simulate(rng, model.dyn; kwargs...), 1:(filter.N))
33-
initial_weights = zeros(T, filter.N)
23+
) where {N,T}
24+
initial_states = map(x -> SSMProblems.simulate(rng, model.dyn; kwargs...), 1:N)
25+
initial_weights = zeros(T, N)
3426

3527
return update_ref!(
3628
ParticleContainer(initial_states, initial_weights), ref_state, filter
@@ -71,7 +63,7 @@ function update(
7163
states.filtered.log_weights = states.proposed.log_weights + log_increments
7264
states.filtered.particles = states.proposed.particles
7365

74-
return states, logmarginal(states)
66+
return states, logmarginal(states, filter)
7567
end
7668

7769
function reset_weights!(

src/algorithms/rbpf.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ function update(
108108

109109
states.filtered.log_weights = states.proposed.log_weights + log_increments
110110

111-
return states, logmarginal(states)
111+
return states, logmarginal(states, algo)
112112
end
113113

114114
#################################

src/containers.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,11 +105,15 @@ Base.keys(state::ParticleState) = LinearIndices(state.particles)
105105
Base.@propagate_inbounds Base.getindex(state::ParticleState, i) = state.particles[i]
106106
# Base.@propagate_inbounds Base.getindex(state::ParticleState, i::Vector{Int}) = state.particles[i]
107107

108-
function reset_weights!(state::ParticleState{T,WT}) where {T,WT<:Real}
108+
function reset_weights!(state::ParticleState{T,WT}, idx, ::AbstractFilter) where {T,WT<:Real}
109109
fill!(state.log_weights, zero(WT))
110110
return state.log_weights
111111
end
112112

113+
function logmarginal(states::ParticleContainer, ::AbstractFilter)
114+
return logsumexp(states.filtered.log_weights) - logsumexp(states.proposed.log_weights)
115+
end
116+
113117
function update_ref!(
114118
pc::ParticleContainer{T},
115119
ref_state::Union{Nothing,AbstractVector{T}},

src/resamplers.jl

Lines changed: 8 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,12 @@ function resample(
1111
rng::AbstractRNG,
1212
resampler::AbstractResampler,
1313
states::ParticleState{PT,WT},
14-
filter::AbstractFilter,
14+
filter::AbstractFilter;
15+
weights::AbstractVector{WT}=StatsBase.weights(states)
1516
) where {PT,WT}
16-
weights = StatsBase.weights(states)
1717
idxs = sample_ancestors(rng, resampler, weights)
18-
19-
new_state = ParticleState(deepcopy(states.particles[idxs]), zeros(WT, length(states)))
20-
18+
new_state = ParticleState(deepcopy(states.particles[idxs]), zeros(WT, length(states)))
19+
reset_weights!(new_state, idxs, filter)
2120
return new_state, idxs
2221
end
2322

@@ -26,8 +25,9 @@ function resample(
2625
rng::AbstractRNG,
2726
resampler::AbstractResampler,
2827
states::RaoBlackwellisedParticleState{T,M,ZT},
28+
::AbstractFilter;
29+
weights=StatsBase.weights(states)
2930
) where {T,M,ZT}
30-
weights = StatsBase.weights(states)
3131
idxs = sample_ancestors(rng, resampler, weights)
3232

3333
new_state = RaoBlackwellisedParticleState(
@@ -39,23 +39,6 @@ function resample(
3939
return new_state, idxs
4040
end
4141

42-
# TODO: combine this with above definition
43-
function resample(
44-
rng::AbstractRNG,
45-
resampler::AbstractResampler,
46-
states::RaoBlackwellisedParticleState{T,M,ZT},
47-
) where {T,M,ZT}
48-
weights = StatsBase.weights(states)
49-
idxs = sample_ancestors(rng, resampler, weights)
50-
51-
new_state = RaoBlackwellisedParticleState(
52-
deepcopy(states.x_particles[:, idxs]),
53-
deepcopy(states.z_particles[idxs]),
54-
CUDA.zeros(T, length(states)),
55-
)
56-
return reset_weights!(state, idxs, filter)
57-
end
58-
5942
## CONDITIONAL RESAMPLING ##################################################################
6043

6144
abstract type AbstractConditionalResampler <: AbstractResampler end
@@ -69,7 +52,7 @@ struct ESSResampler <: AbstractConditionalResampler
6952
end
7053

7154
function resample(
72-
rng::AbstractRNG, cond_resampler::ESSResampler, state::ParticleState{PT,WT}
55+
rng::AbstractRNG, cond_resampler::ESSResampler, state::ParticleState{PT,WT}, filter::AbstractFilter
7356
) where {PT,WT}
7457
n = length(state)
7558
# TODO: computing weights twice. Should create a wrapper to avoid this
@@ -78,7 +61,7 @@ function resample(
7861
@debug "ESS: $ess"
7962

8063
if cond_resampler.threshold * n ess
81-
return resample(rng, cond_resampler.resampler, state)
64+
return resample(rng, cond_resampler.resampler, state, filter; weights=weights)
8265
else
8366
return deepcopy(state), collect(1:n)
8467
end

0 commit comments

Comments
 (0)