-
Notifications
You must be signed in to change notification settings - Fork 118
Add low-rank-modified metric #684
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -28,7 +28,8 @@ | |||||
We can also generate a relativistic dynamic :cite:p:`lu2017relativistic`. | ||||||
|
||||||
""" | ||||||
from typing import Callable, NamedTuple, Optional, Protocol, Union | ||||||
|
||||||
from typing import Any, Callable, NamedTuple, Optional, Protocol, Union | ||||||
|
||||||
import jax.numpy as jnp | ||||||
import jax.scipy as jscipy | ||||||
|
@@ -38,14 +39,18 @@ | |||||
from blackjax.types import Array, ArrayLikeTree, ArrayTree, PRNGKey | ||||||
from blackjax.util import generate_gaussian_noise | ||||||
|
||||||
__all__ = ["default_metric", "gaussian_euclidean", "gaussian_riemannian"] | ||||||
__all__ = [ | ||||||
"default_metric", | ||||||
"gaussian_euclidean", | ||||||
"gaussian_riemannian", | ||||||
"gaussian_euclidean_low_rank", | ||||||
] | ||||||
|
||||||
|
||||||
class KineticEnergy(Protocol): | ||||||
def __call__( | ||||||
self, momentum: ArrayLikeTree, position: Optional[ArrayLikeTree] = None | ||||||
) -> float: | ||||||
... | ||||||
) -> float: ... | ||||||
|
||||||
|
||||||
class CheckTurning(Protocol): | ||||||
|
@@ -56,14 +61,14 @@ def __call__( | |||||
momentum_sum: ArrayLikeTree, | ||||||
position_left: Optional[ArrayLikeTree] = None, | ||||||
position_right: Optional[ArrayLikeTree] = None, | ||||||
) -> bool: | ||||||
... | ||||||
) -> bool: ... | ||||||
|
||||||
|
||||||
class Metric(NamedTuple): | ||||||
sample_momentum: Callable[[PRNGKey, ArrayLikeTree], ArrayLikeTree] | ||||||
kinetic_energy: KineticEnergy | ||||||
check_turning: CheckTurning | ||||||
data: Any = None | ||||||
|
||||||
|
||||||
MetricTypes = Union[Metric, Array, Callable[[ArrayLikeTree], Array]] | ||||||
|
@@ -208,6 +213,120 @@ def is_turning( | |||||
return Metric(momentum_generator, kinetic_energy, is_turning) | ||||||
|
||||||
|
||||||
def gaussian_euclidean_low_rank( | ||||||
diagonal_scale_std: Array, | ||||||
eigenvectors: Array, | ||||||
eigenvalues: Array, | ||||||
) -> Metric: | ||||||
r"""Hamiltonian dynamic on euclidean manifold with normally-distributed momentum | ||||||
:cite:p:`betancourt2013general`. | ||||||
|
||||||
The gaussian euclidean metric is a euclidean metric further characterized | ||||||
by setting the conditional probability density :math:`\pi(momentum|position)` | ||||||
to follow a standard gaussian distribution. A Newtonian hamiltonian | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
dynamics is assumed. | ||||||
|
||||||
This uses the mass matrix $(D^{-1}(V(\Sigma - I)V^T + I)D^{-1})^{-1}$. | ||||||
|
||||||
Parameters | ||||||
---------- | ||||||
diagonal_scale_std | ||||||
The diagonal $D^{-1}$. This should for instance correspond to the standard deviation | ||||||
of the posterior. | ||||||
eigenvectors | ||||||
An arbitrary number of eigenvectors | ||||||
eigenvalues | ||||||
The corresponding eigenvalues | ||||||
|
||||||
Returns | ||||||
------- | ||||||
momentum_generator | ||||||
A function that generates a value for the momentum at random. | ||||||
kinetic_energy | ||||||
A function that returns the kinetic energy given the momentum. | ||||||
is_turning | ||||||
A function that determines whether a trajectory is turning back on | ||||||
itself given the values of the momentum along the trajectory. | ||||||
|
||||||
""" | ||||||
(ndim,) = jnp.shape(diagonal_scale_std) | ||||||
(ndim_, n_eigs) = jnp.shape(eigenvectors) | ||||||
if ndim != ndim_: | ||||||
raise ValueError("Shape mismatch in metric.") | ||||||
|
||||||
(n_eigs_,) = jnp.shape(eigenvalues) | ||||||
if n_eigs != n_eigs_: | ||||||
raise ValueError("Shape mismatch in metric.") | ||||||
|
||||||
# Compute (V(\Sigma - I)V^T + I)x | ||||||
def inner_matrix_mult(vals, vecs, x): | ||||||
projected = x @ vecs | ||||||
scaled = (vals - 1) * projected | ||||||
projected_back = vecs @ scaled | ||||||
return projected_back + x | ||||||
|
||||||
def inv_mass_matrix_mult(x): | ||||||
scaled = x * diagonal_scale_std | ||||||
product = inner_matrix_mult(eigenvalues, eigenvectors, scaled) | ||||||
return product * diagonal_scale_std | ||||||
|
||||||
def momentum_generator(rng_key: PRNGKey, position: ArrayLikeTree) -> ArrayTree: | ||||||
unit_draws = generate_gaussian_noise(rng_key, position) | ||||||
sqrt_vals = jnp.sqrt(jnp.reciprocal(eigenvalues)) | ||||||
sqrt_inv_diag = jnp.sqrt(jnp.reciprocal(diagonal_scale_std)) | ||||||
return inner_matrix_mult(sqrt_vals, eigenvectors, unit_draws) * sqrt_inv_diag | ||||||
|
||||||
def kinetic_energy( | ||||||
momentum: ArrayLikeTree, position: Optional[ArrayLikeTree] = None | ||||||
) -> float: | ||||||
del position | ||||||
momentum, _ = ravel_pytree(momentum) | ||||||
velocity = inv_mass_matrix_mult(momentum) | ||||||
kinetic_energy_val = 0.5 * jnp.dot(velocity, momentum) | ||||||
return kinetic_energy_val | ||||||
|
||||||
def is_turning( | ||||||
momentum_left: ArrayLikeTree, | ||||||
momentum_right: ArrayLikeTree, | ||||||
momentum_sum: ArrayLikeTree, | ||||||
position_left: Optional[ArrayLikeTree] = None, | ||||||
position_right: Optional[ArrayLikeTree] = None, | ||||||
) -> bool: | ||||||
"""Generalized U-turn criterion :cite:p:`betancourt2013generalizing,nuts_uturn`. | ||||||
|
||||||
Parameters | ||||||
---------- | ||||||
momentum_left | ||||||
Momentum of the leftmost point of the trajectory. | ||||||
momentum_right | ||||||
Momentum of the rightmost point of the trajectory. | ||||||
momentum_sum | ||||||
Sum of the momenta along the trajectory. | ||||||
|
||||||
""" | ||||||
del position_left, position_right | ||||||
|
||||||
m_left, _ = ravel_pytree(momentum_left) | ||||||
m_right, _ = ravel_pytree(momentum_right) | ||||||
m_sum, _ = ravel_pytree(momentum_sum) | ||||||
|
||||||
velocity_left = inv_mass_matrix_mult(m_left) | ||||||
velocity_right = inv_mass_matrix_mult(m_right) | ||||||
|
||||||
# rho = m_sum | ||||||
rho = m_sum - (m_right + m_left) / 2 | ||||||
turning_at_left = jnp.dot(velocity_left, rho) <= 0 | ||||||
turning_at_right = jnp.dot(velocity_right, rho) <= 0 | ||||||
return turning_at_left | turning_at_right | ||||||
|
||||||
return Metric( | ||||||
momentum_generator, | ||||||
kinetic_energy, | ||||||
is_turning, | ||||||
data=(diagonal_scale_std, eigenvalues, eigenvectors), | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess you want to store these as properties so it is easier for tuning later on? In any case I suggest removing it here, and add it in the subsequent PR when tuning is introduced (so we can discuss whether it is necessary) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, that was mostly for debugging and I forgot to take it out. Sorry :-) |
||||||
) | ||||||
|
||||||
|
||||||
def gaussian_riemannian( | ||||||
mass_matrix_fn: Callable, | ||||||
) -> Metric: | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.