Skip to content

Commit c2ec380

Browse files
committed
add compatibility with Turing.Experimental.Gibbs
1 parent 2e2efed commit c2ec380

File tree

5 files changed

+94
-5
lines changed

5 files changed

+94
-5
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "SliceSampling"
22
uuid = "43f4d3e8-9711-4a8c-bd1b-03ac73a255cf"
3-
version = "0.5.0"
3+
version = "0.6.0"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
@@ -31,7 +31,7 @@ LogDensityProblemsAD = "1"
3131
Random = "1"
3232
Requires = "1"
3333
SimpleUnPack = "1"
34-
Turing = "0.31, 0.32, 0.33"
34+
Turing = "0.33"
3535
julia = "1.7"
3636

3737
[extras]

docs/src/general.md

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,39 @@ model = demo()
6060
sample(model, externalsampler(sampler), n_samples)
6161
```
6262

63+
### Conditional sampling in a `Turing.Experimental.Gibbs` sampler
64+
`SliceSampling.jl` be used as a conditional sampler in `Turing.Experimental.Gibbs`.
65+
66+
```@example turinggibbs
67+
using Distributions
68+
using FillArrays
69+
using Turing
70+
using SliceSampling
71+
72+
@model function simple_choice(xs)
73+
p ~ Beta(2, 2)
74+
z ~ Bernoulli(p)
75+
for i in 1:length(xs)
76+
if z == 1
77+
xs[i] ~ Normal(0, 1)
78+
else
79+
xs[i] ~ Normal(2, 1)
80+
end
81+
end
82+
end
83+
84+
sampler = Turing.Experimental.Gibbs(
85+
(
86+
p = externalsampler(SliceSteppingOut(2.0)),
87+
z = PG(20, :z)
88+
)
89+
)
90+
91+
n_samples = 1000
92+
model = simple_choice([1.5, 2.0, 0.3])
93+
sample(model, sampler, n_samples)
94+
```
95+
6396
## Drawing Samples
6497
For drawing samples using the algorithms provided by `SliceSampling`, the user only needs to call:
6598
```julia

ext/SliceSamplingTuringExt.jl

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,53 @@ if isdefined(Base, :get_extension)
55
using LogDensityProblemsAD
66
using Random
77
using SliceSampling
8-
using Turing: Turing
8+
using Turing
9+
# using Turing: Turing, Experimental
910
else
1011
using ..LogDensityProblemsAD
1112
using ..Random
1213
using ..SliceSampling
13-
using ..Turing: Turing
14+
using ..Turing
15+
#using ..Turing: Turing, Experimental
1416
end
1517

18+
# Required for using the slice samplers as `externalsampler`s in Turing
19+
# begin
1620
Turing.Inference.getparams(
1721
::Turing.DynamicPPL.Model,
1822
sample::SliceSampling.Transition
1923
) = sample.params
24+
# end
25+
26+
# Required for using the slice samplers as `Experimental.Gibbs` samplers in Turing
27+
# begin
28+
Turing.Inference.getparams(
29+
::Turing.DynamicPPL.Model,
30+
state::SliceSampling.UnivariateSliceState
31+
) = state.transition.params
32+
33+
Turing.Inference.getparams(
34+
::Turing.DynamicPPL.Model,
35+
state::SliceSampling.GibbsState
36+
) = state.transition.params
37+
38+
Turing.Inference.getparams(
39+
::Turing.DynamicPPL.Model,
40+
state::SliceSampling.HitAndRunState
41+
) = state.transition.params
42+
43+
Turing.Experimental.gibbs_requires_recompute_logprob(
44+
model_dst,
45+
::Turing.DynamicPPL.Sampler{
46+
<: Turing.Inference.ExternalSampler{
47+
<: SliceSampling.AbstractSliceSampling, A, U
48+
}
49+
},
50+
sampler_src,
51+
state_dst,
52+
state_src
53+
) where {A,U} = false
54+
# end
2055

2156
function SliceSampling.initial_sample(
2257
rng::Random.AbstractRNG,

test/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,5 +18,5 @@ MCMCTesting = "0.3"
1818
Random = "1"
1919
StableRNGs = "1"
2020
Test = "1"
21-
Turing = "0.31"
21+
Turing = "0.33"
2222
julia = "1.6"

test/turing.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,4 +35,25 @@
3535
progress=false,
3636
)
3737
end
38+
39+
@testset "gibbs($sampler)" for sampler in [
40+
RandPermGibbs(Slice(1)),
41+
RandPermGibbs(SliceSteppingOut(1)),
42+
RandPermGibbs(SliceDoublingOut(1)),
43+
Slice(1),
44+
SliceSteppingOut(1),
45+
SliceDoublingOut(1),
46+
]
47+
sample(
48+
model,
49+
Turing.Experimental.Gibbs(
50+
(
51+
s = externalsampler(sampler),
52+
m = externalsampler(sampler),
53+
),
54+
),
55+
n_samples,
56+
progress=false,
57+
)
58+
end
3859
end

0 commit comments

Comments
 (0)