From c05c94ec03f5d17808c9459fd1f877b74bba14a7 Mon Sep 17 00:00:00 2001 From: Michal-Novomestsky Date: Wed, 2 Jul 2025 17:11:28 +1000 Subject: [PATCH 1/7] changed laplace approx to return MvNormal --- pymc_extras/inference/laplace.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc_extras/inference/laplace.py b/pymc_extras/inference/laplace.py index d64d2ada..1b9f3048 100644 --- a/pymc_extras/inference/laplace.py +++ b/pymc_extras/inference/laplace.py @@ -149,7 +149,7 @@ def get_conditional_gaussian_approximation( # Currently x is passed both as the query point for f(x, args) = logp(x | y, params) AND as an initial guess for x0. This may cause issues if the query point is # far from the mode x0 or in a neighbourhood which results in poor convergence. - return pytensor.function(args, [x0, conditional_gaussian_approx]) + return pytensor.function(args, pm.MvNormal(mu=x0, tau=Q-hess)) def laplace_draws_to_inferencedata( From 37fa1edc96f0d7cada50ec1773279209b4614738 Mon Sep 17 00:00:00 2001 From: Michal-Novomestsky Date: Wed, 2 Jul 2025 17:28:48 +1000 Subject: [PATCH 2/7] added seperate line for evaluating Q-hess --- pymc_extras/inference/laplace.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/pymc_extras/inference/laplace.py b/pymc_extras/inference/laplace.py index 1b9f3048..996f5ac0 100644 --- a/pymc_extras/inference/laplace.py +++ b/pymc_extras/inference/laplace.py @@ -143,13 +143,16 @@ def get_conditional_gaussian_approximation( # Full log(p(x | y, params)) using the Laplace approximation (up to a constant) _, logdetQ = pt.nlinalg.slogdet(Q) - conditional_gaussian_approx = ( - -0.5 * x.T @ (-hess + Q) @ x + x.T @ (Q @ mu + jac - hess @ x0) + 0.5 * logdetQ - ) + # conditional_gaussian_approx = ( + # -0.5 * x.T @ (-hess + Q) @ x + x.T @ (Q @ mu + jac - hess @ x0) + 0.5 * logdetQ + # ) + + # In the future, this could be made more efficient with only adding the diagonal of -hess + tau = Q - hess # Currently x is passed both as the query point for f(x, args) = logp(x | y, params) AND as an initial guess for x0. This may cause issues if the query point is # far from the mode x0 or in a neighbourhood which results in poor convergence. - return pytensor.function(args, pm.MvNormal(mu=x0, tau=Q-hess)) + return pytensor.function(args, [x0, pm.MvNormal(mu=x0, tau=tau)]) def laplace_draws_to_inferencedata( From 83bef75972e67249b10c3e7b66afbf4160414774 Mon Sep 17 00:00:00 2001 From: Michal-Novomestsky Date: Fri, 4 Jul 2025 16:15:36 +1000 Subject: [PATCH 3/7] WIP: minor refactor --- pymc_extras/inference/laplace.py | 20 ++++++++++++++------ tests/test_laplace.py | 2 +- 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/pymc_extras/inference/laplace.py b/pymc_extras/inference/laplace.py index 996f5ac0..11858c27 100644 --- a/pymc_extras/inference/laplace.py +++ b/pymc_extras/inference/laplace.py @@ -121,11 +121,13 @@ def get_conditional_gaussian_approximation( # f = log(p(y | x, params)) f_x = model.logp() - jac = pytensor.gradient.grad(f_x, x) - hess = pytensor.gradient.jacobian(jac.flatten(), x) + # jac = pytensor.gradient.grad(f_x, x) + # hess = pytensor.gradient.jacobian(jac.flatten(), x) # log(p(x | y, params)) only including terms that depend on x for the minimization step (logdet(Q) ignored as it is a constant wrt x) - log_x_posterior = f_x - 0.5 * (x - mu).T @ Q @ (x - mu) + log_x_posterior = f_x - 0.5 * (x - mu).T @ Q @ ( + x - mu + ) # TODO could be f + x.logp - IS X.LOGP DUPLICATED IN F? # Maximize log(p(x | y, params)) wrt x to find mode x0 x0, _ = minimize( @@ -138,11 +140,13 @@ def get_conditional_gaussian_approximation( ) # require f'(x0) and f''(x0) for Laplace approx - jac = pytensor.graph.replace.graph_replace(jac, {x: x0}) + # jac = pytensor.graph.replace.graph_replace(jac, {x: x0}) + jac = pytensor.gradient.grad(f_x, x) + hess = pytensor.gradient.jacobian(jac.flatten(), x) hess = pytensor.graph.replace.graph_replace(hess, {x: x0}) # Full log(p(x | y, params)) using the Laplace approximation (up to a constant) - _, logdetQ = pt.nlinalg.slogdet(Q) + # _, logdetQ = pt.nlinalg.slogdet(Q) # conditional_gaussian_approx = ( # -0.5 * x.T @ (-hess + Q) @ x + x.T @ (Q @ mu + jac - hess @ x0) + 0.5 * logdetQ # ) @@ -152,7 +156,11 @@ def get_conditional_gaussian_approximation( # Currently x is passed both as the query point for f(x, args) = logp(x | y, params) AND as an initial guess for x0. This may cause issues if the query point is # far from the mode x0 or in a neighbourhood which results in poor convergence. - return pytensor.function(args, [x0, pm.MvNormal(mu=x0, tau=tau)]) + return ( + x0, + pm.MvNormal(f"{x.name}_laplace_approx", mu=x0, tau=tau), + tau, + ) # pytensor.function(args, [x0, pm.MvNormal(mu=x0, tau=tau)]) def laplace_draws_to_inferencedata( diff --git a/tests/test_laplace.py b/tests/test_laplace.py index 72ff3e93..9aafa6e8 100644 --- a/tests/test_laplace.py +++ b/tests/test_laplace.py @@ -323,7 +323,7 @@ def test_get_conditional_gaussian_approximation(): Q = pm.MvNormal("Q", mu=Q_mu, cov=Q_cov) # Pytensor currently doesn't support autograd for pt inverses, so we use a numeric Q instead - x = pm.MvNormal("x", mu=mu_param, cov=np.linalg.inv(Q_val)) + x = pm.MvNormal("x", mu=mu_param, tau=Q) # cov=np.linalg.inv(Q_val)) y = pm.MvNormal( "y", From 13961fab907209b3d1b7d00907f3ab5c62288b9a Mon Sep 17 00:00:00 2001 From: Michal Novomestsky Date: Sun, 6 Jul 2025 21:11:37 +1000 Subject: [PATCH 4/7] started writing fit_INLA routine --- pymc_extras/inference/fit.py | 12 ++- pymc_extras/inference/inla.py | 164 +++++++++++++++++++++++++++++++ pymc_extras/inference/laplace.py | 111 --------------------- tests/test_inla.py | 105 ++++++++++++++++++++ tests/test_laplace.py | 83 ---------------- 5 files changed, 280 insertions(+), 195 deletions(-) create mode 100644 pymc_extras/inference/inla.py create mode 100644 tests/test_inla.py diff --git a/pymc_extras/inference/fit.py b/pymc_extras/inference/fit.py index 5b83ff1f..8814ba3b 100644 --- a/pymc_extras/inference/fit.py +++ b/pymc_extras/inference/fit.py @@ -36,7 +36,17 @@ def fit(method: str, **kwargs) -> az.InferenceData: return fit_pathfinder(**kwargs) - if method == "laplace": + elif method == "laplace": from pymc_extras.inference.laplace import fit_laplace return fit_laplace(**kwargs) + + elif method == "INLA": + from pymc_extras.inference.laplace import fit_INLA + + return fit_INLA(**kwargs) + + else: + raise ValueError( + f"method '{method}' not supported. Use one of 'pathfinder', 'laplace' or 'INLA'." + ) diff --git a/pymc_extras/inference/inla.py b/pymc_extras/inference/inla.py new file mode 100644 index 00000000..0eaa649c --- /dev/null +++ b/pymc_extras/inference/inla.py @@ -0,0 +1,164 @@ +import arviz as az +import numpy as np +import pymc as pm +import pytensor +import pytensor.tensor as pt + +from better_optimize.constants import minimize_method +from numpy.typing import ArrayLike +from pytensor.tensor import TensorVariable +from pytensor.tensor.optimize import minimize + + +def get_conditional_gaussian_approximation( + x: TensorVariable, + Q: TensorVariable | ArrayLike, + mu: TensorVariable | ArrayLike, + model: pm.Model | None = None, + method: minimize_method = "BFGS", + use_jac: bool = True, + use_hess: bool = False, + optimizer_kwargs: dict | None = None, +) -> list[TensorVariable]: + """ + Returns an estimate the a posteriori probability of a latent Gaussian field x and its mode x0 using the Laplace approximation. + + That is: + y | x, sigma ~ N(Ax, sigma^2 W) + x | params ~ N(mu, Q(params)^-1) + + We seek to estimate p(x | y, params) with a Gaussian: + + log(p(x | y, params)) = log(p(y | x, params)) + log(p(x | params)) + const + + Let f(x) = log(p(y | x, params)). From the definition of our model above, we have log(p(x | params)) = -0.5*(x - mu).T Q (x - mu) + 0.5*logdet(Q). + + This gives log(p(x | y, params)) = f(x) - 0.5*(x - mu).T Q (x - mu) + 0.5*logdet(Q). We will estimate this using the Laplace approximation by Taylor expanding f(x) about the mode. + + Thus: + + 1. Maximize log(p(x | y, params)) = f(x) - 0.5*(x - mu).T Q (x - mu) wrt x (note that logdet(Q) does not depend on x) to find the mode x0. + + 2. Use the Laplace approximation expanded about the mode: p(x | y, params) ~= N(mu=x0, tau=Q - f''(x0)). + + Parameters + ---------- + x: TensorVariable + The parameter with which to maximize wrt (that is, find the mode in x). In INLA, this is the latent Gaussian field x~N(mu,Q^-1). + Q: TensorVariable | ArrayLike + The precision matrix of the latent field x. + mu: TensorVariable | ArrayLike + The mean of the latent field x. + model: Model + PyMC model to use. + method: minimize_method + Which minimization algorithm to use. + use_jac: bool + If true, the minimizer will compute the gradient of log(p(x | y, params)). + use_hess: bool + If true, the minimizer will compute the Hessian log(p(x | y, params)). + optimizer_kwargs: dict + Kwargs to pass to scipy.optimize.minimize. + + Returns + ------- + x0, p(x | y, params): list[TensorVariable] + Mode and Laplace approximation for posterior. + """ + model = pm.modelcontext(model) + + # f = log(p(y | x, params)) + f_x = model.logp() + + # log(p(x | y, params)) only including terms that depend on x for the minimization step (logdet(Q) ignored as it is a constant wrt x) + log_x_posterior = f_x - 0.5 * (x - mu).T @ Q @ (x - mu) + + # Maximize log(p(x | y, params)) wrt x to find mode x0 + x0, _ = minimize( + objective=-log_x_posterior, + x=x, + method=method, + jac=use_jac, + hess=use_hess, + optimizer_kwargs=optimizer_kwargs, + ) + + # require f''(x0) for Laplace approx + hess = pytensor.gradient.hessian(f_x, x) + hess = pytensor.graph.replace.graph_replace(hess, {x: x0}) + + # Could be made more efficient with adding diagonals only + tau = Q - hess + + # Currently x is passed both as the query point for f(x, args) = logp(x | y, params) AND as an initial guess for x0. This may cause issues if the query point is + # far from the mode x0 or in a neighbourhood which results in poor convergence. + return x0, pm.MvNormal(f"{x.name}_laplace_approx", mu=x0, tau=tau) + + +def get_log_marginal_likelihood( + x: TensorVariable, + Q: TensorVariable | ArrayLike, + mu: TensorVariable | ArrayLike, + model: pm.Model | None = None, + method: minimize_method = "BFGS", + use_jac: bool = True, + use_hess: bool = False, + optimizer_kwargs: dict | None = None, +) -> TensorVariable: + model = pm.modelcontext(model) + + x0, laplace_approx = get_conditional_gaussian_approximation( + x, Q, mu, model, method, use_jac, use_hess, optimizer_kwargs + ) + log_laplace_approx = pm.logp(laplace_approx, model.rvs_to_values[x]) + + _, logdetQ = pt.nlinalg.slogdet(Q) + log_x_likelihood = ( + -0.5 * (x - mu).T @ Q @ (x - mu) + 0.5 * logdetQ - 0.5 * x.shape[0] * np.log(2 * np.pi) + ) + + log_likelihood = ( # logp(y | params) = + model.logp() # logp(y | x, params) + + log_x_likelihood # * logp(x | params) + - log_laplace_approx # / logp(x | y, params) + ) + + return log_likelihood + + +def fit_INLA( + x: TensorVariable, + Q: TensorVariable | ArrayLike, + mu: TensorVariable | ArrayLike, + model: pm.Model | None = None, + method: minimize_method = "BFGS", + use_jac: bool = True, + use_hess: bool = False, + optimizer_kwargs: dict | None = None, +) -> az.InferenceData: + model = pm.modelcontext(model) + + # logp(y | params) + log_likelihood = get_log_marginal_likelihood( + x, Q, mu, model, method, use_jac, use_hess, optimizer_kwargs + ) + + # TODO How to obtain prior? It can parametrise Q, mu, y, etc. Not sure if we could extract from model.logp somehow. Otherwise simply specify as a user input + prior = None + params = None + log_prior = pm.logp(prior, model.rvs_to_values[params]) + + # logp(params | y) = logp(y | params) + logp(params) + const + log_posterior = log_likelihood + log_prior + + # TODO log_marginal_x_likelihood is almost the same as log_likelihood, but need to do some sampling? + log_marginal_x_likelihood = None + log_marginal_x_posterior = log_marginal_x_likelihood + log_prior + + # TODO can we sample over log likelihoods? + # Marginalize params + idata_params = log_posterior.sample() # TODO something like NUTS, QMC, etc.? + idata_x = log_marginal_x_posterior.sample() + + # Bundle up idatas somehow + return idata_params, idata_x diff --git a/pymc_extras/inference/laplace.py b/pymc_extras/inference/laplace.py index 11858c27..9c0a0d27 100644 --- a/pymc_extras/inference/laplace.py +++ b/pymc_extras/inference/laplace.py @@ -15,7 +15,6 @@ import logging -from collections.abc import Callable from functools import reduce from importlib.util import find_spec from itertools import product @@ -30,7 +29,6 @@ from arviz import dict_to_dataset from better_optimize.constants import minimize_method -from numpy.typing import ArrayLike from pymc import DictToArrayBijection from pymc.backends.arviz import ( coords_and_dims_for_inferencedata, @@ -41,8 +39,6 @@ from pymc.model.transform.conditioning import remove_value_transforms from pymc.model.transform.optimization import freeze_dims_and_data from pymc.util import get_default_varnames -from pytensor.tensor import TensorVariable -from pytensor.tensor.optimize import minimize from scipy import stats from pymc_extras.inference.find_map import ( @@ -56,113 +52,6 @@ _log = logging.getLogger(__name__) -def get_conditional_gaussian_approximation( - x: TensorVariable, - Q: TensorVariable | ArrayLike, - mu: TensorVariable | ArrayLike, - args: list[TensorVariable] | None = None, - model: pm.Model | None = None, - method: minimize_method = "BFGS", - use_jac: bool = True, - use_hess: bool = False, - optimizer_kwargs: dict | None = None, -) -> Callable: - """ - Returns a function to estimate the a posteriori log probability of a latent Gaussian field x and its mode x0 using the Laplace approximation. - - That is: - y | x, sigma ~ N(Ax, sigma^2 W) - x | params ~ N(mu, Q(params)^-1) - - We seek to estimate log(p(x | y, params)): - - log(p(x | y, params)) = log(p(y | x, params)) + log(p(x | params)) + const - - Let f(x) = log(p(y | x, params)). From the definition of our model above, we have log(p(x | params)) = -0.5*(x - mu).T Q (x - mu) + 0.5*logdet(Q). - - This gives log(p(x | y, params)) = f(x) - 0.5*(x - mu).T Q (x - mu) + 0.5*logdet(Q). We will estimate this using the Laplace approximation by Taylor expanding f(x) about the mode. - - Thus: - - 1. Maximize log(p(x | y, params)) = f(x) - 0.5*(x - mu).T Q (x - mu) wrt x (note that logdet(Q) does not depend on x) to find the mode x0. - - 2. Substitute x0 into the Laplace approximation expanded about the mode: log(p(x | y, params)) ~= -0.5*x.T (-f''(x0) + Q) x + x.T (Q.mu + f'(x0) - f''(x0).x0) + 0.5*logdet(Q). - - Parameters - ---------- - x: TensorVariable - The parameter with which to maximize wrt (that is, find the mode in x). In INLA, this is the latent field x~N(mu,Q^-1). - Q: TensorVariable | ArrayLike - The precision matrix of the latent field x. - mu: TensorVariable | ArrayLike - The mean of the latent field x. - args: list[TensorVariable] - Args to supply to the compiled function. That is, (x0, logp) = f(x, *args). If set to None, assumes the model RVs are args. - model: Model - PyMC model to use. - method: minimize_method - Which minimization algorithm to use. - use_jac: bool - If true, the minimizer will compute the gradient of log(p(x | y, params)). - use_hess: bool - If true, the minimizer will compute the Hessian log(p(x | y, params)). - optimizer_kwargs: dict - Kwargs to pass to scipy.optimize.minimize. - - Returns - ------- - f: Callable - A function which accepts a value of x and args and returns [x0, log(p(x | y, params))], where x0 is the mode. x is currently both the point at which to evaluate logp and the initial guess for the minimizer. - """ - model = pm.modelcontext(model) - - if args is None: - args = model.continuous_value_vars + model.discrete_value_vars - - # f = log(p(y | x, params)) - f_x = model.logp() - # jac = pytensor.gradient.grad(f_x, x) - # hess = pytensor.gradient.jacobian(jac.flatten(), x) - - # log(p(x | y, params)) only including terms that depend on x for the minimization step (logdet(Q) ignored as it is a constant wrt x) - log_x_posterior = f_x - 0.5 * (x - mu).T @ Q @ ( - x - mu - ) # TODO could be f + x.logp - IS X.LOGP DUPLICATED IN F? - - # Maximize log(p(x | y, params)) wrt x to find mode x0 - x0, _ = minimize( - objective=-log_x_posterior, - x=x, - method=method, - jac=use_jac, - hess=use_hess, - optimizer_kwargs=optimizer_kwargs, - ) - - # require f'(x0) and f''(x0) for Laplace approx - # jac = pytensor.graph.replace.graph_replace(jac, {x: x0}) - jac = pytensor.gradient.grad(f_x, x) - hess = pytensor.gradient.jacobian(jac.flatten(), x) - hess = pytensor.graph.replace.graph_replace(hess, {x: x0}) - - # Full log(p(x | y, params)) using the Laplace approximation (up to a constant) - # _, logdetQ = pt.nlinalg.slogdet(Q) - # conditional_gaussian_approx = ( - # -0.5 * x.T @ (-hess + Q) @ x + x.T @ (Q @ mu + jac - hess @ x0) + 0.5 * logdetQ - # ) - - # In the future, this could be made more efficient with only adding the diagonal of -hess - tau = Q - hess - - # Currently x is passed both as the query point for f(x, args) = logp(x | y, params) AND as an initial guess for x0. This may cause issues if the query point is - # far from the mode x0 or in a neighbourhood which results in poor convergence. - return ( - x0, - pm.MvNormal(f"{x.name}_laplace_approx", mu=x0, tau=tau), - tau, - ) # pytensor.function(args, [x0, pm.MvNormal(mu=x0, tau=tau)]) - - def laplace_draws_to_inferencedata( posterior_draws: list[np.ndarray[float | int]], model: pm.Model | None = None ) -> az.InferenceData: diff --git a/tests/test_inla.py b/tests/test_inla.py new file mode 100644 index 00000000..cda24bb5 --- /dev/null +++ b/tests/test_inla.py @@ -0,0 +1,105 @@ +# Copyright 2024 The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import numpy as np +import pymc as pm +import pytensor + +from pymc_extras.inference.inla import get_conditional_gaussian_approximation + + +def test_get_conditional_gaussian_approximation(): + """ + Consider the trivial case of: + + y | x ~ N(x, cov_param) + x | param ~ N(mu_param, Q^-1) + + cov_param ~ N(cov_mu, cov_cov) + mu_param ~ N(mu_mu, mu_cov) + Q ~ N(Q_mu, Q_cov) + + This has an analytic solution at the mode which we can compare against. + """ + rng = np.random.default_rng(12345) + n = 10000 + d = 10 + + # Initialise arrays + mu_true = rng.random(d) + cov_true = np.diag(rng.random(d)) + Q_val = np.diag(rng.random(d)) + cov_param_val = np.diag(rng.random(d)) + + x_val = rng.random(d) + mu_val = rng.random(d) + + mu_mu = rng.random(d) + mu_cov = np.diag(np.ones(d)) + cov_mu = rng.random(d**2) + cov_cov = np.diag(np.ones(d**2)) + Q_mu = rng.random(d**2) + Q_cov = np.diag(np.ones(d**2)) + + with pm.Model() as model: + y_obs = rng.multivariate_normal(mean=mu_true, cov=cov_true, size=n) + + mu_param = pm.MvNormal("mu_param", mu=mu_mu, cov=mu_cov) + cov_param = pm.MvNormal("cov_param", mu=cov_mu, cov=cov_cov) + Q = pm.MvNormal("Q", mu=Q_mu, cov=Q_cov) + + x = pm.MvNormal("x", mu=mu_param, tau=Q_val) + + y = pm.MvNormal( + "y", + mu=x, + cov=cov_param.reshape((d, d)), + observed=y_obs, + ) + + args = model.continuous_value_vars + model.discrete_value_vars + + # logp(x | y, params) + x0, x_g = get_conditional_gaussian_approximation( + x=model.rvs_to_values[x], + Q=Q.reshape((d, d)), + mu=mu_param, + optimizer_kwargs={"tol": 1e-25}, + ) + + cga = pytensor.function(args, [x0, pm.logp(x_g, model.rvs_to_values[x])]) + + x0, log_x_posterior = cga( + x=x_val, mu_param=mu_val, cov_param=cov_param_val.flatten(), Q=Q_val.flatten() + ) + + # Get analytic values of the mode and Laplace-approximated log posterior + cov_param_inv = np.linalg.inv(cov_param_val) + + x0_true = np.linalg.inv(n * cov_param_inv + 2 * Q_val) @ ( + cov_param_inv @ y_obs.sum(axis=0) + 2 * Q_val @ mu_val + ) + + hess_true = -n * cov_param_inv - Q_val + tau_true = Q_val - hess_true + + log_x_taylor = ( + -0.5 * (x_val - x0_true).T @ tau_true @ (x_val - x0_true) + + 0.5 * np.log(np.linalg.det(tau_true)) + - 0.5 * d * np.log(2 * np.pi) + ) + + np.testing.assert_allclose(x0, x0_true, atol=0.1, rtol=0.1) + np.testing.assert_allclose(log_x_posterior, log_x_taylor, atol=0.1, rtol=0.1) diff --git a/tests/test_laplace.py b/tests/test_laplace.py index 9aafa6e8..8f7a4c01 100644 --- a/tests/test_laplace.py +++ b/tests/test_laplace.py @@ -23,7 +23,6 @@ from pymc_extras.inference.laplace import ( fit_laplace, fit_mvn_at_MAP, - get_conditional_gaussian_approximation, sample_laplace_posterior, ) @@ -280,85 +279,3 @@ def test_laplace_scalar(): assert idata_laplace.fit.covariance_matrix.shape == (1, 1) np.testing.assert_allclose(idata_laplace.fit.mean_vector.values.item(), data.mean(), atol=0.1) - - -def test_get_conditional_gaussian_approximation(): - """ - Consider the trivial case of: - - y | x ~ N(x, cov_param) - x | param ~ N(mu_param, Q^-1) - - cov_param ~ N(cov_mu, cov_cov) - mu_param ~ N(mu_mu, mu_cov) - Q ~ N(Q_mu, Q_cov) - - This has an analytic solution at the mode which we can compare against. - """ - rng = np.random.default_rng(12345) - n = 10000 - d = 10 - - # Initialise arrays - mu_true = rng.random(d) - cov_true = np.diag(rng.random(d)) - Q_val = np.diag(rng.random(d)) - cov_param_val = np.diag(rng.random(d)) - - x_val = rng.random(d) - mu_val = rng.random(d) - - mu_mu = rng.random(d) - mu_cov = np.diag(np.ones(d)) - cov_mu = rng.random(d**2) - cov_cov = np.diag(np.ones(d**2)) - Q_mu = rng.random(d**2) - Q_cov = np.diag(np.ones(d**2)) - - with pm.Model() as model: - y_obs = rng.multivariate_normal(mean=mu_true, cov=cov_true, size=n) - - mu_param = pm.MvNormal("mu_param", mu=mu_mu, cov=mu_cov) - cov_param = pm.MvNormal("cov_param", mu=cov_mu, cov=cov_cov) - Q = pm.MvNormal("Q", mu=Q_mu, cov=Q_cov) - - # Pytensor currently doesn't support autograd for pt inverses, so we use a numeric Q instead - x = pm.MvNormal("x", mu=mu_param, tau=Q) # cov=np.linalg.inv(Q_val)) - - y = pm.MvNormal( - "y", - mu=x, - cov=cov_param.reshape((d, d)), - observed=y_obs, - ) - - # logp(x | y, params) - cga = get_conditional_gaussian_approximation( - x=model.rvs_to_values[x], - Q=Q.reshape((d, d)), - mu=mu_param, - optimizer_kwargs={"tol": 1e-25}, - ) - - x0, log_x_posterior = cga( - x=x_val, mu_param=mu_val, cov_param=cov_param_val.flatten(), Q=Q_val.flatten() - ) - - # Get analytic values of the mode and Laplace-approximated log posterior - cov_param_inv = np.linalg.inv(cov_param_val) - - x0_true = np.linalg.inv(n * cov_param_inv + 2 * Q_val) @ ( - cov_param_inv @ y_obs.sum(axis=0) + 2 * Q_val @ mu_val - ) - - jac_true = cov_param_inv @ (y_obs - x0_true).sum(axis=0) - Q_val @ (x0_true - mu_val) - hess_true = -n * cov_param_inv - Q_val - - log_x_posterior_laplace_true = ( - -0.5 * x_val.T @ (-hess_true + Q_val) @ x_val - + x_val.T @ (Q_val @ mu_val + jac_true - hess_true @ x0_true) - + 0.5 * np.log(np.linalg.det(Q_val)) - ) - - np.testing.assert_allclose(x0, x0_true, atol=0.1, rtol=0.1) - np.testing.assert_allclose(log_x_posterior, log_x_posterior_laplace_true, atol=0.1, rtol=0.1) From 43fb626ed223a3d704b25b4da6dbd58fd33392c9 Mon Sep 17 00:00:00 2001 From: Michal Novomestsky Date: Sun, 6 Jul 2025 21:43:54 +1000 Subject: [PATCH 5/7] changed minimizer tol to 1e-8 --- tests/test_inla.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_inla.py b/tests/test_inla.py index cda24bb5..58ff9ae1 100644 --- a/tests/test_inla.py +++ b/tests/test_inla.py @@ -76,7 +76,7 @@ def test_get_conditional_gaussian_approximation(): x=model.rvs_to_values[x], Q=Q.reshape((d, d)), mu=mu_param, - optimizer_kwargs={"tol": 1e-25}, + optimizer_kwargs={"tol": 1e-8}, ) cga = pytensor.function(args, [x0, pm.logp(x_g, model.rvs_to_values[x])]) From 263f6121a45e4ed6e812d678a72074fc57f00a42 Mon Sep 17 00:00:00 2001 From: Michal Novomestsky Date: Wed, 16 Jul 2025 18:57:56 +1000 Subject: [PATCH 6/7] WIP: MarginalLaplaceRV --- pymc_extras/inference/__init__.py | 3 +- pymc_extras/inference/inla.py | 20 +++++--- pymc_extras/model/marginal/distributions.py | 57 +++++++++++++++++++++ 3 files changed, 72 insertions(+), 8 deletions(-) diff --git a/pymc_extras/inference/__init__.py b/pymc_extras/inference/__init__.py index a01fdd5c..3e4d781d 100644 --- a/pymc_extras/inference/__init__.py +++ b/pymc_extras/inference/__init__.py @@ -14,7 +14,8 @@ from pymc_extras.inference.find_map import find_MAP from pymc_extras.inference.fit import fit +from pymc_extras.inference.inla import fit_INLA from pymc_extras.inference.laplace import fit_laplace from pymc_extras.inference.pathfinder.pathfinder import fit_pathfinder -__all__ = ["fit", "fit_pathfinder", "fit_laplace", "find_MAP"] +__all__ = ["fit", "fit_pathfinder", "fit_laplace", "find_MAP", "fit_INLA"] diff --git a/pymc_extras/inference/inla.py b/pymc_extras/inference/inla.py index 0eaa649c..3f6db2cf 100644 --- a/pymc_extras/inference/inla.py +++ b/pymc_extras/inference/inla.py @@ -92,7 +92,8 @@ def get_conditional_gaussian_approximation( # Currently x is passed both as the query point for f(x, args) = logp(x | y, params) AND as an initial guess for x0. This may cause issues if the query point is # far from the mode x0 or in a neighbourhood which results in poor convergence. - return x0, pm.MvNormal(f"{x.name}_laplace_approx", mu=x0, tau=tau) + _, logdetTau = pt.nlinalg.slogdet(tau) + return x0, 0.5 * logdetTau - 0.5 * x0.shape[0] * np.log(2 * np.pi) def get_log_marginal_likelihood( @@ -107,14 +108,17 @@ def get_log_marginal_likelihood( ) -> TensorVariable: model = pm.modelcontext(model) - x0, laplace_approx = get_conditional_gaussian_approximation( + x0, log_laplace_approx = get_conditional_gaussian_approximation( x, Q, mu, model, method, use_jac, use_hess, optimizer_kwargs ) - log_laplace_approx = pm.logp(laplace_approx, model.rvs_to_values[x]) + # log_laplace_approx = pm.logp(laplace_approx, x)#model.rvs_to_values[x]) _, logdetQ = pt.nlinalg.slogdet(Q) + # log_x_likelihood = ( + # -0.5 * (x - mu).T @ Q @ (x - mu) + 0.5 * logdetQ - 0.5 * x.shape[0] * np.log(2 * np.pi) + # ) log_x_likelihood = ( - -0.5 * (x - mu).T @ Q @ (x - mu) + 0.5 * logdetQ - 0.5 * x.shape[0] * np.log(2 * np.pi) + -0.5 * (x0 - mu).T @ Q @ (x0 - mu) + 0.5 * logdetQ - 0.5 * x0.shape[0] * np.log(2 * np.pi) ) log_likelihood = ( # logp(y | params) = @@ -123,7 +127,7 @@ def get_log_marginal_likelihood( - log_laplace_approx # / logp(x | y, params) ) - return log_likelihood + return x0, log_likelihood def fit_INLA( @@ -139,23 +143,25 @@ def fit_INLA( model = pm.modelcontext(model) # logp(y | params) - log_likelihood = get_log_marginal_likelihood( + x0, log_likelihood = get_log_marginal_likelihood( x, Q, mu, model, method, use_jac, use_hess, optimizer_kwargs ) # TODO How to obtain prior? It can parametrise Q, mu, y, etc. Not sure if we could extract from model.logp somehow. Otherwise simply specify as a user input + # Perhaps obtain as RVs which y depends on which aren't x? prior = None params = None log_prior = pm.logp(prior, model.rvs_to_values[params]) # logp(params | y) = logp(y | params) + logp(params) + const log_posterior = log_likelihood + log_prior + log_posterior = pytensor.graph.replace.graph_replace(log_posterior, {x: x0}) # TODO log_marginal_x_likelihood is almost the same as log_likelihood, but need to do some sampling? log_marginal_x_likelihood = None log_marginal_x_posterior = log_marginal_x_likelihood + log_prior - # TODO can we sample over log likelihoods? + # TODO can we sample over log likelihoods?w # Marginalize params idata_params = log_posterior.sample() # TODO something like NUTS, QMC, etc.? idata_x = log_marginal_x_posterior.sample() diff --git a/pymc_extras/model/marginal/distributions.py b/pymc_extras/model/marginal/distributions.py index 86aa5f02..6e08352e 100644 --- a/pymc_extras/model/marginal/distributions.py +++ b/pymc_extras/model/marginal/distributions.py @@ -132,6 +132,10 @@ class MarginalDiscreteMarkovChainRV(MarginalRV): """Base class for Marginalized Discrete Markov Chain RVs""" +class MarginalLaplaceRV(MarginalRV): + """Base class for Marginalized Laplace-Approximated RVs""" + + def get_domain_of_finite_discrete_rv(rv: TensorVariable) -> tuple[int, ...]: op = rv.owner.op dist_params = rv.owner.op.dist_params(rv.owner) @@ -371,3 +375,56 @@ def step_alpha(logp_emission, log_alpha, log_P): warn_non_separable_logp(values) dummy_logps = (DUMMY_ZERO,) * (len(values) - 1) return joint_logp, *dummy_logps + + +@_logprob.register(MarginalLaplaceRV) +def laplace_marginal_rv_logp(op: MarginalLaplaceRV, values, *inputs, **kwargs): + # Clone the inner RV graph of the Marginalized RV + x, *inner_rvs = inline_ofg_outputs(op, inputs) + + # Obtain the joint_logp graph of the inner RV graph + inner_rv_values = dict(zip(inner_rvs, values)) + marginalized_vv = x.clone() + rv_values = inner_rv_values | {x: marginalized_vv} + logps_dict = conditional_logp(rv_values=rv_values, **kwargs) + + logp = pt.sum( + [pt.sum(logps_dict[k]) for k in logps_dict] + ) # TODO check this gives the proper p(y | x, params) + + import pytensor + + from pytensor.tensor.optimize import minimize + + # Maximize log(p(x | y, params)) wrt x to find mode x0 + x0, _ = minimize( + objective=-logp, + x=marginalized_vv, + method="BFGS", + # jac=use_jac, + # hess=use_hess, + optimizer_kwargs={"tol": 1e-8}, + ) + + # require f''(x0) for Laplace approx + hess = pytensor.gradient.hessian(logp, marginalized_vv) + # hess = pytensor.graph.replace.graph_replace(hess, {marginalized_vv: x0}) + + # Could be made more efficient with adding diagonals only + rng = np.random.default_rng(12345) + d = 3 + Q = np.diag(rng.random(d)) + tau = Q - hess + + # Currently x is passed both as the query point for f(x, args) = logp(x | y, params) AND as an initial guess for x0. This may cause issues if the query point is + # far from the mode x0 or in a neighbourhood which results in poor convergence. + _, logdetTau = pt.nlinalg.slogdet(tau) + log_laplace_approx = 0.5 * logdetTau - 0.5 * x0.shape[0] * np.log(2 * np.pi) + + # Reduce logp dimensions corresponding to broadcasted variables + # marginalized_logp = logps_dict.pop(marginalized_vv) + joint_logp = logp - log_laplace_approx + + joint_logp = pytensor.graph.replace.graph_replace(joint_logp, {marginalized_vv: x0}) + + return joint_logp # TODO check if pm.sample adds on p(params). Otherwise this is p(y|params) not p(params|y) From c9d711bd3344681e8a27750d67ccf1ae64247be0 Mon Sep 17 00:00:00 2001 From: Michal-Novomestsky Date: Sat, 19 Jul 2025 21:44:28 +1000 Subject: [PATCH 7/7] WIP: Minimize inside logp --- notebooks/INLA_testing.ipynb | 875 ++++++++++++++++++++ pymc_extras/__init__.py | 2 +- pymc_extras/inference/__init__.py | 3 +- pymc_extras/model/marginal/distributions.py | 28 +- 4 files changed, 902 insertions(+), 6 deletions(-) create mode 100644 notebooks/INLA_testing.ipynb diff --git a/notebooks/INLA_testing.ipynb b/notebooks/INLA_testing.ipynb new file mode 100644 index 00000000..e2c5e27b --- /dev/null +++ b/notebooks/INLA_testing.ipynb @@ -0,0 +1,875 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "ffd6780e-1bfb-42f0-ba6a-055e9ffd1490", + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "5a2819fd-6e01-47c0-88b2-f2b5e4215b9b", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import pymc as pm\n", + "import pytensor.tensor as pt\n", + "\n", + "import pytensor\n", + "from pytensor.tensor.optimize import minimize\n", + "from pymc_extras.inference.inla import *\n", + "\n", + "from pymc.model.fgraph import fgraph_from_model, model_from_fgraph\n", + "from pymc_extras.model.marginal.marginal_model import marginalize" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "0ad97d05-f577-4793-ba6c-dd5f1300c022", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "ExpandDims{axis=0}.0\n", + "(1, None)\n", + "(1, None, 1)\n" + ] + }, + { + "data": { + "text/plain": [ + "Reshape{4}.0" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from pytensor.gradient import grad, hessian, jacobian\n", + "from pytensor.tensor.optimize import root\n", + "\n", + "x = pt.vector(\"x\")\n", + "var = pt.stack([x])\n", + "y = pt.stack([var[0], var[0] ** 2])\n", + "sol, _ = root(y, variables=var)\n", + "jacobian(sol, var, vectorize=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "2f475324-9fba-48a1-a79a-563a5e7818c9", + "metadata": {}, + "outputs": [], + "source": [ + "rng = np.random.default_rng(12345)\n", + "n = 10000\n", + "d = 10\n", + "\n", + "# Initialise arrays\n", + "mu_true = rng.random(d)\n", + "cov_true = np.diag(rng.random(d))\n", + "Q_val = np.diag(rng.random(d))\n", + "cov_param_val = np.diag(rng.random(d))\n", + "\n", + "x_val = rng.random(d)\n", + "mu_val = rng.random(d)\n", + "\n", + "mu_mu = rng.random(d)\n", + "mu_cov = np.diag(np.ones(d))\n", + "cov_mu = rng.random(d**2)\n", + "cov_cov = np.diag(np.ones(d**2))\n", + "Q_mu = rng.random(d**2)\n", + "Q_cov = np.diag(np.ones(d**2))\n", + "\n", + "with pm.Model() as model:\n", + " y_obs = rng.multivariate_normal(mean=mu_true, cov=cov_true, size=n)\n", + "\n", + " mu_param = pm.MvNormal(\"mu_param\", mu=mu_mu, cov=mu_cov)\n", + " # cov_param = np.abs(pm.MvNormal(\"cov_param\", mu=cov_mu, cov=cov_cov))\n", + " # Q = pm.MvNormal(\"Q\", mu=Q_mu, cov=Q_cov)\n", + "\n", + " x = pm.MvNormal(\"x\", mu=mu_param, tau=Q_val)\n", + "\n", + " y = pm.MvNormal(\n", + " \"y\",\n", + " mu=x,\n", + " cov=cov_param_val, # cov_param.reshape((d, d)),\n", + " observed=y_obs,\n", + " )\n", + "\n", + " # x0, log_likelihood = get_log_marginal_likelihood(\n", + " # x=model.rvs_to_values[x],\n", + " # Q=Q_val,#Q.reshape((d, d)),\n", + " # mu=mu_param,\n", + " # optimizer_kwargs={\"tol\": 1e-8},\n", + " # )\n", + "\n", + " # args = model.continuous_value_vars + model.discrete_value_vars\n", + " # for i, rv in enumerate(args):\n", + " # if rv == model.rvs_to_values[x]:\n", + " # args.pop(i)\n", + " # log_likelihood = pytensor.graph.replace.graph_replace(log_likelihood, {model.rvs_to_values[x]: rng.random(d)})\n", + " # log_laplace_approx = pytensor.function(args, log_likelihood)\n", + "\n", + " # pm.sample()" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "5121ab56-7841-4ff2-b9b0-016639d2bdb2", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "ModelFreeRV{transform=None} [id A] 'mu_param' 3\n", + " ├─ MvNormalRV{name='multivariate_normal', signature='(n),(n,n)->(n)', dtype='float64', inplace=False, method='cholesky'}.1 [id B] 'mu_param' 2\n", + " │ ├─ RNG() [id C]\n", + " │ ├─ NoneConst{None} [id D]\n", + " │ ├─ Second [id E] 1\n", + " │ │ ├─ Subtensor{:, i} [id F] 0\n", + " │ │ │ ├─ [[1. 0. 0. ... 0. 0. 1.]] [id G]\n", + " │ │ │ └─ -1 [id H]\n", + " │ │ └─ [0.2552323 ... .18013059] [id I]\n", + " │ └─ [[1. 0. 0. ... 0. 0. 1.]] [id G]\n", + " └─ mu_param [id J]\n", + "ModelFreeRV{transform=None} [id K] 'x' 8\n", + " ├─ MvNormalRV{name='multivariate_normal', signature='(n),(n,n)->(n)', dtype='float64', inplace=False, method='cholesky'}.1 [id L] 'x' 7\n", + " │ ├─ RNG() [id M]\n", + " │ ├─ NoneConst{None} [id D]\n", + " │ ├─ Second [id N] 6\n", + " │ │ ├─ Subtensor{:, i} [id O] 5\n", + " │ │ │ ├─ Blockwise{MatrixInverse, (m,m)->(m,m)} [id P] 4\n", + " │ │ │ │ └─ [[0.081594 ... 59856801]] [id Q]\n", + " │ │ │ └─ -1 [id R]\n", + " │ │ └─ ModelFreeRV{transform=None} [id A] 'mu_param' 3\n", + " │ │ └─ ···\n", + " │ └─ Blockwise{MatrixInverse, (m,m)->(m,m)} [id P] 4\n", + " │ └─ ···\n", + " └─ x [id S]\n", + "ModelObservedRV{transform=None} [id T] 'y' 14\n", + " ├─ MvNormalRV{name='multivariate_normal', signature='(n),(n,n)->(n)', dtype='float64', inplace=False, method='cholesky'}.1 [id U] 'y' 13\n", + " │ ├─ RNG() [id V]\n", + " │ ├─ [10000] [id W]\n", + " │ ├─ ExpandDims{axis=0} [id X] 12\n", + " │ │ └─ Second [id Y] 11\n", + " │ │ ├─ Subtensor{:, i} [id Z] 10\n", + " │ │ │ ├─ [[0.854741 ... 27377318]] [id BA]\n", + " │ │ │ └─ -1 [id BB]\n", + " │ │ └─ ModelFreeRV{transform=None} [id K] 'x' 8\n", + " │ │ └─ ···\n", + " │ └─ ExpandDims{axis=0} [id BC] 9\n", + " │ └─ [[0.854741 ... 27377318]] [id BA]\n", + " └─ y{[[ 0.64235 ... 56333986]]} [id BD]\n" + ] + }, + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "cluster10\n", + "\n", + "10\n", + "\n", + "\n", + "cluster10000 x 10\n", + "\n", + "10000 x 10\n", + "\n", + "\n", + "\n", + "x\n", + "\n", + "x\n", + "~\n", + "MvNormal\n", + "\n", + "\n", + "\n", + "y\n", + "\n", + "y\n", + "~\n", + "MvNormal\n", + "\n", + "\n", + "\n", + "x->y\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "mu_param\n", + "\n", + "mu_param\n", + "~\n", + "MvNormal\n", + "\n", + "\n", + "\n", + "mu_param->x\n", + "\n", + "\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "rvs_to_marginalize = [x]\n", + "\n", + "fg, memo = fgraph_from_model(model)\n", + "fg.dprint()\n", + "rvs_to_marginalize = [memo[rv] for rv in rvs_to_marginalize]\n", + "toposort = fg.toposort()\n", + "\n", + "# fg.dprint()\n", + "# print(rvs_to_marginalize)\n", + "# print(toposort)\n", + "\n", + "from pymc.model.fgraph import (\n", + " ModelFreeRV,\n", + " ModelValuedVar,\n", + ")\n", + "\n", + "from pymc_extras.model.marginal.graph_analysis import (\n", + " find_conditional_dependent_rvs,\n", + " find_conditional_input_rvs,\n", + " is_conditional_dependent,\n", + " subgraph_batch_dim_connection,\n", + ")\n", + "\n", + "from pymc_extras.model.marginal.marginal_model import (\n", + " _unique,\n", + " collect_shared_vars,\n", + " remove_model_vars,\n", + ")\n", + "\n", + "from pymc_extras.model.marginal.distributions import (\n", + " MarginalLaplaceRV,\n", + ")\n", + "\n", + "from pymc.pytensorf import collect_default_updates\n", + "\n", + "from pytensor.graph import (\n", + " FunctionGraph,\n", + " Variable,\n", + " clone_replace,\n", + ")\n", + "\n", + "for rv_to_marginalize in sorted(\n", + " rvs_to_marginalize,\n", + " key=lambda rv: toposort.index(rv.owner),\n", + " reverse=True,\n", + "):\n", + " all_rvs = [node.out for node in fg.toposort() if isinstance(node.op, ModelValuedVar)]\n", + "\n", + " dependent_rvs = find_conditional_dependent_rvs(rv_to_marginalize, all_rvs)\n", + " if not dependent_rvs:\n", + " # TODO: This should at most be a warning, not an error\n", + " raise ValueError(f\"No RVs depend on marginalized RV {rv_to_marginalize}\")\n", + "\n", + " # Issue warning for IntervalTransform on dependent RVs\n", + " for dependent_rv in dependent_rvs:\n", + " transform = dependent_rv.owner.op.transform\n", + "\n", + " # if isinstance(transform, IntervalTransform) or (\n", + " # isinstance(transform, Chain)\n", + " # and any(isinstance(tr, IntervalTransform) for tr in transform.transform_list)\n", + " # ):\n", + " # warnings.warn(\n", + " # f\"The transform {transform} for the variable {dependent_rv}, which depends on the \"\n", + " # f\"marginalized {rv_to_marginalize} may no longer work if bounds depended on other variables.\",\n", + " # UserWarning,\n", + " # )\n", + "\n", + " # Check that no deterministics or potentials depend on the rv to marginalize\n", + " for det in model.deterministics:\n", + " if is_conditional_dependent(memo[det], rv_to_marginalize, all_rvs):\n", + " raise NotImplementedError(\n", + " f\"Cannot marginalize {rv_to_marginalize} due to dependent Deterministic {det}\"\n", + " )\n", + " for pot in model.potentials:\n", + " if is_conditional_dependent(memo[pot], rv_to_marginalize, all_rvs):\n", + " raise NotImplementedError(\n", + " f\"Cannot marginalize {rv_to_marginalize} due to dependent Potential {pot}\"\n", + " )\n", + "\n", + " marginalized_rv_input_rvs = find_conditional_input_rvs([rv_to_marginalize], all_rvs)\n", + " other_direct_rv_ancestors = [\n", + " rv\n", + " for rv in find_conditional_input_rvs(dependent_rvs, all_rvs)\n", + " if rv is not rv_to_marginalize\n", + " ]\n", + " input_rvs = _unique((*marginalized_rv_input_rvs, *other_direct_rv_ancestors))\n", + "\n", + "pm.model_to_graphviz(model)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "7d2d4683-fc83-47a4-bca8-77085688c42f", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[mu_param, RNG(), RNG()]\n", + "[x, y, MvNormalRV{name='multivariate_normal', signature='(n),(n,n)->(n)', dtype='float64', inplace=False, method='cholesky'}.0, MvNormalRV{name='multivariate_normal', signature='(n),(n,n)->(n)', dtype='float64', inplace=False, method='cholesky'}.0]\n", + "[mu_param, RNG(), RNG()]\n", + "[x, y, MvNormalRV{name='multivariate_normal', signature='(n),(n,n)->(n)', dtype='float64', inplace=False, method='cholesky'}.0, MvNormalRV{name='multivariate_normal', signature='(n),(n,n)->(n)', dtype='float64', inplace=False, method='cholesky'}.0]\n", + "[x, y, MarginalLaplaceRV{inline=False}.2, MarginalLaplaceRV{inline=False}.3]\n", + "ModelFreeRV{transform=None} [id A] 'mu_param' 3\n", + " ├─ MvNormalRV{name='multivariate_normal', signature='(n),(n,n)->(n)', dtype='float64', inplace=False, method='cholesky'}.1 [id B] 'mu_param' 2\n", + " │ ├─ RNG() [id C]\n", + " │ ├─ NoneConst{None} [id D]\n", + " │ ├─ Second [id E] 1\n", + " │ │ ├─ Subtensor{:, i} [id F] 0\n", + " │ │ │ ├─ [[1. 0. 0. ... 0. 0. 1.]] [id G]\n", + " │ │ │ └─ -1 [id H]\n", + " │ │ └─ [0.2552323 ... .18013059] [id I]\n", + " │ └─ [[1. 0. 0. ... 0. 0. 1.]] [id G]\n", + " └─ mu_param [id J]\n", + "MarginalLaplaceRV{inline=False}.0 [id K] 'x' 4\n", + " ├─ ModelFreeRV{transform=None} [id A] 'mu_param' 3\n", + " │ └─ ···\n", + " ├─ RNG() [id L]\n", + " └─ RNG() [id M]\n", + "ModelObservedRV{transform=None} [id N] 'y' 5\n", + " ├─ MarginalLaplaceRV{inline=False}.1 [id K] 'y' 4\n", + " │ └─ ···\n", + " └─ y{[[ 0.64235 ... 56333986]]} [id O]\n", + "\n", + "Inner graphs:\n", + "\n", + "MarginalLaplaceRV{inline=False} [id K]\n", + " ← MvNormalRV{name='multivariate_normal', signature='(n),(n,n)->(n)', dtype='float64', inplace=False, method='cholesky'}.1 [id P] 'x'\n", + " ├─ *2- [id Q]\n", + " ├─ NoneConst{None} [id D]\n", + " ├─ Second [id R]\n", + " │ ├─ Subtensor{:, i} [id S]\n", + " │ │ ├─ Blockwise{MatrixInverse, (m,m)->(m,m)} [id T]\n", + " │ │ │ └─ [[0.081594 ... 59856801]] [id U]\n", + " │ │ └─ -1 [id V]\n", + " │ └─ *0- [id W]\n", + " └─ Blockwise{MatrixInverse, (m,m)->(m,m)} [id T]\n", + " └─ ···\n", + " ← MvNormalRV{name='multivariate_normal', signature='(n),(n,n)->(n)', dtype='float64', inplace=False, method='cholesky'}.1 [id X] 'y'\n", + " ├─ *1- [id Y]\n", + " ├─ [10000] [id Z]\n", + " ├─ ExpandDims{axis=0} [id BA]\n", + " │ └─ Second [id BB]\n", + " │ ├─ Subtensor{:, i} [id BC]\n", + " │ │ ├─ [[0.854741 ... 27377318]] [id BD]\n", + " │ │ └─ -1 [id BE]\n", + " │ └─ MvNormalRV{name='multivariate_normal', signature='(n),(n,n)->(n)', dtype='float64', inplace=False, method='cholesky'}.1 [id P] 'x'\n", + " │ └─ ···\n", + " └─ ExpandDims{axis=0} [id BF]\n", + " └─ [[0.854741 ... 27377318]] [id BD]\n", + " ← MvNormalRV{name='multivariate_normal', signature='(n),(n,n)->(n)', dtype='float64', inplace=False, method='cholesky'}.0 [id X]\n", + " └─ ···\n", + " ← MvNormalRV{name='multivariate_normal', signature='(n),(n,n)->(n)', dtype='float64', inplace=False, method='cholesky'}.0 [id P]\n", + " └─ ···\n" + ] + }, + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "cluster10\n", + "\n", + "10\n", + "\n", + "\n", + "cluster10000 x 10\n", + "\n", + "10000 x 10\n", + "\n", + "\n", + "\n", + "mu_param\n", + "\n", + "mu_param\n", + "~\n", + "MvNormal\n", + "\n", + "\n", + "\n", + "y\n", + "\n", + "y\n", + "~\n", + "MarginalLaplace\n", + "\n", + "\n", + "\n", + "mu_param->y\n", + "\n", + "\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "output_rvs = [rv_to_marginalize, *dependent_rvs]\n", + "rng_updates = collect_default_updates(output_rvs, inputs=input_rvs, must_be_shared=False)\n", + "outputs = output_rvs + list(rng_updates.values())\n", + "inputs = input_rvs + list(rng_updates.keys())\n", + "# Add any other shared variable inputs\n", + "inputs += collect_shared_vars(output_rvs, blockers=inputs)\n", + "\n", + "inner_inputs = [inp.clone() for inp in inputs]\n", + "inner_outputs = clone_replace(outputs, replace=dict(zip(inputs, inner_inputs)))\n", + "inner_outputs = remove_model_vars(inner_outputs)\n", + "\n", + "marginalize_constructor = MarginalLaplaceRV\n", + "\n", + "_, _, *dims = rv_to_marginalize.owner.inputs\n", + "marginalization_op = marginalize_constructor(\n", + " inputs=inner_inputs,\n", + " outputs=inner_outputs,\n", + " dims_connections=[\n", + " (None,),\n", + " ], # dependent_rvs_dim_connections, # TODO NOT SURE WHAT THIS IS\n", + " dims=dims,\n", + " # x0=x0,\n", + " # marginalized_rv_input_rvs=marginalized_rv_input_rvs\n", + ")\n", + "\n", + "new_outputs = marginalization_op(*inputs)\n", + "for old_output, new_output in zip(outputs, new_outputs):\n", + " new_output.name = old_output.name\n", + "\n", + "model_replacements = []\n", + "for old_output, new_output in zip(outputs, new_outputs):\n", + " if old_output is rv_to_marginalize or not isinstance(old_output.owner.op, ModelValuedVar):\n", + " # Replace the marginalized ModelFreeRV (or non model-variables) themselves\n", + " var_to_replace = old_output\n", + " else:\n", + " # Replace the underlying RV, keeping the same value, transform and dims\n", + " var_to_replace = old_output.owner.inputs[0]\n", + " model_replacements.append((var_to_replace, new_output))\n", + "\n", + "print(inner_inputs)\n", + "print(inner_outputs)\n", + "print(inputs)\n", + "\n", + "print(outputs)\n", + "print(new_outputs)\n", + "\n", + "fg.replace_all(model_replacements)\n", + "fg.dprint()\n", + "\n", + "model_marg = model_from_fgraph(fg, mutate_fgraph=True)\n", + "pm.model_to_graphviz(model_marg)" + ] + }, + { + "cell_type": "code", + "execution_count": 93, + "id": "793248a1-8088-41fc-9dbd-c58e596e0df7", + "metadata": {}, + "outputs": [ + { + "ename": "TypeError", + "evalue": "Only tensors with the same number of dimensions can be joined. Input ndims were: [3, 2]", + "output_type": "error", + "traceback": [ + "\u001b[31m---------------------------------------------------------------------------\u001b[39m", + "\u001b[31mTypeError\u001b[39m Traceback (most recent call last)", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[93]\u001b[39m\u001b[32m, line 8\u001b[39m\n\u001b[32m 5\u001b[39m b = pt.vector(\u001b[33m'\u001b[39m\u001b[33mb\u001b[39m\u001b[33m'\u001b[39m, shape=(\u001b[32m3\u001b[39m,))\n\u001b[32m 7\u001b[39m eqns = pt.stack([A @ x - b])\n\u001b[32m----> \u001b[39m\u001b[32m8\u001b[39m var = \u001b[43mpt\u001b[49m\u001b[43m.\u001b[49m\u001b[43mstack\u001b[49m\u001b[43m(\u001b[49m\u001b[43m[\u001b[49m\u001b[43mA\u001b[49m\u001b[43m,\u001b[49m\u001b[43mb\u001b[49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 10\u001b[39m soln, _ = root(eqns, variables=var)\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/git/pytensor/pytensor/tensor/basic.py:2977\u001b[39m, in \u001b[36mstack\u001b[39m\u001b[34m(tensors, axis)\u001b[39m\n\u001b[32m 2973\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mconcatenate\u001b[39m(tensor_list, axis=\u001b[32m0\u001b[39m):\n\u001b[32m 2974\u001b[39m \u001b[38;5;250m \u001b[39m\u001b[33;03m\"\"\"Alias for `join`(axis, *tensor_list).\u001b[39;00m\n\u001b[32m 2975\u001b[39m \n\u001b[32m 2976\u001b[39m \u001b[33;03m This function is similar to `join`, but uses the signature of\u001b[39;00m\n\u001b[32m-> \u001b[39m\u001b[32m2977\u001b[39m \u001b[33;03m numpy's concatenate function.\u001b[39;00m\n\u001b[32m 2978\u001b[39m \n\u001b[32m 2979\u001b[39m \u001b[33;03m Raises\u001b[39;00m\n\u001b[32m 2980\u001b[39m \u001b[33;03m ------\u001b[39;00m\n\u001b[32m 2981\u001b[39m \u001b[33;03m TypeError\u001b[39;00m\n\u001b[32m 2982\u001b[39m \u001b[33;03m The tensor_list must be a tuple or list.\u001b[39;00m\n\u001b[32m 2983\u001b[39m \n\u001b[32m 2984\u001b[39m \u001b[33;03m \"\"\"\u001b[39;00m\n\u001b[32m 2985\u001b[39m \u001b[38;5;66;03m# Check someone did not make the common mistake to do something like:\u001b[39;00m\n\u001b[32m 2986\u001b[39m \u001b[38;5;66;03m# c = concatenate(x, y)\u001b[39;00m\n\u001b[32m 2987\u001b[39m \u001b[38;5;66;03m# instead of\u001b[39;00m\n\u001b[32m 2988\u001b[39m \u001b[38;5;66;03m# c = concatenate((x, y))\u001b[39;00m\n\u001b[32m 2989\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(tensor_list, \u001b[38;5;28mtuple\u001b[39m | \u001b[38;5;28mlist\u001b[39m):\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/git/pytensor/pytensor/tensor/basic.py:2817\u001b[39m, in \u001b[36mjoin\u001b[39m\u001b[34m(axis, *tensors_list)\u001b[39m\n\u001b[32m 2815\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m tensors_list[\u001b[32m0\u001b[39m]\n\u001b[32m 2816\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m2817\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_join\u001b[49m\u001b[43m(\u001b[49m\u001b[43maxis\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43mtensors_list\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/git/pytensor/pytensor/graph/op.py:293\u001b[39m, in \u001b[36mOp.__call__\u001b[39m\u001b[34m(self, name, return_list, *inputs, **kwargs)\u001b[39m\n\u001b[32m 249\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34m__call__\u001b[39m(\n\u001b[32m 250\u001b[39m \u001b[38;5;28mself\u001b[39m, *inputs: Any, name=\u001b[38;5;28;01mNone\u001b[39;00m, return_list=\u001b[38;5;28;01mFalse\u001b[39;00m, **kwargs\n\u001b[32m 251\u001b[39m ) -> Variable | \u001b[38;5;28mlist\u001b[39m[Variable]:\n\u001b[32m 252\u001b[39m \u001b[38;5;250m \u001b[39m\u001b[33mr\u001b[39m\u001b[33;03m\"\"\"Construct an `Apply` node using :meth:`Op.make_node` and return its outputs.\u001b[39;00m\n\u001b[32m 253\u001b[39m \n\u001b[32m 254\u001b[39m \u001b[33;03m This method is just a wrapper around :meth:`Op.make_node`.\u001b[39;00m\n\u001b[32m (...)\u001b[39m\u001b[32m 291\u001b[39m \n\u001b[32m 292\u001b[39m \u001b[33;03m \"\"\"\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m293\u001b[39m node = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mmake_node\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 294\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m name \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[32m 295\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(node.outputs) == \u001b[32m1\u001b[39m:\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/git/pytensor/pytensor/tensor/basic.py:2510\u001b[39m, in \u001b[36mJoin.make_node\u001b[39m\u001b[34m(self, axis, *tensors)\u001b[39m\n\u001b[32m 2507\u001b[39m ndim = tensors[\u001b[32m0\u001b[39m].type.ndim\n\u001b[32m 2509\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m builtins.all(x.ndim == ndim \u001b[38;5;28;01mfor\u001b[39;00m x \u001b[38;5;129;01min\u001b[39;00m tensors):\n\u001b[32m-> \u001b[39m\u001b[32m2510\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(\n\u001b[32m 2511\u001b[39m \u001b[33m\"\u001b[39m\u001b[33mOnly tensors with the same number of dimensions can be joined. \u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m 2512\u001b[39m \u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mInput ndims were: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m[x.ndim\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mfor\u001b[39;00m\u001b[38;5;250m \u001b[39mx\u001b[38;5;250m \u001b[39m\u001b[38;5;129;01min\u001b[39;00m\u001b[38;5;250m \u001b[39mtensors]\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m\"\u001b[39m\n\u001b[32m 2513\u001b[39m )\n\u001b[32m 2515\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[32m 2516\u001b[39m static_axis = \u001b[38;5;28mint\u001b[39m(get_scalar_constant_value(axis))\n", + "\u001b[31mTypeError\u001b[39m: Only tensors with the same number of dimensions can be joined. Input ndims were: [3, 2]" + ] + } + ], + "source": [ + "from pytensor.tensor.optimize import root\n", + "\n", + "A = pt.matrix(\"A\", shape=(3, 3))\n", + "x = np.ones((3, 1))\n", + "b = pt.vector(\"b\", shape=(3,))\n", + "\n", + "eqns = pt.stack([A @ x - b])\n", + "var = pt.stack([A, b])\n", + "\n", + "soln, _ = root(eqns, variables=var)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "07f8abf3-6158-4d62-83d2-2cffe65aae91", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(10,)\n", + "(10, 1, 1)\n" + ] + }, + { + "data": { + "text/plain": [ + "Reshape{4}.0" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from pytensor.tensor.math import tensordot\n", + "\n", + "a = pt.tensor(\"a\", shape=(10,))\n", + "b = pt.tensor(\"b\", shape=(10, 1, 1))\n", + "\n", + "print(a.type.shape)\n", + "print(b.type.shape)\n", + "# print(b.T.type.shape)\n", + "\n", + "tensordot(a, b, axes=0)" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "id": "2eee8e73-5305-4472-8f9c-58a689f9471e", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "ename": "AttributeError", + "evalue": "'MarginalLaplaceRV' object has no attribute 'owner'", + "output_type": "error", + "traceback": [ + "\u001b[31m---------------------------------------------------------------------------\u001b[39m", + "\u001b[31mAttributeError\u001b[39m Traceback (most recent call last)", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[31]\u001b[39m\u001b[32m, line 2\u001b[39m\n\u001b[32m 1\u001b[39m \u001b[38;5;28;01mwith\u001b[39;00m model_marg \u001b[38;5;28;01mas\u001b[39;00m m:\n\u001b[32m----> \u001b[39m\u001b[32m2\u001b[39m \u001b[43mpm\u001b[49m\u001b[43m.\u001b[49m\u001b[43msample\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/git/pymc/pymc/sampling/mcmc.py:783\u001b[39m, in \u001b[36msample\u001b[39m\u001b[34m(draws, tune, chains, cores, random_seed, progressbar, progressbar_theme, step, var_names, nuts_sampler, initvals, init, jitter_max_retries, n_init, trace, discard_tuned_samples, compute_convergence_checks, keep_warning_stat, return_inferencedata, idata_kwargs, nuts_sampler_kwargs, callback, mp_ctx, blas_cores, model, compile_kwargs, **kwargs)\u001b[39m\n\u001b[32m 780\u001b[39m _log.warning(msg)\n\u001b[32m 782\u001b[39m provided_steps, selected_steps = assign_step_methods(model, step, methods=pm.STEP_METHODS)\n\u001b[32m--> \u001b[39m\u001b[32m783\u001b[39m exclusive_nuts = (\n\u001b[32m 784\u001b[39m \u001b[38;5;66;03m# User provided an instantiated NUTS step, and nothing else is needed\u001b[39;00m\n\u001b[32m 785\u001b[39m (\u001b[38;5;129;01mnot\u001b[39;00m selected_steps \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(provided_steps) == \u001b[32m1\u001b[39m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(provided_steps[\u001b[32m0\u001b[39m], NUTS))\n\u001b[32m 786\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m\n\u001b[32m 787\u001b[39m \u001b[38;5;66;03m# Only automatically selected NUTS step is needed\u001b[39;00m\n\u001b[32m 788\u001b[39m (\n\u001b[32m 789\u001b[39m \u001b[38;5;129;01mnot\u001b[39;00m provided_steps\n\u001b[32m 790\u001b[39m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(selected_steps) == \u001b[32m1\u001b[39m\n\u001b[32m 791\u001b[39m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28missubclass\u001b[39m(\u001b[38;5;28mnext\u001b[39m(\u001b[38;5;28miter\u001b[39m(selected_steps)), NUTS)\n\u001b[32m 792\u001b[39m )\n\u001b[32m 793\u001b[39m )\n\u001b[32m 795\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m nuts_sampler != \u001b[33m\"\u001b[39m\u001b[33mpymc\u001b[39m\u001b[33m\"\u001b[39m:\n\u001b[32m 796\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m exclusive_nuts:\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/git/pymc/pymc/sampling/mcmc.py:245\u001b[39m, in \u001b[36massign_step_methods\u001b[39m\u001b[34m(model, step, methods)\u001b[39m\n\u001b[32m 243\u001b[39m methods_list: \u001b[38;5;28mlist\u001b[39m[\u001b[38;5;28mtype\u001b[39m[BlockedStep]] = \u001b[38;5;28mlist\u001b[39m(methods \u001b[38;5;129;01mor\u001b[39;00m pm.STEP_METHODS)\n\u001b[32m 244\u001b[39m selected_steps: \u001b[38;5;28mdict\u001b[39m[\u001b[38;5;28mtype\u001b[39m[BlockedStep], \u001b[38;5;28mlist\u001b[39m] = {}\n\u001b[32m--> \u001b[39m\u001b[32m245\u001b[39m model_logp = \u001b[43mmodel\u001b[49m\u001b[43m.\u001b[49m\u001b[43mlogp\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 247\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m var \u001b[38;5;129;01min\u001b[39;00m model.value_vars:\n\u001b[32m 248\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m var \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;129;01min\u001b[39;00m assigned_vars:\n\u001b[32m 249\u001b[39m \u001b[38;5;66;03m# determine if a gradient can be computed\u001b[39;00m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/git/pymc/pymc/model/core.py:691\u001b[39m, in \u001b[36mModel.logp\u001b[39m\u001b[34m(self, vars, jacobian, sum)\u001b[39m\n\u001b[32m 689\u001b[39m rv_logps: \u001b[38;5;28mlist\u001b[39m[TensorVariable] = []\n\u001b[32m 690\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m rvs:\n\u001b[32m--> \u001b[39m\u001b[32m691\u001b[39m rv_logps = \u001b[43mtransformed_conditional_logp\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 692\u001b[39m \u001b[43m \u001b[49m\u001b[43mrvs\u001b[49m\u001b[43m=\u001b[49m\u001b[43mrvs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 693\u001b[39m \u001b[43m \u001b[49m\u001b[43mrvs_to_values\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mrvs_to_values\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 694\u001b[39m \u001b[43m \u001b[49m\u001b[43mrvs_to_transforms\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mrvs_to_transforms\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 695\u001b[39m \u001b[43m \u001b[49m\u001b[43mjacobian\u001b[49m\u001b[43m=\u001b[49m\u001b[43mjacobian\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 696\u001b[39m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 697\u001b[39m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(rv_logps, \u001b[38;5;28mlist\u001b[39m)\n\u001b[32m 699\u001b[39m \u001b[38;5;66;03m# Replace random variables by their value variables in potential terms\u001b[39;00m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/git/pymc/pymc/logprob/basic.py:570\u001b[39m, in \u001b[36mtransformed_conditional_logp\u001b[39m\u001b[34m(rvs, rvs_to_values, rvs_to_transforms, jacobian, **kwargs)\u001b[39m\n\u001b[32m 567\u001b[39m transform_rewrite = TransformValuesRewrite(values_to_transforms) \u001b[38;5;66;03m# type: ignore[arg-type]\u001b[39;00m\n\u001b[32m 569\u001b[39m kwargs.setdefault(\u001b[33m\"\u001b[39m\u001b[33mwarn_rvs\u001b[39m\u001b[33m\"\u001b[39m, \u001b[38;5;28;01mFalse\u001b[39;00m)\n\u001b[32m--> \u001b[39m\u001b[32m570\u001b[39m temp_logp_terms = \u001b[43mconditional_logp\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 571\u001b[39m \u001b[43m \u001b[49m\u001b[43mrvs_to_values\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 572\u001b[39m \u001b[43m \u001b[49m\u001b[43mextra_rewrites\u001b[49m\u001b[43m=\u001b[49m\u001b[43mtransform_rewrite\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 573\u001b[39m \u001b[43m \u001b[49m\u001b[43muse_jacobian\u001b[49m\u001b[43m=\u001b[49m\u001b[43mjacobian\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 574\u001b[39m \u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 575\u001b[39m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 577\u001b[39m \u001b[38;5;66;03m# The function returns the logp for every single value term we provided to it.\u001b[39;00m\n\u001b[32m 578\u001b[39m \u001b[38;5;66;03m# This includes the extra values we plugged in above, so we filter those we\u001b[39;00m\n\u001b[32m 579\u001b[39m \u001b[38;5;66;03m# actually wanted in the same order they were given in.\u001b[39;00m\n\u001b[32m 580\u001b[39m logp_terms = {}\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/git/pymc/pymc/logprob/basic.py:500\u001b[39m, in \u001b[36mconditional_logp\u001b[39m\u001b[34m(rv_values, warn_rvs, ir_rewriter, extra_rewrites, **kwargs)\u001b[39m\n\u001b[32m 497\u001b[39m node_values = remapped_vars[: \u001b[38;5;28mlen\u001b[39m(node_values)]\n\u001b[32m 498\u001b[39m node_inputs = remapped_vars[\u001b[38;5;28mlen\u001b[39m(node_values) :]\n\u001b[32m--> \u001b[39m\u001b[32m500\u001b[39m node_logprobs = \u001b[43m_logprob\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 501\u001b[39m \u001b[43m \u001b[49m\u001b[43mnode\u001b[49m\u001b[43m.\u001b[49m\u001b[43mop\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 502\u001b[39m \u001b[43m \u001b[49m\u001b[43mnode_values\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 503\u001b[39m \u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43mnode_inputs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 504\u001b[39m \u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 505\u001b[39m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 507\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(node_logprobs, \u001b[38;5;28mlist\u001b[39m | \u001b[38;5;28mtuple\u001b[39m):\n\u001b[32m 508\u001b[39m node_logprobs = [node_logprobs]\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/git/pymc-extras/.pixi/envs/default/lib/python3.12/functools.py:912\u001b[39m, in \u001b[36msingledispatch..wrapper\u001b[39m\u001b[34m(*args, **kw)\u001b[39m\n\u001b[32m 908\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m args:\n\u001b[32m 909\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(\u001b[33mf\u001b[39m\u001b[33m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mfuncname\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m requires at least \u001b[39m\u001b[33m'\u001b[39m\n\u001b[32m 910\u001b[39m \u001b[33m'\u001b[39m\u001b[33m1 positional argument\u001b[39m\u001b[33m'\u001b[39m)\n\u001b[32m--> \u001b[39m\u001b[32m912\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mdispatch\u001b[49m\u001b[43m(\u001b[49m\u001b[43margs\u001b[49m\u001b[43m[\u001b[49m\u001b[32;43m0\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m.\u001b[49m\u001b[34;43m__class__\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkw\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m:17\u001b[39m, in \u001b[36mlaplace_marginal_rv_logp\u001b[39m\u001b[34m(op, values, *inputs, **kwargs)\u001b[39m\n", + "\u001b[31mAttributeError\u001b[39m: 'MarginalLaplaceRV' object has no attribute 'owner'" + ] + } + ], + "source": [ + "with model_marg as m:\n", + " pm.sample()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "991b7cbd-4ef3-487a-9ece-139226883502", + "metadata": {}, + "outputs": [], + "source": [ + "[2, 2, 3, 4, 2, 2, 2, 3, 3, 4]" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "id": "42203b08-519a-4676-9a99-34a5d92d4c5d", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Initializing NUTS using jitter+adapt_diag...\n", + "Multiprocess sampling (4 chains in 4 jobs)\n", + "NUTS: [mu_param, Q, x]\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "c6253c53d3124ca1add318ff3828d5c2", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "ename": "ValueError",
+     "evalue": "Not enough samples to build a trace.",
+     "output_type": "error",
+     "traceback": [
+      "\u001b[31m---------------------------------------------------------------------------\u001b[39m",
+      "\u001b[31mValueError\u001b[39m                                Traceback (most recent call last)",
+      "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[40]\u001b[39m\u001b[32m, line 46\u001b[39m\n\u001b[32m     37\u001b[39m     \u001b[38;5;66;03m# model.logp().dprint()\u001b[39;00m\n\u001b[32m     39\u001b[39m     x0, log_likelihood = get_log_marginal_likelihood(\n\u001b[32m     40\u001b[39m         x=model.rvs_to_values[x],\n\u001b[32m     41\u001b[39m         Q=Q.reshape((d, d)),\u001b[38;5;66;03m#Q_val,\u001b[39;00m\n\u001b[32m     42\u001b[39m         mu=mu_param,\n\u001b[32m     43\u001b[39m         optimizer_kwargs={\u001b[33m\"\u001b[39m\u001b[33mtol\u001b[39m\u001b[33m\"\u001b[39m: \u001b[32m1e-8\u001b[39m},\n\u001b[32m     44\u001b[39m     )\n\u001b[32m---> \u001b[39m\u001b[32m46\u001b[39m     \u001b[43mpm\u001b[49m\u001b[43m.\u001b[49m\u001b[43msample\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m     48\u001b[39m     \u001b[38;5;66;03m# print(model.free_RVs)\u001b[39;00m\n\u001b[32m     49\u001b[39m \n\u001b[32m     50\u001b[39m     \u001b[38;5;66;03m# # with pm.Model() as inla_model:\u001b[39;00m\n\u001b[32m   (...)\u001b[39m\u001b[32m     65\u001b[39m \n\u001b[32m     66\u001b[39m \u001b[38;5;66;03m# inla_model = marginalize(model, [mu_param, cov_param])\u001b[39;00m\n",
+      "\u001b[36mFile \u001b[39m\u001b[32m~/Michal_Linux/git/GSoC/pymc-extras/.pixi/envs/default/lib/python3.12/site-packages/pymc/sampling/mcmc.py:964\u001b[39m, in \u001b[36msample\u001b[39m\u001b[34m(draws, tune, chains, cores, random_seed, progressbar, progressbar_theme, step, var_names, nuts_sampler, initvals, init, jitter_max_retries, n_init, trace, discard_tuned_samples, compute_convergence_checks, keep_warning_stat, return_inferencedata, idata_kwargs, nuts_sampler_kwargs, callback, mp_ctx, blas_cores, model, compile_kwargs, **kwargs)\u001b[39m\n\u001b[32m    960\u001b[39m t_sampling = time.time() - t_start\n\u001b[32m    962\u001b[39m \u001b[38;5;66;03m# Packaging, validating and returning the result was extracted\u001b[39;00m\n\u001b[32m    963\u001b[39m \u001b[38;5;66;03m# into a function to make it easier to test and refactor.\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m964\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_sample_return\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m    965\u001b[39m \u001b[43m    \u001b[49m\u001b[43mrun\u001b[49m\u001b[43m=\u001b[49m\u001b[43mrun\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m    966\u001b[39m \u001b[43m    \u001b[49m\u001b[43mtraces\u001b[49m\u001b[43m=\u001b[49m\u001b[43mtrace\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mif\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43misinstance\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mtrace\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mZarrTrace\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01melse\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mtraces\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m    967\u001b[39m \u001b[43m    \u001b[49m\u001b[43mtune\u001b[49m\u001b[43m=\u001b[49m\u001b[43mtune\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m    968\u001b[39m \u001b[43m    \u001b[49m\u001b[43mt_sampling\u001b[49m\u001b[43m=\u001b[49m\u001b[43mt_sampling\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m    969\u001b[39m \u001b[43m    \u001b[49m\u001b[43mdiscard_tuned_samples\u001b[49m\u001b[43m=\u001b[49m\u001b[43mdiscard_tuned_samples\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m    970\u001b[39m \u001b[43m    \u001b[49m\u001b[43mcompute_convergence_checks\u001b[49m\u001b[43m=\u001b[49m\u001b[43mcompute_convergence_checks\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m    971\u001b[39m \u001b[43m    \u001b[49m\u001b[43mreturn_inferencedata\u001b[49m\u001b[43m=\u001b[49m\u001b[43mreturn_inferencedata\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m    972\u001b[39m \u001b[43m    \u001b[49m\u001b[43mkeep_warning_stat\u001b[49m\u001b[43m=\u001b[49m\u001b[43mkeep_warning_stat\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m    973\u001b[39m \u001b[43m    \u001b[49m\u001b[43midata_kwargs\u001b[49m\u001b[43m=\u001b[49m\u001b[43midata_kwargs\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01mor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43m{\u001b[49m\u001b[43m}\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m    974\u001b[39m \u001b[43m    \u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m=\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m    975\u001b[39m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
+      "\u001b[36mFile \u001b[39m\u001b[32m~/Michal_Linux/git/GSoC/pymc-extras/.pixi/envs/default/lib/python3.12/site-packages/pymc/sampling/mcmc.py:1049\u001b[39m, in \u001b[36m_sample_return\u001b[39m\u001b[34m(run, traces, tune, t_sampling, discard_tuned_samples, compute_convergence_checks, return_inferencedata, keep_warning_stat, idata_kwargs, model)\u001b[39m\n\u001b[32m   1047\u001b[39m \u001b[38;5;66;03m# Pick and slice chains to keep the maximum number of samples\u001b[39;00m\n\u001b[32m   1048\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m discard_tuned_samples:\n\u001b[32m-> \u001b[39m\u001b[32m1049\u001b[39m     traces, length = \u001b[43m_choose_chains\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtraces\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtune\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m   1050\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m   1051\u001b[39m     traces, length = _choose_chains(traces, \u001b[32m0\u001b[39m)\n",
+      "\u001b[36mFile \u001b[39m\u001b[32m~/Michal_Linux/git/GSoC/pymc-extras/.pixi/envs/default/lib/python3.12/site-packages/pymc/backends/base.py:624\u001b[39m, in \u001b[36m_choose_chains\u001b[39m\u001b[34m(traces, tune)\u001b[39m\n\u001b[32m    622\u001b[39m lengths = [\u001b[38;5;28mmax\u001b[39m(\u001b[32m0\u001b[39m, \u001b[38;5;28mlen\u001b[39m(trace) - tune) \u001b[38;5;28;01mfor\u001b[39;00m trace \u001b[38;5;129;01min\u001b[39;00m traces]\n\u001b[32m    623\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28msum\u001b[39m(lengths):\n\u001b[32m--> \u001b[39m\u001b[32m624\u001b[39m     \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[33m\"\u001b[39m\u001b[33mNot enough samples to build a trace.\u001b[39m\u001b[33m\"\u001b[39m)\n\u001b[32m    626\u001b[39m idxs = np.argsort(lengths)\n\u001b[32m    627\u001b[39m l_sort = np.array(lengths)[idxs]\n",
+      "\u001b[31mValueError\u001b[39m: Not enough samples to build a trace."
+     ]
+    }
+   ],
+   "source": [
+    "rng = np.random.default_rng(12345)\n",
+    "n = 10000\n",
+    "d = 3\n",
+    "\n",
+    "# Initialise arrays\n",
+    "mu_true = rng.random(d)\n",
+    "cov_true = np.diag(rng.random(d))\n",
+    "Q_val = np.diag(rng.random(d))\n",
+    "cov_param_val = np.diag(rng.random(d))\n",
+    "\n",
+    "x_val = rng.random(d)\n",
+    "mu_val = rng.random(d)\n",
+    "\n",
+    "mu_mu = rng.random(d)\n",
+    "mu_cov = np.diag(np.ones(d))\n",
+    "cov_mu = rng.random(d**2)\n",
+    "cov_cov = np.diag(np.ones(d**2))\n",
+    "Q_mu = rng.random(d**2)\n",
+    "Q_cov = np.diag(np.ones(d**2))\n",
+    "\n",
+    "with pm.Model() as model:\n",
+    "    y_obs = rng.multivariate_normal(mean=mu_true, cov=cov_true, size=n)\n",
+    "\n",
+    "    mu_param = pm.MvNormal(\"mu_param\", mu=mu_mu, cov=mu_cov)\n",
+    "    # cov_param = np.abs(pm.MvNormal(\"cov_param\", mu=cov_mu, cov=cov_cov))\n",
+    "    # Q = pm.MvNormal(\"Q\", mu=Q_mu, cov=Q_cov)\n",
+    "\n",
+    "    x = pm.MvNormal(\"x\", mu=mu_param, tau=Q_val)\n",
+    "\n",
+    "    y = pm.MvNormal(\n",
+    "        \"y\",\n",
+    "        mu=x,\n",
+    "        cov=cov_param_val,  # cov_param.reshape((d, d)),\n",
+    "        observed=y_obs,\n",
+    "    )\n",
+    "\n",
+    "    # model.logp().dprint()\n",
+    "\n",
+    "    # x0, log_likelihood = get_log_marginal_likelihood(\n",
+    "    #     x=model.rvs_to_values[x],\n",
+    "    #     Q=Q_val,#Q.reshape((d, d)),\n",
+    "    #     mu=mu_param,\n",
+    "    #     optimizer_kwargs={\"tol\": 1e-8},\n",
+    "    # )\n",
+    "\n",
+    "    # print(model.free_RVs)\n",
+    "\n",
+    "    # # with pm.Model() as inla_model:\n",
+    "    # log_prior = pm.logp(mu_param, mu_mu)\n",
+    "    # log_posterior = log_likelihood + log_prior\n",
+    "    # # # log_posterior.dprint()\n",
+    "    # # # log_posterior = pytensor.graph.replace.graph_replace(log_posterior, {model.rvs_to_values[x]: x0})\n",
+    "    # # # log_posterior_model = model_from_fgraph(log_posterior, mutate_fgraph=True)\n",
+    "    # # idata = pm.sample()\n",
+    "    # # draws = pm.draw(mu_param)\n",
+    "    # idata = pm.sample_prior_predictive()\n",
+    "    # print(idata.prior)\n",
+    "    # print(draws)\n",
+    "\n",
+    "    # print(inla_model.free_RVs)\n",
+    "    # log_posterior.dprint()\n",
+    "    # pytensor.graph.fg.FunctionGraph(inputs=[model.rvs_to_values[mu_param], model.rvs_to_values[x]], outputs=[log_posterior])\n",
+    "\n",
+    "# inla_model = marginalize(model, [mu_param, cov_param])"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "681e14ea-3629-4cc5-bb7f-e08cad5df276",
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  },
+  {
+   "cell_type": "markdown",
+   "id": "e9c829f7-cf06-4402-b909-2a27d0dea07a",
+   "metadata": {},
+   "source": [
+    "True dataset:\n",
+    "\n",
+    "$y \\sim N(\\mu_{true}, \\Sigma_{true})$\n",
+    "\n",
+    "Model:\n",
+    "\n",
+    "$y|x, \\sigma \\sim N(Ax, \\sigma W)$\n",
+    "\n",
+    "Let $A=I$, $W=I$:\n",
+    "\n",
+    "$y|x, \\sigma \\sim N(x, \\sigma)$\n",
+    "\n",
+    "Comparing model and true data:\n",
+    "\n",
+    "$x = \\mu_{true}$\n",
+    "\n",
+    "$x|\\theta \\sim N(\\mu, Q^{-1})$\n",
+    "\n",
+    "$\\theta = (\\mu, \\Sigma_b, \\sigma)$\n",
+    "\n",
+    "Set $Q = I$ for now.\n",
+    "\n",
+    "$\\theta = (\\mu, \\sigma)$"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 27,
+   "id": "e344b7d0-f76e-4a28-9be9-884a2ba1f2c4",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "[autoreload of cutils_ext failed: Traceback (most recent call last):\n",
+      "  File \"/home/michaln/Michal_Linux/git/GSoC/pymc-extras/.pixi/envs/default/lib/python3.12/site-packages/IPython/extensions/autoreload.py\", line 283, in check\n",
+      "    superreload(m, reload, self.old_objects)\n",
+      "  File \"/home/michaln/Michal_Linux/git/GSoC/pymc-extras/.pixi/envs/default/lib/python3.12/site-packages/IPython/extensions/autoreload.py\", line 483, in superreload\n",
+      "    module = reload(module)\n",
+      "             ^^^^^^^^^^^^^^\n",
+      "  File \"/home/michaln/Michal_Linux/git/GSoC/pymc-extras/.pixi/envs/default/lib/python3.12/importlib/__init__.py\", line 130, in reload\n",
+      "    raise ModuleNotFoundError(f\"spec not found for the module {name!r}\", name=name)\n",
+      "ModuleNotFoundError: spec not found for the module 'cutils_ext'\n",
+      "]\n"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "\u001b[1m============================= test session starts ==============================\u001b[0m\n",
+      "platform linux -- Python 3.12.10, pytest-8.4.1, pluggy-1.6.0 -- /home/michaln/Michal_Linux/git/GSoC/pymc-extras/.pixi/envs/default/bin/python\n",
+      "cachedir: .pytest_cache\n",
+      "rootdir: /home/michaln/Michal_Linux/git/GSoC/pymc-extras\n",
+      "configfile: pyproject.toml\n",
+      "plugins: anyio-4.9.0\n",
+      "collected 1 item                                                               \u001b[0m\u001b[1m\n",
+      "\n",
+      "../tests/test_inla.py::test_get_conditional_gaussian_approximation \u001b[32mPASSED\u001b[0m\u001b[32m [100%]\u001b[0m\n",
+      "\n",
+      "\u001b[32m============================== \u001b[32m\u001b[1m1 passed\u001b[0m\u001b[32m in 5.99s\u001b[0m\u001b[32m ===============================\u001b[0m\n"
+     ]
+    }
+   ],
+   "source": [
+    "!python -m pytest -v /home/michaln/Michal_Linux/git/GSoC/pymc-extras/tests/test_inla.py"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "f1eb9a04-15dd-437d-bcf0-ce369feec912",
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3 (ipykernel)",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.12.11"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/pymc_extras/__init__.py b/pymc_extras/__init__.py
index cee0ffeb..6566c337 100644
--- a/pymc_extras/__init__.py
+++ b/pymc_extras/__init__.py
@@ -17,7 +17,7 @@
 
 from pymc_extras import gp, statespace, utils
 from pymc_extras.distributions import *
-from pymc_extras.inference import find_MAP, fit, fit_laplace, fit_pathfinder
+from pymc_extras.inference import find_MAP, fit, fit_laplace  # , fit_pathfinder
 from pymc_extras.model.marginal.marginal_model import (
     MarginalModel,
     marginalize,
diff --git a/pymc_extras/inference/__init__.py b/pymc_extras/inference/__init__.py
index 3e4d781d..05041762 100644
--- a/pymc_extras/inference/__init__.py
+++ b/pymc_extras/inference/__init__.py
@@ -16,6 +16,7 @@
 from pymc_extras.inference.fit import fit
 from pymc_extras.inference.inla import fit_INLA
 from pymc_extras.inference.laplace import fit_laplace
-from pymc_extras.inference.pathfinder.pathfinder import fit_pathfinder
+
+# from pymc_extras.inference.pathfinder.pathfinder import fit_pathfinder
 
 __all__ = ["fit", "fit_pathfinder", "fit_laplace", "find_MAP", "fit_INLA"]
diff --git a/pymc_extras/model/marginal/distributions.py b/pymc_extras/model/marginal/distributions.py
index 6e08352e..42ee3ff8 100644
--- a/pymc_extras/model/marginal/distributions.py
+++ b/pymc_extras/model/marginal/distributions.py
@@ -384,6 +384,7 @@ def laplace_marginal_rv_logp(op: MarginalLaplaceRV, values, *inputs, **kwargs):
 
     # Obtain the joint_logp graph of the inner RV graph
     inner_rv_values = dict(zip(inner_rvs, values))
+
     marginalized_vv = x.clone()
     rv_values = inner_rv_values | {x: marginalized_vv}
     logps_dict = conditional_logp(rv_values=rv_values, **kwargs)
@@ -396,7 +397,7 @@ def laplace_marginal_rv_logp(op: MarginalLaplaceRV, values, *inputs, **kwargs):
 
     from pytensor.tensor.optimize import minimize
 
-    # Maximize log(p(x | y, params)) wrt x to find mode x0
+    # Maximize log(p(x | y, params)) wrt x to find mode x0 # TODO args need to be user-supplied
     x0, _ = minimize(
         objective=-logp,
         x=marginalized_vv,
@@ -406,14 +407,32 @@ def laplace_marginal_rv_logp(op: MarginalLaplaceRV, values, *inputs, **kwargs):
         optimizer_kwargs={"tol": 1e-8},
     )
 
+    # print(op.__dict__)
+    # marginalized_rv_input_rvs = op.kwargs['marginalized_rv_input_rvs']
+    # x0 = op.kwargs['x0']
+    # log_laplace_approx = op.kwargs['log_laplace_approx']
+    # return logp - log_laplace_approx
+
+    rng = np.random.default_rng(12345)
+    d = 10
+    # Q = np.diag(rng.random(d))
+    from pymc import MvNormal
+
+    x = op.owner.inputs[0]
+    if not isinstance(x, MvNormal):
+        raise ValueError("Latent field x must be MvNormal.")
+    Q = x.owner.inputs[1]  # TODO double check this grabs the right thing
+    x0 = rng.random(d)
+
+    # x0 = pytensor.graph.replace.graph_replace(x0, {marginalized_vv: rng.random(d)})
+    # for rv in marginalized_rv_input_rvs:
+    #     x0 = pytensor.graph.replace.graph_replace(x0, {marginalized_vv: rng.random(d)})
+
     # require f''(x0) for Laplace approx
     hess = pytensor.gradient.hessian(logp, marginalized_vv)
     # hess = pytensor.graph.replace.graph_replace(hess, {marginalized_vv: x0})
 
     # Could be made more efficient with adding diagonals only
-    rng = np.random.default_rng(12345)
-    d = 3
-    Q = np.diag(rng.random(d))
     tau = Q - hess
 
     # Currently x is passed both as the query point for f(x, args) = logp(x | y, params) AND as an initial guess for x0. This may cause issues if the query point is
@@ -425,6 +444,7 @@ def laplace_marginal_rv_logp(op: MarginalLaplaceRV, values, *inputs, **kwargs):
     # marginalized_logp = logps_dict.pop(marginalized_vv)
     joint_logp = logp - log_laplace_approx
 
+    # TODO this might cause circularity issues by overwriting x as an input to the x0 minimizer
     joint_logp = pytensor.graph.replace.graph_replace(joint_logp, {marginalized_vv: x0})
 
     return joint_logp  # TODO check if pm.sample adds on p(params). Otherwise this is p(y|params) not p(params|y)