Skip to content

Commit 24d1b00

Browse files
Guided Filter Implementation (#74)
* fixed type stability of linear filter * added MLE demonstration * flipped sign of objective function * reorganized and added Mooncake MWE * fixed KF type stability in Enzyme * add MWE for Kalman filtering * replaced second order optimizer and added backend testing * fixed type stability of linear filter * added MLE demonstration * flipped sign of objective function * reorganized and added Mooncake MWE * fixed KF type stability in Enzyme * add MWE for Kalman filtering * replaced second order optimizer and added backend testing * fixed Mooncake errors for Bootstrap filter * Add guided filter draft * add VSMC replication * switch proposals for demonstration * fix formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * add fix for Flux and Mooncake * add plots and fix formatting * restructure particle filters * update example * fixed type signatures for bootstrap filter * guess who forgot to run the formatter?? it was me Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * update forward algorithm * additional merge fixes * remove redundant files * consolidate MLE examples * suggested changes and house cleaning * update proposal definition * add unit testing * formatter Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent 213a59a commit 24d1b00

File tree

12 files changed

+574
-220
lines changed

12 files changed

+574
-220
lines changed

GeneralisedFilters/src/GeneralisedFilters.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ function initialise(model, alg; kwargs...)
5454
return initialise(default_rng(), model, alg; kwargs...)
5555
end
5656

57-
function predict(model, alg, step, filtered; kwargs...)
57+
function predict(model, alg, step, filtered, observation; kwargs...)
5858
return predict(default_rng(), model, alg, step, filtered; kwargs...)
5959
end
6060

@@ -108,7 +108,7 @@ function step(
108108
callback::Union{AbstractCallback,Nothing}=nothing,
109109
kwargs...,
110110
)
111-
state = predict(rng, model, alg, iter, state; kwargs...)
111+
state = predict(rng, model, alg, iter, state, observation; kwargs...)
112112
isnothing(callback) ||
113113
callback(model, alg, iter, state, observation, PostPredict; kwargs...)
114114

@@ -132,7 +132,7 @@ include("models/discrete.jl")
132132
include("models/hierarchical.jl")
133133

134134
# Filtering/smoothing algorithms
135-
include("algorithms/bootstrap.jl")
135+
include("algorithms/particles.jl")
136136
include("algorithms/kalman.jl")
137137
include("algorithms/forward.jl")
138138
include("algorithms/rbpf.jl")

GeneralisedFilters/src/algorithms/bootstrap.jl

Lines changed: 0 additions & 121 deletions
This file was deleted.

GeneralisedFilters/src/algorithms/forward.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@ function predict(
1414
model::DiscreteStateSpaceModel{T},
1515
filter::ForwardAlgorithm,
1616
step::Integer,
17-
states::AbstractVector;
17+
states::AbstractVector,
18+
observation;
1819
kwargs...,
1920
) where {T}
2021
P = calc_P(model.dyn, step; kwargs...)

GeneralisedFilters/src/algorithms/kalman.jl

Lines changed: 32 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
export KalmanFilter, filter, BatchKalmanFilter
22
using GaussianDistributions
33
using CUDA: i32
4+
import LinearAlgebra: hermitianpart
45

56
export KalmanFilter, KF, KalmanSmoother, KS
67

@@ -18,42 +19,40 @@ end
1819
function predict(
1920
rng::AbstractRNG,
2021
model::LinearGaussianStateSpaceModel,
21-
filter::KalmanFilter,
22-
step::Integer,
23-
filtered::Gaussian;
22+
algo::KalmanFilter,
23+
iter::Integer,
24+
state::Gaussian,
25+
observation=nothing;
2426
kwargs...,
2527
)
26-
μ, Σ = GaussianDistributions.pair(filtered)
27-
A, b, Q = calc_params(model.dyn, step; kwargs...)
28+
μ, Σ = GaussianDistributions.pair(state)
29+
A, b, Q = calc_params(model.dyn, iter; kwargs...)
2830
return Gaussian(A * μ + b, A * Σ * A' + Q)
2931
end
3032

3133
function update(
3234
model::LinearGaussianStateSpaceModel,
33-
filter::KalmanFilter,
34-
step::Integer,
35-
proposed::Gaussian,
36-
obs::AbstractVector;
35+
algo::KalmanFilter,
36+
iter::Integer,
37+
state::Gaussian,
38+
observation::AbstractVector;
3739
kwargs...,
3840
)
39-
μ, Σ = GaussianDistributions.pair(proposed)
40-
H, c, R = calc_params(model.obs, step; kwargs...)
41+
μ, Σ = GaussianDistributions.pair(state)
42+
H, c, R = calc_params(model.obs, iter; kwargs...)
4143

4244
# Update state
4345
m = H * μ + c
44-
y = obs - m
45-
S = H * Σ * H' + R
46+
y = observation - m
47+
S = hermitianpart(H * Σ * H' + R)
4648
K = Σ * H' / S
4749

48-
# HACK: force the covariance to be positive definite
49-
S = (S + S') / 2
50-
51-
filtered = Gaussian+ K * y, Σ - K * H * Σ)
50+
state = Gaussian+ K * y, Σ - K * H * Σ)
5251

5352
# Compute log-likelihood
54-
ll = logpdf(MvNormal(m, S), obs)
53+
ll = logpdf(MvNormal(m, S), observation)
5554

56-
return filtered, ll
55+
return state, ll
5756
end
5857

5958
struct BatchKalmanFilter <: AbstractBatchFilter
@@ -74,12 +73,13 @@ function predict(
7473
rng::AbstractRNG,
7574
model::LinearGaussianStateSpaceModel{T},
7675
algo::BatchKalmanFilter,
77-
step::Integer,
78-
state::BatchGaussianDistribution;
76+
iter::Integer,
77+
state::BatchGaussianDistribution,
78+
observation;
7979
kwargs...,
8080
) where {T}
8181
μs, Σs = state.μs, state.Σs
82-
As, bs, Qs = batch_calc_params(model.dyn, step, algo.batch_size; kwargs...)
82+
As, bs, Qs = batch_calc_params(model.dyn, iter, algo.batch_size; kwargs...)
8383
μ̂s = NNlib.batched_vec(As, μs) .+ bs
8484
Σ̂s = NNlib.batched_mul(NNlib.batched_mul(As, Σs), NNlib.batched_transpose(As)) .+ Qs
8585
return BatchGaussianDistribution(μ̂s, Σ̂s)
@@ -88,17 +88,17 @@ end
8888
function update(
8989
model::LinearGaussianStateSpaceModel{T},
9090
algo::BatchKalmanFilter,
91-
step::Integer,
91+
iter::Integer,
9292
state::BatchGaussianDistribution,
93-
obs;
93+
observation;
9494
kwargs...,
9595
) where {T}
9696
μs, Σs = state.μs, state.Σs
97-
Hs, cs, Rs = batch_calc_params(model.obs, step, algo.batch_size; kwargs...)
98-
D = size(obs, 1)
97+
Hs, cs, Rs = batch_calc_params(model.obs, iter, algo.batch_size; kwargs...)
98+
D = size(observation, 1)
9999

100100
m = NNlib.batched_vec(Hs, μs) .+ cs
101-
y_res = cu(obs) .- m
101+
y_res = cu(observation) .- m
102102
S = NNlib.batched_mul(Hs, NNlib.batched_mul(Σs, NNlib.batched_transpose(Hs))) .+ Rs
103103

104104
ΣH_T = NNlib.batched_mul(Σs, NNlib.batched_transpose(Hs))
@@ -151,7 +151,7 @@ function (callback::StateCallback)(
151151
algo::KalmanFilter,
152152
iter::Integer,
153153
state,
154-
obs,
154+
observation,
155155
::PostPredictCallback;
156156
kwargs...,
157157
)
@@ -164,7 +164,7 @@ function (callback::StateCallback)(
164164
algo::KalmanFilter,
165165
iter::Integer,
166166
state,
167-
obs,
167+
observation,
168168
::PostUpdateCallback;
169169
kwargs...,
170170
)
@@ -175,7 +175,7 @@ end
175175
function smooth(
176176
rng::AbstractRNG,
177177
model::LinearGaussianStateSpaceModel{T},
178-
alg::KalmanSmoother,
178+
algo::KalmanSmoother,
179179
observations::AbstractVector;
180180
t_smooth=1,
181181
callback=nothing,
@@ -190,7 +190,7 @@ function smooth(
190190
back_state = filtered
191191
for t in (length(observations) - 1):-1:t_smooth
192192
back_state = backward(
193-
rng, model, alg, t, back_state, observations[t]; states_cache=cache, kwargs...
193+
rng, model, algo, t, back_state, observations[t]; states_cache=cache, kwargs...
194194
)
195195
end
196196

@@ -200,7 +200,7 @@ end
200200
function backward(
201201
rng::AbstractRNG,
202202
model::LinearGaussianStateSpaceModel{T},
203-
alg::KalmanSmoother,
203+
algo::KalmanSmoother,
204204
iter::Integer,
205205
back_state,
206206
obs;

0 commit comments

Comments
 (0)