Skip to content

Commit 321fd29

Browse files
committed
feat(inference): add nuts sampler
1 parent dce003c commit 321fd29

File tree

5 files changed

+241
-12
lines changed

5 files changed

+241
-12
lines changed

src/inference/hmc.jl

-12
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,3 @@
1-
function sample_momenta(n::Int)
2-
Float64[random(normal, 0, 1) for _=1:n]
3-
end
4-
5-
function assess_momenta(momenta)
6-
logprob = 0.
7-
for val in momenta
8-
logprob += logpdf(normal, val, 0, 1)
9-
end
10-
logprob
11-
end
12-
131
"""
142
(new_trace, accepted) = hmc(
153
trace, selection::Selection; L=10, eps=0.1,

src/inference/hmc_common.jl

+39
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
function sample_momenta(n::Int)
2+
Float64[random(normal, 0, 1) for _=1:n]
3+
end
4+
5+
function assess_momenta(momenta)
6+
logprob = 0.
7+
for val in momenta
8+
logprob += logpdf(normal, val, 0, 1)
9+
end
10+
logprob
11+
end
12+
13+
function add_choicemaps(a::ChoiceMap, b::ChoiceMap)
14+
out = choicemap()
15+
16+
for (name, val) in get_values_shallow(a)
17+
out[name] = val + b[name]
18+
end
19+
20+
for (name, submap) in get_submaps_shallow(a)
21+
out.internal_nodes[name] = add_choicemaps(submap, get_submap(b, name))
22+
end
23+
24+
return out
25+
end
26+
27+
function scale_choicemap(a::ChoiceMap, scale)
28+
out = choicemap()
29+
30+
for (name, val) in get_values_shallow(a)
31+
out[name] = val * scale
32+
end
33+
34+
for (name, submap) in get_submaps_shallow(a)
35+
out.internal_nodes[name] = scale_choicemap(submap, scale)
36+
end
37+
38+
return out
39+
end

src/inference/inference.jl

+3
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,13 @@ export logsumexp
1414

1515
include("trace_translators.jl")
1616

17+
include("hmc_common.jl")
18+
1719
# mcmc
1820
include("kernel_dsl.jl")
1921
include("mh.jl")
2022
include("hmc.jl")
23+
include("nuts.jl")
2124
include("mala.jl")
2225
include("elliptical_slice.jl")
2326

src/inference/nuts.jl

+179
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
using LinearAlgebra: dot
2+
3+
struct Tree
4+
val_left :: ChoiceMap
5+
momenta_left :: ChoiceMap
6+
val_right :: ChoiceMap
7+
momenta_right :: ChoiceMap
8+
val_sample :: ChoiceMap
9+
n :: Int
10+
weight :: Float64
11+
stop :: Bool
12+
diverging :: Bool
13+
end
14+
15+
function u_turn(values_left, values_right, momenta_left, momenta_right)
16+
return (dot(values_left - values_right, momenta_right) >= 0) &&
17+
(dot(values_right - values_left, momenta_left) >= 0)
18+
end
19+
20+
function leapfrog(values_trie, momenta_trie, eps, integrator_state)
21+
selection, retval_grad, trace = integrator_state
22+
23+
(trace, _, _) = update(trace, values_trie)
24+
(_, _, gradient_trie) = choice_gradients(trace, selection, retval_grad)
25+
26+
# half step on momenta
27+
momenta_trie = add_choicemaps(momenta_trie, scale_choicemap(gradient_trie, eps / 2))
28+
29+
# full step on positions
30+
values_trie = add_choicemaps(values_trie, scale_choicemap(momenta_trie, eps))
31+
32+
# get new gradient
33+
(trace, _, _) = update(trace, values_trie)
34+
(_, _, gradient_trie) = choice_gradients(trace, selection, retval_grad)
35+
36+
# half step on momenta
37+
momenta_trie = add_choicemaps(momenta_trie, scale_choicemap(gradient_trie, eps / 2))
38+
return values_trie, momenta_trie, get_score(trace)
39+
end
40+
41+
function build_root(val, momenta, eps, direction, weight_init, integrator_state)
42+
val, momenta, lp = leapfrog(val, momenta, direction * eps, integrator_state)
43+
weight = lp + assess_momenta(to_array(momenta, Float64))
44+
45+
diverging = weight - weight_init > 1000
46+
47+
return Tree(val, momenta, val, momenta, val, 1, weight, false, diverging)
48+
end
49+
50+
function merge_trees(tree_left, tree_right)
51+
# multinomial sampling
52+
if log(rand()) < tree_right.weight - tree_left.weight
53+
sample = tree_right.val_sample
54+
else
55+
sample = tree_left.val_sample
56+
end
57+
58+
weight = logsumexp(tree_left.weight, tree_right.weight)
59+
n = tree_left.n + tree_right.n
60+
61+
stop = tree_left.stop || tree_right.stop || u_turn(to_array(tree_left.val_left, Float64),
62+
to_array(tree_right.val_right, Float64),
63+
to_array(tree_left.momenta_left, Float64),
64+
to_array(tree_right.momenta_right, Float64))
65+
diverging = tree_left.diverging || tree_right.diverging
66+
67+
return Tree(tree_left.val_left, tree_left.momenta_left, tree_right.val_right,
68+
tree_right.momenta_right, sample, n, weight, stop, diverging)
69+
end
70+
71+
function build_tree(val, momenta, depth, eps, direction, weight_init, integrator_state)
72+
if depth == 0
73+
return build_root(val, momenta, eps, direction, weight_init, integrator_state)
74+
end
75+
76+
tree = build_tree(val, momenta, depth - 1, eps, direction, weight_init, integrator_state)
77+
78+
if tree.stop || tree.diverging
79+
return tree
80+
end
81+
82+
if direction == 1
83+
other_tree = build_tree(tree.val_right, tree.momenta_right, depth - 1, eps, direction,
84+
weight_init, integrator_state)
85+
return merge_trees(tree, other_tree)
86+
else
87+
other_tree = build_tree(tree.val_left, tree.momenta_left, depth - 1, eps, direction,
88+
weight_init, integrator_state)
89+
return merge_trees(other_tree, tree)
90+
end
91+
end
92+
93+
"""
94+
(new_trace, sampler_statistics) = nuts(
95+
trace, selection::Selection;eps=0.1,
96+
max_treedepth=15, check=false, observations=EmptyChoiceMap())
97+
98+
Apply a Hamiltonian Monte Carlo (HMC) update with a No U Turn stopping criterion that proposes new values for the selected addresses, returning the new trace (which is equal to the previous trace if the move was not accepted) and a `Bool` indicating whether the move was accepted or not..
99+
100+
The NUT sampler allows for sampling trajectories of dynamic lengths, removing the need to specify the length of the trajectory as a parameter.
101+
The sample will be returned early if the height of the sampled tree exceeds `max_treedepth`.
102+
103+
`sampler_statistics` is a struct containing the following fields:
104+
- depth: the depth of the trajectory tree
105+
- n: the number of samples in the trajectory tree
106+
- sum_alpha: the sum of the individual mh acceptance probabilities for each sample in the tree
107+
- n_accept: how many intermediate samples were accepted
108+
- accept: whether the sample was accepted or not
109+
110+
# References
111+
Betancourt, M. (2017). A Conceptual Introduction to Hamiltonian Monte Carlo. URL: https://doi.org/10.48550/arXiv.1701.02434
112+
Hoffman, M. D., & Gelman, A. (2022). The No-U-Turn Sampler: Adaptively Setting Path Lengths in Hamiltonian Monte Carlo. URL: https://arxiv.org/abs/1111.4246
113+
"""
114+
function nuts(
115+
trace::Trace, selection::Selection; eps=0.1, max_treedepth=15,
116+
check=false, observations=EmptyChoiceMap())
117+
prev_model_score = get_score(trace)
118+
retval_grad = accepts_output_grad(get_gen_fn(trace)) ? zero(get_retval(trace)) : nothing
119+
120+
# values needed for a leapfrog step
121+
(_, values_trie, _) = choice_gradients(trace, selection, retval_grad)
122+
123+
momenta = sample_momenta(length(to_array(values_trie, Float64)))
124+
momenta_trie = from_array(values_trie, momenta)
125+
prev_momenta_score = assess_momenta(momenta)
126+
127+
weight_init = prev_model_score + prev_momenta_score
128+
129+
integrator_state = (selection, retval_grad, trace)
130+
131+
tree = Tree(values_trie, momenta_trie, values_trie, momenta_trie, values_trie, 1, -Inf, false, false)
132+
133+
direction = 0
134+
depth = 0
135+
stop = false
136+
while depth < max_treedepth
137+
direction = rand([-1, 1])
138+
139+
if direction == 1 # going right
140+
other_tree = build_tree(tree.val_right, tree.momenta_right, depth, eps, direction,
141+
weight_init, integrator_state)
142+
tree = merge_trees(tree, other_tree)
143+
else # going left
144+
other_tree = build_tree(tree.val_left, tree.momenta_left, depth, eps, direction,
145+
weight_init, integrator_state)
146+
tree = merge_trees(other_tree, tree)
147+
end
148+
149+
stop = stop || tree.stop || tree.diverging
150+
if stop
151+
break
152+
end
153+
depth += 1
154+
end
155+
156+
(new_trace, _, _) = update(trace, tree.val_sample)
157+
check && check_observations(get_choices(new_trace), observations)
158+
159+
# assess new model score (negative potential energy)
160+
new_model_score = get_score(new_trace)
161+
162+
# assess new momenta score (negative kinetic energy)
163+
if direction == 1
164+
new_momenta_score = assess_momenta(to_array(tree.momenta_right, Float64))
165+
else
166+
new_momenta_score = assess_momenta(to_array(tree.momenta_left, Float64))
167+
end
168+
169+
# accept or reject
170+
alpha = new_model_score + new_momenta_score - weight_init
171+
if log(rand()) < alpha
172+
return (new_trace, true)
173+
else
174+
return (trace, false)
175+
end
176+
end
177+
178+
export nuts
179+

test/inference/nuts.jl

+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
@testset "nuts" begin
2+
3+
# smoke test a function without retval gradient
4+
@gen function foo()
5+
x = @trace(normal(0, 1), :x)
6+
return x
7+
end
8+
9+
(trace, _) = generate(foo, ())
10+
(new_trace, accepted) = nuts(trace, select(:x))
11+
12+
# smoke test a function with retval gradient
13+
@gen (grad) function foo()
14+
x = @trace(normal(0, 1), :x)
15+
return x
16+
end
17+
18+
(trace, _) = generate(foo, ())
19+
(new_trace, accepted) = nuts(trace, select(:x))
20+
end

0 commit comments

Comments
 (0)