|
| 1 | +using LinearAlgebra: dot |
| 2 | + |
| 3 | +struct Tree |
| 4 | + val_left :: ChoiceMap |
| 5 | + momenta_left :: ChoiceMap |
| 6 | + val_right :: ChoiceMap |
| 7 | + momenta_right :: ChoiceMap |
| 8 | + val_sample :: ChoiceMap |
| 9 | + n :: Int |
| 10 | + weight :: Float64 |
| 11 | + stop :: Bool |
| 12 | + diverging :: Bool |
| 13 | +end |
| 14 | + |
| 15 | +function u_turn(values_left, values_right, momenta_left, momenta_right) |
| 16 | + return (dot(values_left - values_right, momenta_right) >= 0) && |
| 17 | + (dot(values_right - values_left, momenta_left) >= 0) |
| 18 | +end |
| 19 | + |
| 20 | +function leapfrog(values_trie, momenta_trie, eps, integrator_state) |
| 21 | + selection, retval_grad, trace = integrator_state |
| 22 | + |
| 23 | + (trace, _, _) = update(trace, values_trie) |
| 24 | + (_, _, gradient_trie) = choice_gradients(trace, selection, retval_grad) |
| 25 | + |
| 26 | + # half step on momenta |
| 27 | + momenta_trie = add_choicemaps(momenta_trie, scale_choicemap(gradient_trie, eps / 2)) |
| 28 | + |
| 29 | + # full step on positions |
| 30 | + values_trie = add_choicemaps(values_trie, scale_choicemap(momenta_trie, eps)) |
| 31 | + |
| 32 | + # get new gradient |
| 33 | + (trace, _, _) = update(trace, values_trie) |
| 34 | + (_, _, gradient_trie) = choice_gradients(trace, selection, retval_grad) |
| 35 | + |
| 36 | + # half step on momenta |
| 37 | + momenta_trie = add_choicemaps(momenta_trie, scale_choicemap(gradient_trie, eps / 2)) |
| 38 | + return values_trie, momenta_trie, get_score(trace) |
| 39 | +end |
| 40 | + |
| 41 | +function build_root(val, momenta, eps, direction, weight_init, integrator_state) |
| 42 | + val, momenta, lp = leapfrog(val, momenta, direction * eps, integrator_state) |
| 43 | + weight = lp + assess_momenta(to_array(momenta, Float64)) |
| 44 | + |
| 45 | + diverging = weight - weight_init > 1000 |
| 46 | + |
| 47 | + return Tree(val, momenta, val, momenta, val, 1, weight, false, diverging) |
| 48 | +end |
| 49 | + |
| 50 | +function merge_trees(tree_left, tree_right) |
| 51 | + # multinomial sampling |
| 52 | + if log(rand()) < tree_right.weight - tree_left.weight |
| 53 | + sample = tree_right.val_sample |
| 54 | + else |
| 55 | + sample = tree_left.val_sample |
| 56 | + end |
| 57 | + |
| 58 | + weight = logsumexp(tree_left.weight, tree_right.weight) |
| 59 | + n = tree_left.n + tree_right.n |
| 60 | + |
| 61 | + stop = tree_left.stop || tree_right.stop || u_turn(to_array(tree_left.val_left, Float64), |
| 62 | + to_array(tree_right.val_right, Float64), |
| 63 | + to_array(tree_left.momenta_left, Float64), |
| 64 | + to_array(tree_right.momenta_right, Float64)) |
| 65 | + diverging = tree_left.diverging || tree_right.diverging |
| 66 | + |
| 67 | + return Tree(tree_left.val_left, tree_left.momenta_left, tree_right.val_right, |
| 68 | + tree_right.momenta_right, sample, n, weight, stop, diverging) |
| 69 | +end |
| 70 | + |
| 71 | +function build_tree(val, momenta, depth, eps, direction, weight_init, integrator_state) |
| 72 | + if depth == 0 |
| 73 | + return build_root(val, momenta, eps, direction, weight_init, integrator_state) |
| 74 | + end |
| 75 | + |
| 76 | + tree = build_tree(val, momenta, depth - 1, eps, direction, weight_init, integrator_state) |
| 77 | + |
| 78 | + if tree.stop || tree.diverging |
| 79 | + return tree |
| 80 | + end |
| 81 | + |
| 82 | + if direction == 1 |
| 83 | + other_tree = build_tree(tree.val_right, tree.momenta_right, depth - 1, eps, direction, |
| 84 | + weight_init, integrator_state) |
| 85 | + return merge_trees(tree, other_tree) |
| 86 | + else |
| 87 | + other_tree = build_tree(tree.val_left, tree.momenta_left, depth - 1, eps, direction, |
| 88 | + weight_init, integrator_state) |
| 89 | + return merge_trees(other_tree, tree) |
| 90 | + end |
| 91 | +end |
| 92 | + |
| 93 | +""" |
| 94 | + (new_trace, sampler_statistics) = nuts( |
| 95 | + trace, selection::Selection;eps=0.1, |
| 96 | + max_treedepth=15, check=false, observations=EmptyChoiceMap()) |
| 97 | +
|
| 98 | +Apply a Hamiltonian Monte Carlo (HMC) update with a No U Turn stopping criterion that proposes new values for the selected addresses, returning the new trace (which is equal to the previous trace if the move was not accepted) and a `Bool` indicating whether the move was accepted or not.. |
| 99 | +
|
| 100 | +The NUT sampler allows for sampling trajectories of dynamic lengths, removing the need to specify the length of the trajectory as a parameter. |
| 101 | +The sample will be returned early if the height of the sampled tree exceeds `max_treedepth`. |
| 102 | +
|
| 103 | +`sampler_statistics` is a struct containing the following fields: |
| 104 | + - depth: the depth of the trajectory tree |
| 105 | + - n: the number of samples in the trajectory tree |
| 106 | + - sum_alpha: the sum of the individual mh acceptance probabilities for each sample in the tree |
| 107 | + - n_accept: how many intermediate samples were accepted |
| 108 | + - accept: whether the sample was accepted or not |
| 109 | +
|
| 110 | +# References |
| 111 | +Betancourt, M. (2017). A Conceptual Introduction to Hamiltonian Monte Carlo. URL: https://doi.org/10.48550/arXiv.1701.02434 |
| 112 | +Hoffman, M. D., & Gelman, A. (2022). The No-U-Turn Sampler: Adaptively Setting Path Lengths in Hamiltonian Monte Carlo. URL: https://arxiv.org/abs/1111.4246 |
| 113 | +""" |
| 114 | +function nuts( |
| 115 | + trace::Trace, selection::Selection; eps=0.1, max_treedepth=15, |
| 116 | + check=false, observations=EmptyChoiceMap()) |
| 117 | + prev_model_score = get_score(trace) |
| 118 | + retval_grad = accepts_output_grad(get_gen_fn(trace)) ? zero(get_retval(trace)) : nothing |
| 119 | + |
| 120 | + # values needed for a leapfrog step |
| 121 | + (_, values_trie, _) = choice_gradients(trace, selection, retval_grad) |
| 122 | + |
| 123 | + momenta = sample_momenta(length(to_array(values_trie, Float64))) |
| 124 | + momenta_trie = from_array(values_trie, momenta) |
| 125 | + prev_momenta_score = assess_momenta(momenta) |
| 126 | + |
| 127 | + weight_init = prev_model_score + prev_momenta_score |
| 128 | + |
| 129 | + integrator_state = (selection, retval_grad, trace) |
| 130 | + |
| 131 | + tree = Tree(values_trie, momenta_trie, values_trie, momenta_trie, values_trie, 1, -Inf, false, false) |
| 132 | + |
| 133 | + direction = 0 |
| 134 | + depth = 0 |
| 135 | + stop = false |
| 136 | + while depth < max_treedepth |
| 137 | + direction = rand([-1, 1]) |
| 138 | + |
| 139 | + if direction == 1 # going right |
| 140 | + other_tree = build_tree(tree.val_right, tree.momenta_right, depth, eps, direction, |
| 141 | + weight_init, integrator_state) |
| 142 | + tree = merge_trees(tree, other_tree) |
| 143 | + else # going left |
| 144 | + other_tree = build_tree(tree.val_left, tree.momenta_left, depth, eps, direction, |
| 145 | + weight_init, integrator_state) |
| 146 | + tree = merge_trees(other_tree, tree) |
| 147 | + end |
| 148 | + |
| 149 | + stop = stop || tree.stop || tree.diverging |
| 150 | + if stop |
| 151 | + break |
| 152 | + end |
| 153 | + depth += 1 |
| 154 | + end |
| 155 | + |
| 156 | + (new_trace, _, _) = update(trace, tree.val_sample) |
| 157 | + check && check_observations(get_choices(new_trace), observations) |
| 158 | + |
| 159 | + # assess new model score (negative potential energy) |
| 160 | + new_model_score = get_score(new_trace) |
| 161 | + |
| 162 | + # assess new momenta score (negative kinetic energy) |
| 163 | + if direction == 1 |
| 164 | + new_momenta_score = assess_momenta(to_array(tree.momenta_right, Float64)) |
| 165 | + else |
| 166 | + new_momenta_score = assess_momenta(to_array(tree.momenta_left, Float64)) |
| 167 | + end |
| 168 | + |
| 169 | + # accept or reject |
| 170 | + alpha = new_model_score + new_momenta_score - weight_init |
| 171 | + if log(rand()) < alpha |
| 172 | + return (new_trace, true) |
| 173 | + else |
| 174 | + return (trace, false) |
| 175 | + end |
| 176 | +end |
| 177 | + |
| 178 | +export nuts |
| 179 | + |
0 commit comments