diff --git a/notebooks/INLA_testing.ipynb b/notebooks/INLA_testing.ipynb new file mode 100644 index 00000000..e2c5e27b --- /dev/null +++ b/notebooks/INLA_testing.ipynb @@ -0,0 +1,875 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "ffd6780e-1bfb-42f0-ba6a-055e9ffd1490", + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "5a2819fd-6e01-47c0-88b2-f2b5e4215b9b", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import pymc as pm\n", + "import pytensor.tensor as pt\n", + "\n", + "import pytensor\n", + "from pytensor.tensor.optimize import minimize\n", + "from pymc_extras.inference.inla import *\n", + "\n", + "from pymc.model.fgraph import fgraph_from_model, model_from_fgraph\n", + "from pymc_extras.model.marginal.marginal_model import marginalize" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "0ad97d05-f577-4793-ba6c-dd5f1300c022", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "ExpandDims{axis=0}.0\n", + "(1, None)\n", + "(1, None, 1)\n" + ] + }, + { + "data": { + "text/plain": [ + "Reshape{4}.0" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from pytensor.gradient import grad, hessian, jacobian\n", + "from pytensor.tensor.optimize import root\n", + "\n", + "x = pt.vector(\"x\")\n", + "var = pt.stack([x])\n", + "y = pt.stack([var[0], var[0] ** 2])\n", + "sol, _ = root(y, variables=var)\n", + "jacobian(sol, var, vectorize=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "2f475324-9fba-48a1-a79a-563a5e7818c9", + "metadata": {}, + "outputs": [], + "source": [ + "rng = np.random.default_rng(12345)\n", + "n = 10000\n", + "d = 10\n", + "\n", + "# Initialise arrays\n", + "mu_true = rng.random(d)\n", + "cov_true = np.diag(rng.random(d))\n", + "Q_val = np.diag(rng.random(d))\n", + "cov_param_val = np.diag(rng.random(d))\n", + "\n", + "x_val = rng.random(d)\n", + "mu_val = rng.random(d)\n", + "\n", + "mu_mu = rng.random(d)\n", + "mu_cov = np.diag(np.ones(d))\n", + "cov_mu = rng.random(d**2)\n", + "cov_cov = np.diag(np.ones(d**2))\n", + "Q_mu = rng.random(d**2)\n", + "Q_cov = np.diag(np.ones(d**2))\n", + "\n", + "with pm.Model() as model:\n", + " y_obs = rng.multivariate_normal(mean=mu_true, cov=cov_true, size=n)\n", + "\n", + " mu_param = pm.MvNormal(\"mu_param\", mu=mu_mu, cov=mu_cov)\n", + " # cov_param = np.abs(pm.MvNormal(\"cov_param\", mu=cov_mu, cov=cov_cov))\n", + " # Q = pm.MvNormal(\"Q\", mu=Q_mu, cov=Q_cov)\n", + "\n", + " x = pm.MvNormal(\"x\", mu=mu_param, tau=Q_val)\n", + "\n", + " y = pm.MvNormal(\n", + " \"y\",\n", + " mu=x,\n", + " cov=cov_param_val, # cov_param.reshape((d, d)),\n", + " observed=y_obs,\n", + " )\n", + "\n", + " # x0, log_likelihood = get_log_marginal_likelihood(\n", + " # x=model.rvs_to_values[x],\n", + " # Q=Q_val,#Q.reshape((d, d)),\n", + " # mu=mu_param,\n", + " # optimizer_kwargs={\"tol\": 1e-8},\n", + " # )\n", + "\n", + " # args = model.continuous_value_vars + model.discrete_value_vars\n", + " # for i, rv in enumerate(args):\n", + " # if rv == model.rvs_to_values[x]:\n", + " # args.pop(i)\n", + " # log_likelihood = pytensor.graph.replace.graph_replace(log_likelihood, {model.rvs_to_values[x]: rng.random(d)})\n", + " # log_laplace_approx = pytensor.function(args, log_likelihood)\n", + "\n", + " # pm.sample()" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "5121ab56-7841-4ff2-b9b0-016639d2bdb2", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "ModelFreeRV{transform=None} [id A] 'mu_param' 3\n", + " ├─ MvNormalRV{name='multivariate_normal', signature='(n),(n,n)->(n)', dtype='float64', inplace=False, method='cholesky'}.1 [id B] 'mu_param' 2\n", + " │ ├─ RNG() [id C]\n", + " │ ├─ NoneConst{None} [id D]\n", + " │ ├─ Second [id E] 1\n", + " │ │ ├─ Subtensor{:, i} [id F] 0\n", + " │ │ │ ├─ [[1. 0. 0. ... 0. 0. 1.]] [id G]\n", + " │ │ │ └─ -1 [id H]\n", + " │ │ └─ [0.2552323 ... .18013059] [id I]\n", + " │ └─ [[1. 0. 0. ... 0. 0. 1.]] [id G]\n", + " └─ mu_param [id J]\n", + "ModelFreeRV{transform=None} [id K] 'x' 8\n", + " ├─ MvNormalRV{name='multivariate_normal', signature='(n),(n,n)->(n)', dtype='float64', inplace=False, method='cholesky'}.1 [id L] 'x' 7\n", + " │ ├─ RNG() [id M]\n", + " │ ├─ NoneConst{None} [id D]\n", + " │ ├─ Second [id N] 6\n", + " │ │ ├─ Subtensor{:, i} [id O] 5\n", + " │ │ │ ├─ Blockwise{MatrixInverse, (m,m)->(m,m)} [id P] 4\n", + " │ │ │ │ └─ [[0.081594 ... 59856801]] [id Q]\n", + " │ │ │ └─ -1 [id R]\n", + " │ │ └─ ModelFreeRV{transform=None} [id A] 'mu_param' 3\n", + " │ │ └─ ···\n", + " │ └─ Blockwise{MatrixInverse, (m,m)->(m,m)} [id P] 4\n", + " │ └─ ···\n", + " └─ x [id S]\n", + "ModelObservedRV{transform=None} [id T] 'y' 14\n", + " ├─ MvNormalRV{name='multivariate_normal', signature='(n),(n,n)->(n)', dtype='float64', inplace=False, method='cholesky'}.1 [id U] 'y' 13\n", + " │ ├─ RNG() [id V]\n", + " │ ├─ [10000] [id W]\n", + " │ ├─ ExpandDims{axis=0} [id X] 12\n", + " │ │ └─ Second [id Y] 11\n", + " │ │ ├─ Subtensor{:, i} [id Z] 10\n", + " │ │ │ ├─ [[0.854741 ... 27377318]] [id BA]\n", + " │ │ │ └─ -1 [id BB]\n", + " │ │ └─ ModelFreeRV{transform=None} [id K] 'x' 8\n", + " │ │ └─ ···\n", + " │ └─ ExpandDims{axis=0} [id BC] 9\n", + " │ └─ [[0.854741 ... 27377318]] [id BA]\n", + " └─ y{[[ 0.64235 ... 56333986]]} [id BD]\n" + ] + }, + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "cluster10\n", + "\n", + "10\n", + "\n", + "\n", + "cluster10000 x 10\n", + "\n", + "10000 x 10\n", + "\n", + "\n", + "\n", + "x\n", + "\n", + "x\n", + "~\n", + "MvNormal\n", + "\n", + "\n", + "\n", + "y\n", + "\n", + "y\n", + "~\n", + "MvNormal\n", + "\n", + "\n", + "\n", + "x->y\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "mu_param\n", + "\n", + "mu_param\n", + "~\n", + "MvNormal\n", + "\n", + "\n", + "\n", + "mu_param->x\n", + "\n", + "\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "rvs_to_marginalize = [x]\n", + "\n", + "fg, memo = fgraph_from_model(model)\n", + "fg.dprint()\n", + "rvs_to_marginalize = [memo[rv] for rv in rvs_to_marginalize]\n", + "toposort = fg.toposort()\n", + "\n", + "# fg.dprint()\n", + "# print(rvs_to_marginalize)\n", + "# print(toposort)\n", + "\n", + "from pymc.model.fgraph import (\n", + " ModelFreeRV,\n", + " ModelValuedVar,\n", + ")\n", + "\n", + "from pymc_extras.model.marginal.graph_analysis import (\n", + " find_conditional_dependent_rvs,\n", + " find_conditional_input_rvs,\n", + " is_conditional_dependent,\n", + " subgraph_batch_dim_connection,\n", + ")\n", + "\n", + "from pymc_extras.model.marginal.marginal_model import (\n", + " _unique,\n", + " collect_shared_vars,\n", + " remove_model_vars,\n", + ")\n", + "\n", + "from pymc_extras.model.marginal.distributions import (\n", + " MarginalLaplaceRV,\n", + ")\n", + "\n", + "from pymc.pytensorf import collect_default_updates\n", + "\n", + "from pytensor.graph import (\n", + " FunctionGraph,\n", + " Variable,\n", + " clone_replace,\n", + ")\n", + "\n", + "for rv_to_marginalize in sorted(\n", + " rvs_to_marginalize,\n", + " key=lambda rv: toposort.index(rv.owner),\n", + " reverse=True,\n", + "):\n", + " all_rvs = [node.out for node in fg.toposort() if isinstance(node.op, ModelValuedVar)]\n", + "\n", + " dependent_rvs = find_conditional_dependent_rvs(rv_to_marginalize, all_rvs)\n", + " if not dependent_rvs:\n", + " # TODO: This should at most be a warning, not an error\n", + " raise ValueError(f\"No RVs depend on marginalized RV {rv_to_marginalize}\")\n", + "\n", + " # Issue warning for IntervalTransform on dependent RVs\n", + " for dependent_rv in dependent_rvs:\n", + " transform = dependent_rv.owner.op.transform\n", + "\n", + " # if isinstance(transform, IntervalTransform) or (\n", + " # isinstance(transform, Chain)\n", + " # and any(isinstance(tr, IntervalTransform) for tr in transform.transform_list)\n", + " # ):\n", + " # warnings.warn(\n", + " # f\"The transform {transform} for the variable {dependent_rv}, which depends on the \"\n", + " # f\"marginalized {rv_to_marginalize} may no longer work if bounds depended on other variables.\",\n", + " # UserWarning,\n", + " # )\n", + "\n", + " # Check that no deterministics or potentials depend on the rv to marginalize\n", + " for det in model.deterministics:\n", + " if is_conditional_dependent(memo[det], rv_to_marginalize, all_rvs):\n", + " raise NotImplementedError(\n", + " f\"Cannot marginalize {rv_to_marginalize} due to dependent Deterministic {det}\"\n", + " )\n", + " for pot in model.potentials:\n", + " if is_conditional_dependent(memo[pot], rv_to_marginalize, all_rvs):\n", + " raise NotImplementedError(\n", + " f\"Cannot marginalize {rv_to_marginalize} due to dependent Potential {pot}\"\n", + " )\n", + "\n", + " marginalized_rv_input_rvs = find_conditional_input_rvs([rv_to_marginalize], all_rvs)\n", + " other_direct_rv_ancestors = [\n", + " rv\n", + " for rv in find_conditional_input_rvs(dependent_rvs, all_rvs)\n", + " if rv is not rv_to_marginalize\n", + " ]\n", + " input_rvs = _unique((*marginalized_rv_input_rvs, *other_direct_rv_ancestors))\n", + "\n", + "pm.model_to_graphviz(model)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "7d2d4683-fc83-47a4-bca8-77085688c42f", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[mu_param, RNG(), RNG()]\n", + "[x, y, MvNormalRV{name='multivariate_normal', signature='(n),(n,n)->(n)', dtype='float64', inplace=False, method='cholesky'}.0, MvNormalRV{name='multivariate_normal', signature='(n),(n,n)->(n)', dtype='float64', inplace=False, method='cholesky'}.0]\n", + "[mu_param, RNG(), RNG()]\n", + "[x, y, MvNormalRV{name='multivariate_normal', signature='(n),(n,n)->(n)', dtype='float64', inplace=False, method='cholesky'}.0, MvNormalRV{name='multivariate_normal', signature='(n),(n,n)->(n)', dtype='float64', inplace=False, method='cholesky'}.0]\n", + "[x, y, MarginalLaplaceRV{inline=False}.2, MarginalLaplaceRV{inline=False}.3]\n", + "ModelFreeRV{transform=None} [id A] 'mu_param' 3\n", + " ├─ MvNormalRV{name='multivariate_normal', signature='(n),(n,n)->(n)', dtype='float64', inplace=False, method='cholesky'}.1 [id B] 'mu_param' 2\n", + " │ ├─ RNG() [id C]\n", + " │ ├─ NoneConst{None} [id D]\n", + " │ ├─ Second [id E] 1\n", + " │ │ ├─ Subtensor{:, i} [id F] 0\n", + " │ │ │ ├─ [[1. 0. 0. ... 0. 0. 1.]] [id G]\n", + " │ │ │ └─ -1 [id H]\n", + " │ │ └─ [0.2552323 ... .18013059] [id I]\n", + " │ └─ [[1. 0. 0. ... 0. 0. 1.]] [id G]\n", + " └─ mu_param [id J]\n", + "MarginalLaplaceRV{inline=False}.0 [id K] 'x' 4\n", + " ├─ ModelFreeRV{transform=None} [id A] 'mu_param' 3\n", + " │ └─ ···\n", + " ├─ RNG() [id L]\n", + " └─ RNG() [id M]\n", + "ModelObservedRV{transform=None} [id N] 'y' 5\n", + " ├─ MarginalLaplaceRV{inline=False}.1 [id K] 'y' 4\n", + " │ └─ ···\n", + " └─ y{[[ 0.64235 ... 56333986]]} [id O]\n", + "\n", + "Inner graphs:\n", + "\n", + "MarginalLaplaceRV{inline=False} [id K]\n", + " ← MvNormalRV{name='multivariate_normal', signature='(n),(n,n)->(n)', dtype='float64', inplace=False, method='cholesky'}.1 [id P] 'x'\n", + " ├─ *2- [id Q]\n", + " ├─ NoneConst{None} [id D]\n", + " ├─ Second [id R]\n", + " │ ├─ Subtensor{:, i} [id S]\n", + " │ │ ├─ Blockwise{MatrixInverse, (m,m)->(m,m)} [id T]\n", + " │ │ │ └─ [[0.081594 ... 59856801]] [id U]\n", + " │ │ └─ -1 [id V]\n", + " │ └─ *0- [id W]\n", + " └─ Blockwise{MatrixInverse, (m,m)->(m,m)} [id T]\n", + " └─ ···\n", + " ← MvNormalRV{name='multivariate_normal', signature='(n),(n,n)->(n)', dtype='float64', inplace=False, method='cholesky'}.1 [id X] 'y'\n", + " ├─ *1- [id Y]\n", + " ├─ [10000] [id Z]\n", + " ├─ ExpandDims{axis=0} [id BA]\n", + " │ └─ Second [id BB]\n", + " │ ├─ Subtensor{:, i} [id BC]\n", + " │ │ ├─ [[0.854741 ... 27377318]] [id BD]\n", + " │ │ └─ -1 [id BE]\n", + " │ └─ MvNormalRV{name='multivariate_normal', signature='(n),(n,n)->(n)', dtype='float64', inplace=False, method='cholesky'}.1 [id P] 'x'\n", + " │ └─ ···\n", + " └─ ExpandDims{axis=0} [id BF]\n", + " └─ [[0.854741 ... 27377318]] [id BD]\n", + " ← MvNormalRV{name='multivariate_normal', signature='(n),(n,n)->(n)', dtype='float64', inplace=False, method='cholesky'}.0 [id X]\n", + " └─ ···\n", + " ← MvNormalRV{name='multivariate_normal', signature='(n),(n,n)->(n)', dtype='float64', inplace=False, method='cholesky'}.0 [id P]\n", + " └─ ···\n" + ] + }, + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "cluster10\n", + "\n", + "10\n", + "\n", + "\n", + "cluster10000 x 10\n", + "\n", + "10000 x 10\n", + "\n", + "\n", + "\n", + "mu_param\n", + "\n", + "mu_param\n", + "~\n", + "MvNormal\n", + "\n", + "\n", + "\n", + "y\n", + "\n", + "y\n", + "~\n", + "MarginalLaplace\n", + "\n", + "\n", + "\n", + "mu_param->y\n", + "\n", + "\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "output_rvs = [rv_to_marginalize, *dependent_rvs]\n", + "rng_updates = collect_default_updates(output_rvs, inputs=input_rvs, must_be_shared=False)\n", + "outputs = output_rvs + list(rng_updates.values())\n", + "inputs = input_rvs + list(rng_updates.keys())\n", + "# Add any other shared variable inputs\n", + "inputs += collect_shared_vars(output_rvs, blockers=inputs)\n", + "\n", + "inner_inputs = [inp.clone() for inp in inputs]\n", + "inner_outputs = clone_replace(outputs, replace=dict(zip(inputs, inner_inputs)))\n", + "inner_outputs = remove_model_vars(inner_outputs)\n", + "\n", + "marginalize_constructor = MarginalLaplaceRV\n", + "\n", + "_, _, *dims = rv_to_marginalize.owner.inputs\n", + "marginalization_op = marginalize_constructor(\n", + " inputs=inner_inputs,\n", + " outputs=inner_outputs,\n", + " dims_connections=[\n", + " (None,),\n", + " ], # dependent_rvs_dim_connections, # TODO NOT SURE WHAT THIS IS\n", + " dims=dims,\n", + " # x0=x0,\n", + " # marginalized_rv_input_rvs=marginalized_rv_input_rvs\n", + ")\n", + "\n", + "new_outputs = marginalization_op(*inputs)\n", + "for old_output, new_output in zip(outputs, new_outputs):\n", + " new_output.name = old_output.name\n", + "\n", + "model_replacements = []\n", + "for old_output, new_output in zip(outputs, new_outputs):\n", + " if old_output is rv_to_marginalize or not isinstance(old_output.owner.op, ModelValuedVar):\n", + " # Replace the marginalized ModelFreeRV (or non model-variables) themselves\n", + " var_to_replace = old_output\n", + " else:\n", + " # Replace the underlying RV, keeping the same value, transform and dims\n", + " var_to_replace = old_output.owner.inputs[0]\n", + " model_replacements.append((var_to_replace, new_output))\n", + "\n", + "print(inner_inputs)\n", + "print(inner_outputs)\n", + "print(inputs)\n", + "\n", + "print(outputs)\n", + "print(new_outputs)\n", + "\n", + "fg.replace_all(model_replacements)\n", + "fg.dprint()\n", + "\n", + "model_marg = model_from_fgraph(fg, mutate_fgraph=True)\n", + "pm.model_to_graphviz(model_marg)" + ] + }, + { + "cell_type": "code", + "execution_count": 93, + "id": "793248a1-8088-41fc-9dbd-c58e596e0df7", + "metadata": {}, + "outputs": [ + { + "ename": "TypeError", + "evalue": "Only tensors with the same number of dimensions can be joined. Input ndims were: [3, 2]", + "output_type": "error", + "traceback": [ + "\u001b[31m---------------------------------------------------------------------------\u001b[39m", + "\u001b[31mTypeError\u001b[39m Traceback (most recent call last)", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[93]\u001b[39m\u001b[32m, line 8\u001b[39m\n\u001b[32m 5\u001b[39m b = pt.vector(\u001b[33m'\u001b[39m\u001b[33mb\u001b[39m\u001b[33m'\u001b[39m, shape=(\u001b[32m3\u001b[39m,))\n\u001b[32m 7\u001b[39m eqns = pt.stack([A @ x - b])\n\u001b[32m----> \u001b[39m\u001b[32m8\u001b[39m var = \u001b[43mpt\u001b[49m\u001b[43m.\u001b[49m\u001b[43mstack\u001b[49m\u001b[43m(\u001b[49m\u001b[43m[\u001b[49m\u001b[43mA\u001b[49m\u001b[43m,\u001b[49m\u001b[43mb\u001b[49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 10\u001b[39m soln, _ = root(eqns, variables=var)\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/git/pytensor/pytensor/tensor/basic.py:2977\u001b[39m, in \u001b[36mstack\u001b[39m\u001b[34m(tensors, axis)\u001b[39m\n\u001b[32m 2973\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mconcatenate\u001b[39m(tensor_list, axis=\u001b[32m0\u001b[39m):\n\u001b[32m 2974\u001b[39m \u001b[38;5;250m \u001b[39m\u001b[33;03m\"\"\"Alias for `join`(axis, *tensor_list).\u001b[39;00m\n\u001b[32m 2975\u001b[39m \n\u001b[32m 2976\u001b[39m \u001b[33;03m This function is similar to `join`, but uses the signature of\u001b[39;00m\n\u001b[32m-> \u001b[39m\u001b[32m2977\u001b[39m \u001b[33;03m numpy's concatenate function.\u001b[39;00m\n\u001b[32m 2978\u001b[39m \n\u001b[32m 2979\u001b[39m \u001b[33;03m Raises\u001b[39;00m\n\u001b[32m 2980\u001b[39m \u001b[33;03m ------\u001b[39;00m\n\u001b[32m 2981\u001b[39m \u001b[33;03m TypeError\u001b[39;00m\n\u001b[32m 2982\u001b[39m \u001b[33;03m The tensor_list must be a tuple or list.\u001b[39;00m\n\u001b[32m 2983\u001b[39m \n\u001b[32m 2984\u001b[39m \u001b[33;03m \"\"\"\u001b[39;00m\n\u001b[32m 2985\u001b[39m \u001b[38;5;66;03m# Check someone did not make the common mistake to do something like:\u001b[39;00m\n\u001b[32m 2986\u001b[39m \u001b[38;5;66;03m# c = concatenate(x, y)\u001b[39;00m\n\u001b[32m 2987\u001b[39m \u001b[38;5;66;03m# instead of\u001b[39;00m\n\u001b[32m 2988\u001b[39m \u001b[38;5;66;03m# c = concatenate((x, y))\u001b[39;00m\n\u001b[32m 2989\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(tensor_list, \u001b[38;5;28mtuple\u001b[39m | \u001b[38;5;28mlist\u001b[39m):\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/git/pytensor/pytensor/tensor/basic.py:2817\u001b[39m, in \u001b[36mjoin\u001b[39m\u001b[34m(axis, *tensors_list)\u001b[39m\n\u001b[32m 2815\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m tensors_list[\u001b[32m0\u001b[39m]\n\u001b[32m 2816\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m2817\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_join\u001b[49m\u001b[43m(\u001b[49m\u001b[43maxis\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43mtensors_list\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/git/pytensor/pytensor/graph/op.py:293\u001b[39m, in \u001b[36mOp.__call__\u001b[39m\u001b[34m(self, name, return_list, *inputs, **kwargs)\u001b[39m\n\u001b[32m 249\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34m__call__\u001b[39m(\n\u001b[32m 250\u001b[39m \u001b[38;5;28mself\u001b[39m, *inputs: Any, name=\u001b[38;5;28;01mNone\u001b[39;00m, return_list=\u001b[38;5;28;01mFalse\u001b[39;00m, **kwargs\n\u001b[32m 251\u001b[39m ) -> Variable | \u001b[38;5;28mlist\u001b[39m[Variable]:\n\u001b[32m 252\u001b[39m \u001b[38;5;250m \u001b[39m\u001b[33mr\u001b[39m\u001b[33;03m\"\"\"Construct an `Apply` node using :meth:`Op.make_node` and return its outputs.\u001b[39;00m\n\u001b[32m 253\u001b[39m \n\u001b[32m 254\u001b[39m \u001b[33;03m This method is just a wrapper around :meth:`Op.make_node`.\u001b[39;00m\n\u001b[32m (...)\u001b[39m\u001b[32m 291\u001b[39m \n\u001b[32m 292\u001b[39m \u001b[33;03m \"\"\"\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m293\u001b[39m node = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mmake_node\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 294\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m name \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[32m 295\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(node.outputs) == \u001b[32m1\u001b[39m:\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/git/pytensor/pytensor/tensor/basic.py:2510\u001b[39m, in \u001b[36mJoin.make_node\u001b[39m\u001b[34m(self, axis, *tensors)\u001b[39m\n\u001b[32m 2507\u001b[39m ndim = tensors[\u001b[32m0\u001b[39m].type.ndim\n\u001b[32m 2509\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m builtins.all(x.ndim == ndim \u001b[38;5;28;01mfor\u001b[39;00m x \u001b[38;5;129;01min\u001b[39;00m tensors):\n\u001b[32m-> \u001b[39m\u001b[32m2510\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(\n\u001b[32m 2511\u001b[39m \u001b[33m\"\u001b[39m\u001b[33mOnly tensors with the same number of dimensions can be joined. \u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m 2512\u001b[39m \u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mInput ndims were: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m[x.ndim\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mfor\u001b[39;00m\u001b[38;5;250m \u001b[39mx\u001b[38;5;250m \u001b[39m\u001b[38;5;129;01min\u001b[39;00m\u001b[38;5;250m \u001b[39mtensors]\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m\"\u001b[39m\n\u001b[32m 2513\u001b[39m )\n\u001b[32m 2515\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[32m 2516\u001b[39m static_axis = \u001b[38;5;28mint\u001b[39m(get_scalar_constant_value(axis))\n", + "\u001b[31mTypeError\u001b[39m: Only tensors with the same number of dimensions can be joined. Input ndims were: [3, 2]" + ] + } + ], + "source": [ + "from pytensor.tensor.optimize import root\n", + "\n", + "A = pt.matrix(\"A\", shape=(3, 3))\n", + "x = np.ones((3, 1))\n", + "b = pt.vector(\"b\", shape=(3,))\n", + "\n", + "eqns = pt.stack([A @ x - b])\n", + "var = pt.stack([A, b])\n", + "\n", + "soln, _ = root(eqns, variables=var)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "07f8abf3-6158-4d62-83d2-2cffe65aae91", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(10,)\n", + "(10, 1, 1)\n" + ] + }, + { + "data": { + "text/plain": [ + "Reshape{4}.0" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from pytensor.tensor.math import tensordot\n", + "\n", + "a = pt.tensor(\"a\", shape=(10,))\n", + "b = pt.tensor(\"b\", shape=(10, 1, 1))\n", + "\n", + "print(a.type.shape)\n", + "print(b.type.shape)\n", + "# print(b.T.type.shape)\n", + "\n", + "tensordot(a, b, axes=0)" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "id": "2eee8e73-5305-4472-8f9c-58a689f9471e", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "ename": "AttributeError", + "evalue": "'MarginalLaplaceRV' object has no attribute 'owner'", + "output_type": "error", + "traceback": [ + "\u001b[31m---------------------------------------------------------------------------\u001b[39m", + "\u001b[31mAttributeError\u001b[39m Traceback (most recent call last)", + "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[31]\u001b[39m\u001b[32m, line 2\u001b[39m\n\u001b[32m 1\u001b[39m \u001b[38;5;28;01mwith\u001b[39;00m model_marg \u001b[38;5;28;01mas\u001b[39;00m m:\n\u001b[32m----> \u001b[39m\u001b[32m2\u001b[39m \u001b[43mpm\u001b[49m\u001b[43m.\u001b[49m\u001b[43msample\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/git/pymc/pymc/sampling/mcmc.py:783\u001b[39m, in \u001b[36msample\u001b[39m\u001b[34m(draws, tune, chains, cores, random_seed, progressbar, progressbar_theme, step, var_names, nuts_sampler, initvals, init, jitter_max_retries, n_init, trace, discard_tuned_samples, compute_convergence_checks, keep_warning_stat, return_inferencedata, idata_kwargs, nuts_sampler_kwargs, callback, mp_ctx, blas_cores, model, compile_kwargs, **kwargs)\u001b[39m\n\u001b[32m 780\u001b[39m _log.warning(msg)\n\u001b[32m 782\u001b[39m provided_steps, selected_steps = assign_step_methods(model, step, methods=pm.STEP_METHODS)\n\u001b[32m--> \u001b[39m\u001b[32m783\u001b[39m exclusive_nuts = (\n\u001b[32m 784\u001b[39m \u001b[38;5;66;03m# User provided an instantiated NUTS step, and nothing else is needed\u001b[39;00m\n\u001b[32m 785\u001b[39m (\u001b[38;5;129;01mnot\u001b[39;00m selected_steps \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(provided_steps) == \u001b[32m1\u001b[39m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(provided_steps[\u001b[32m0\u001b[39m], NUTS))\n\u001b[32m 786\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m\n\u001b[32m 787\u001b[39m \u001b[38;5;66;03m# Only automatically selected NUTS step is needed\u001b[39;00m\n\u001b[32m 788\u001b[39m (\n\u001b[32m 789\u001b[39m \u001b[38;5;129;01mnot\u001b[39;00m provided_steps\n\u001b[32m 790\u001b[39m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(selected_steps) == \u001b[32m1\u001b[39m\n\u001b[32m 791\u001b[39m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28missubclass\u001b[39m(\u001b[38;5;28mnext\u001b[39m(\u001b[38;5;28miter\u001b[39m(selected_steps)), NUTS)\n\u001b[32m 792\u001b[39m )\n\u001b[32m 793\u001b[39m )\n\u001b[32m 795\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m nuts_sampler != \u001b[33m\"\u001b[39m\u001b[33mpymc\u001b[39m\u001b[33m\"\u001b[39m:\n\u001b[32m 796\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m exclusive_nuts:\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/git/pymc/pymc/sampling/mcmc.py:245\u001b[39m, in \u001b[36massign_step_methods\u001b[39m\u001b[34m(model, step, methods)\u001b[39m\n\u001b[32m 243\u001b[39m methods_list: \u001b[38;5;28mlist\u001b[39m[\u001b[38;5;28mtype\u001b[39m[BlockedStep]] = \u001b[38;5;28mlist\u001b[39m(methods \u001b[38;5;129;01mor\u001b[39;00m pm.STEP_METHODS)\n\u001b[32m 244\u001b[39m selected_steps: \u001b[38;5;28mdict\u001b[39m[\u001b[38;5;28mtype\u001b[39m[BlockedStep], \u001b[38;5;28mlist\u001b[39m] = {}\n\u001b[32m--> \u001b[39m\u001b[32m245\u001b[39m model_logp = \u001b[43mmodel\u001b[49m\u001b[43m.\u001b[49m\u001b[43mlogp\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 247\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m var \u001b[38;5;129;01min\u001b[39;00m model.value_vars:\n\u001b[32m 248\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m var \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;129;01min\u001b[39;00m assigned_vars:\n\u001b[32m 249\u001b[39m \u001b[38;5;66;03m# determine if a gradient can be computed\u001b[39;00m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/git/pymc/pymc/model/core.py:691\u001b[39m, in \u001b[36mModel.logp\u001b[39m\u001b[34m(self, vars, jacobian, sum)\u001b[39m\n\u001b[32m 689\u001b[39m rv_logps: \u001b[38;5;28mlist\u001b[39m[TensorVariable] = []\n\u001b[32m 690\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m rvs:\n\u001b[32m--> \u001b[39m\u001b[32m691\u001b[39m rv_logps = \u001b[43mtransformed_conditional_logp\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 692\u001b[39m \u001b[43m \u001b[49m\u001b[43mrvs\u001b[49m\u001b[43m=\u001b[49m\u001b[43mrvs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 693\u001b[39m \u001b[43m \u001b[49m\u001b[43mrvs_to_values\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mrvs_to_values\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 694\u001b[39m \u001b[43m \u001b[49m\u001b[43mrvs_to_transforms\u001b[49m\u001b[43m=\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mrvs_to_transforms\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 695\u001b[39m \u001b[43m \u001b[49m\u001b[43mjacobian\u001b[49m\u001b[43m=\u001b[49m\u001b[43mjacobian\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 696\u001b[39m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 697\u001b[39m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(rv_logps, \u001b[38;5;28mlist\u001b[39m)\n\u001b[32m 699\u001b[39m \u001b[38;5;66;03m# Replace random variables by their value variables in potential terms\u001b[39;00m\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/git/pymc/pymc/logprob/basic.py:570\u001b[39m, in \u001b[36mtransformed_conditional_logp\u001b[39m\u001b[34m(rvs, rvs_to_values, rvs_to_transforms, jacobian, **kwargs)\u001b[39m\n\u001b[32m 567\u001b[39m transform_rewrite = TransformValuesRewrite(values_to_transforms) \u001b[38;5;66;03m# type: ignore[arg-type]\u001b[39;00m\n\u001b[32m 569\u001b[39m kwargs.setdefault(\u001b[33m\"\u001b[39m\u001b[33mwarn_rvs\u001b[39m\u001b[33m\"\u001b[39m, \u001b[38;5;28;01mFalse\u001b[39;00m)\n\u001b[32m--> \u001b[39m\u001b[32m570\u001b[39m temp_logp_terms = \u001b[43mconditional_logp\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 571\u001b[39m \u001b[43m \u001b[49m\u001b[43mrvs_to_values\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 572\u001b[39m \u001b[43m \u001b[49m\u001b[43mextra_rewrites\u001b[49m\u001b[43m=\u001b[49m\u001b[43mtransform_rewrite\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 573\u001b[39m \u001b[43m \u001b[49m\u001b[43muse_jacobian\u001b[49m\u001b[43m=\u001b[49m\u001b[43mjacobian\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 574\u001b[39m \u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 575\u001b[39m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 577\u001b[39m \u001b[38;5;66;03m# The function returns the logp for every single value term we provided to it.\u001b[39;00m\n\u001b[32m 578\u001b[39m \u001b[38;5;66;03m# This includes the extra values we plugged in above, so we filter those we\u001b[39;00m\n\u001b[32m 579\u001b[39m \u001b[38;5;66;03m# actually wanted in the same order they were given in.\u001b[39;00m\n\u001b[32m 580\u001b[39m logp_terms = {}\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/git/pymc/pymc/logprob/basic.py:500\u001b[39m, in \u001b[36mconditional_logp\u001b[39m\u001b[34m(rv_values, warn_rvs, ir_rewriter, extra_rewrites, **kwargs)\u001b[39m\n\u001b[32m 497\u001b[39m node_values = remapped_vars[: \u001b[38;5;28mlen\u001b[39m(node_values)]\n\u001b[32m 498\u001b[39m node_inputs = remapped_vars[\u001b[38;5;28mlen\u001b[39m(node_values) :]\n\u001b[32m--> \u001b[39m\u001b[32m500\u001b[39m node_logprobs = \u001b[43m_logprob\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 501\u001b[39m \u001b[43m \u001b[49m\u001b[43mnode\u001b[49m\u001b[43m.\u001b[49m\u001b[43mop\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 502\u001b[39m \u001b[43m \u001b[49m\u001b[43mnode_values\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 503\u001b[39m \u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43mnode_inputs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 504\u001b[39m \u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 505\u001b[39m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 507\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(node_logprobs, \u001b[38;5;28mlist\u001b[39m | \u001b[38;5;28mtuple\u001b[39m):\n\u001b[32m 508\u001b[39m node_logprobs = [node_logprobs]\n", + "\u001b[36mFile \u001b[39m\u001b[32m~/git/pymc-extras/.pixi/envs/default/lib/python3.12/functools.py:912\u001b[39m, in \u001b[36msingledispatch..wrapper\u001b[39m\u001b[34m(*args, **kw)\u001b[39m\n\u001b[32m 908\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m args:\n\u001b[32m 909\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(\u001b[33mf\u001b[39m\u001b[33m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mfuncname\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m requires at least \u001b[39m\u001b[33m'\u001b[39m\n\u001b[32m 910\u001b[39m \u001b[33m'\u001b[39m\u001b[33m1 positional argument\u001b[39m\u001b[33m'\u001b[39m)\n\u001b[32m--> \u001b[39m\u001b[32m912\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mdispatch\u001b[49m\u001b[43m(\u001b[49m\u001b[43margs\u001b[49m\u001b[43m[\u001b[49m\u001b[32;43m0\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m.\u001b[49m\u001b[34;43m__class__\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkw\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[36mFile \u001b[39m\u001b[32m:17\u001b[39m, in \u001b[36mlaplace_marginal_rv_logp\u001b[39m\u001b[34m(op, values, *inputs, **kwargs)\u001b[39m\n", + "\u001b[31mAttributeError\u001b[39m: 'MarginalLaplaceRV' object has no attribute 'owner'" + ] + } + ], + "source": [ + "with model_marg as m:\n", + " pm.sample()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "991b7cbd-4ef3-487a-9ece-139226883502", + "metadata": {}, + "outputs": [], + "source": [ + "[2, 2, 3, 4, 2, 2, 2, 3, 3, 4]" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "id": "42203b08-519a-4676-9a99-34a5d92d4c5d", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Initializing NUTS using jitter+adapt_diag...\n", + "Multiprocess sampling (4 chains in 4 jobs)\n", + "NUTS: [mu_param, Q, x]\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "c6253c53d3124ca1add318ff3828d5c2", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "ename": "ValueError",
+     "evalue": "Not enough samples to build a trace.",
+     "output_type": "error",
+     "traceback": [
+      "\u001b[31m---------------------------------------------------------------------------\u001b[39m",
+      "\u001b[31mValueError\u001b[39m                                Traceback (most recent call last)",
+      "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[40]\u001b[39m\u001b[32m, line 46\u001b[39m\n\u001b[32m     37\u001b[39m     \u001b[38;5;66;03m# model.logp().dprint()\u001b[39;00m\n\u001b[32m     39\u001b[39m     x0, log_likelihood = get_log_marginal_likelihood(\n\u001b[32m     40\u001b[39m         x=model.rvs_to_values[x],\n\u001b[32m     41\u001b[39m         Q=Q.reshape((d, d)),\u001b[38;5;66;03m#Q_val,\u001b[39;00m\n\u001b[32m     42\u001b[39m         mu=mu_param,\n\u001b[32m     43\u001b[39m         optimizer_kwargs={\u001b[33m\"\u001b[39m\u001b[33mtol\u001b[39m\u001b[33m\"\u001b[39m: \u001b[32m1e-8\u001b[39m},\n\u001b[32m     44\u001b[39m     )\n\u001b[32m---> \u001b[39m\u001b[32m46\u001b[39m     \u001b[43mpm\u001b[49m\u001b[43m.\u001b[49m\u001b[43msample\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m     48\u001b[39m     \u001b[38;5;66;03m# print(model.free_RVs)\u001b[39;00m\n\u001b[32m     49\u001b[39m \n\u001b[32m     50\u001b[39m     \u001b[38;5;66;03m# # with pm.Model() as inla_model:\u001b[39;00m\n\u001b[32m   (...)\u001b[39m\u001b[32m     65\u001b[39m \n\u001b[32m     66\u001b[39m \u001b[38;5;66;03m# inla_model = marginalize(model, [mu_param, cov_param])\u001b[39;00m\n",
+      "\u001b[36mFile \u001b[39m\u001b[32m~/Michal_Linux/git/GSoC/pymc-extras/.pixi/envs/default/lib/python3.12/site-packages/pymc/sampling/mcmc.py:964\u001b[39m, in \u001b[36msample\u001b[39m\u001b[34m(draws, tune, chains, cores, random_seed, progressbar, progressbar_theme, step, var_names, nuts_sampler, initvals, init, jitter_max_retries, n_init, trace, discard_tuned_samples, compute_convergence_checks, keep_warning_stat, return_inferencedata, idata_kwargs, nuts_sampler_kwargs, callback, mp_ctx, blas_cores, model, compile_kwargs, **kwargs)\u001b[39m\n\u001b[32m    960\u001b[39m t_sampling = time.time() - t_start\n\u001b[32m    962\u001b[39m \u001b[38;5;66;03m# Packaging, validating and returning the result was extracted\u001b[39;00m\n\u001b[32m    963\u001b[39m \u001b[38;5;66;03m# into a function to make it easier to test and refactor.\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m964\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_sample_return\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m    965\u001b[39m \u001b[43m    \u001b[49m\u001b[43mrun\u001b[49m\u001b[43m=\u001b[49m\u001b[43mrun\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m    966\u001b[39m \u001b[43m    \u001b[49m\u001b[43mtraces\u001b[49m\u001b[43m=\u001b[49m\u001b[43mtrace\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mif\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;28;43misinstance\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mtrace\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mZarrTrace\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01melse\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mtraces\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m    967\u001b[39m \u001b[43m    \u001b[49m\u001b[43mtune\u001b[49m\u001b[43m=\u001b[49m\u001b[43mtune\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m    968\u001b[39m \u001b[43m    \u001b[49m\u001b[43mt_sampling\u001b[49m\u001b[43m=\u001b[49m\u001b[43mt_sampling\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m    969\u001b[39m \u001b[43m    \u001b[49m\u001b[43mdiscard_tuned_samples\u001b[49m\u001b[43m=\u001b[49m\u001b[43mdiscard_tuned_samples\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m    970\u001b[39m \u001b[43m    \u001b[49m\u001b[43mcompute_convergence_checks\u001b[49m\u001b[43m=\u001b[49m\u001b[43mcompute_convergence_checks\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m    971\u001b[39m \u001b[43m    \u001b[49m\u001b[43mreturn_inferencedata\u001b[49m\u001b[43m=\u001b[49m\u001b[43mreturn_inferencedata\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m    972\u001b[39m \u001b[43m    \u001b[49m\u001b[43mkeep_warning_stat\u001b[49m\u001b[43m=\u001b[49m\u001b[43mkeep_warning_stat\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m    973\u001b[39m \u001b[43m    \u001b[49m\u001b[43midata_kwargs\u001b[49m\u001b[43m=\u001b[49m\u001b[43midata_kwargs\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01mor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43m{\u001b[49m\u001b[43m}\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m    974\u001b[39m \u001b[43m    \u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m=\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m    975\u001b[39m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
+      "\u001b[36mFile \u001b[39m\u001b[32m~/Michal_Linux/git/GSoC/pymc-extras/.pixi/envs/default/lib/python3.12/site-packages/pymc/sampling/mcmc.py:1049\u001b[39m, in \u001b[36m_sample_return\u001b[39m\u001b[34m(run, traces, tune, t_sampling, discard_tuned_samples, compute_convergence_checks, return_inferencedata, keep_warning_stat, idata_kwargs, model)\u001b[39m\n\u001b[32m   1047\u001b[39m \u001b[38;5;66;03m# Pick and slice chains to keep the maximum number of samples\u001b[39;00m\n\u001b[32m   1048\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m discard_tuned_samples:\n\u001b[32m-> \u001b[39m\u001b[32m1049\u001b[39m     traces, length = \u001b[43m_choose_chains\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtraces\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtune\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m   1050\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m   1051\u001b[39m     traces, length = _choose_chains(traces, \u001b[32m0\u001b[39m)\n",
+      "\u001b[36mFile \u001b[39m\u001b[32m~/Michal_Linux/git/GSoC/pymc-extras/.pixi/envs/default/lib/python3.12/site-packages/pymc/backends/base.py:624\u001b[39m, in \u001b[36m_choose_chains\u001b[39m\u001b[34m(traces, tune)\u001b[39m\n\u001b[32m    622\u001b[39m lengths = [\u001b[38;5;28mmax\u001b[39m(\u001b[32m0\u001b[39m, \u001b[38;5;28mlen\u001b[39m(trace) - tune) \u001b[38;5;28;01mfor\u001b[39;00m trace \u001b[38;5;129;01min\u001b[39;00m traces]\n\u001b[32m    623\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28msum\u001b[39m(lengths):\n\u001b[32m--> \u001b[39m\u001b[32m624\u001b[39m     \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[33m\"\u001b[39m\u001b[33mNot enough samples to build a trace.\u001b[39m\u001b[33m\"\u001b[39m)\n\u001b[32m    626\u001b[39m idxs = np.argsort(lengths)\n\u001b[32m    627\u001b[39m l_sort = np.array(lengths)[idxs]\n",
+      "\u001b[31mValueError\u001b[39m: Not enough samples to build a trace."
+     ]
+    }
+   ],
+   "source": [
+    "rng = np.random.default_rng(12345)\n",
+    "n = 10000\n",
+    "d = 3\n",
+    "\n",
+    "# Initialise arrays\n",
+    "mu_true = rng.random(d)\n",
+    "cov_true = np.diag(rng.random(d))\n",
+    "Q_val = np.diag(rng.random(d))\n",
+    "cov_param_val = np.diag(rng.random(d))\n",
+    "\n",
+    "x_val = rng.random(d)\n",
+    "mu_val = rng.random(d)\n",
+    "\n",
+    "mu_mu = rng.random(d)\n",
+    "mu_cov = np.diag(np.ones(d))\n",
+    "cov_mu = rng.random(d**2)\n",
+    "cov_cov = np.diag(np.ones(d**2))\n",
+    "Q_mu = rng.random(d**2)\n",
+    "Q_cov = np.diag(np.ones(d**2))\n",
+    "\n",
+    "with pm.Model() as model:\n",
+    "    y_obs = rng.multivariate_normal(mean=mu_true, cov=cov_true, size=n)\n",
+    "\n",
+    "    mu_param = pm.MvNormal(\"mu_param\", mu=mu_mu, cov=mu_cov)\n",
+    "    # cov_param = np.abs(pm.MvNormal(\"cov_param\", mu=cov_mu, cov=cov_cov))\n",
+    "    # Q = pm.MvNormal(\"Q\", mu=Q_mu, cov=Q_cov)\n",
+    "\n",
+    "    x = pm.MvNormal(\"x\", mu=mu_param, tau=Q_val)\n",
+    "\n",
+    "    y = pm.MvNormal(\n",
+    "        \"y\",\n",
+    "        mu=x,\n",
+    "        cov=cov_param_val,  # cov_param.reshape((d, d)),\n",
+    "        observed=y_obs,\n",
+    "    )\n",
+    "\n",
+    "    # model.logp().dprint()\n",
+    "\n",
+    "    # x0, log_likelihood = get_log_marginal_likelihood(\n",
+    "    #     x=model.rvs_to_values[x],\n",
+    "    #     Q=Q_val,#Q.reshape((d, d)),\n",
+    "    #     mu=mu_param,\n",
+    "    #     optimizer_kwargs={\"tol\": 1e-8},\n",
+    "    # )\n",
+    "\n",
+    "    # print(model.free_RVs)\n",
+    "\n",
+    "    # # with pm.Model() as inla_model:\n",
+    "    # log_prior = pm.logp(mu_param, mu_mu)\n",
+    "    # log_posterior = log_likelihood + log_prior\n",
+    "    # # # log_posterior.dprint()\n",
+    "    # # # log_posterior = pytensor.graph.replace.graph_replace(log_posterior, {model.rvs_to_values[x]: x0})\n",
+    "    # # # log_posterior_model = model_from_fgraph(log_posterior, mutate_fgraph=True)\n",
+    "    # # idata = pm.sample()\n",
+    "    # # draws = pm.draw(mu_param)\n",
+    "    # idata = pm.sample_prior_predictive()\n",
+    "    # print(idata.prior)\n",
+    "    # print(draws)\n",
+    "\n",
+    "    # print(inla_model.free_RVs)\n",
+    "    # log_posterior.dprint()\n",
+    "    # pytensor.graph.fg.FunctionGraph(inputs=[model.rvs_to_values[mu_param], model.rvs_to_values[x]], outputs=[log_posterior])\n",
+    "\n",
+    "# inla_model = marginalize(model, [mu_param, cov_param])"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "681e14ea-3629-4cc5-bb7f-e08cad5df276",
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  },
+  {
+   "cell_type": "markdown",
+   "id": "e9c829f7-cf06-4402-b909-2a27d0dea07a",
+   "metadata": {},
+   "source": [
+    "True dataset:\n",
+    "\n",
+    "$y \\sim N(\\mu_{true}, \\Sigma_{true})$\n",
+    "\n",
+    "Model:\n",
+    "\n",
+    "$y|x, \\sigma \\sim N(Ax, \\sigma W)$\n",
+    "\n",
+    "Let $A=I$, $W=I$:\n",
+    "\n",
+    "$y|x, \\sigma \\sim N(x, \\sigma)$\n",
+    "\n",
+    "Comparing model and true data:\n",
+    "\n",
+    "$x = \\mu_{true}$\n",
+    "\n",
+    "$x|\\theta \\sim N(\\mu, Q^{-1})$\n",
+    "\n",
+    "$\\theta = (\\mu, \\Sigma_b, \\sigma)$\n",
+    "\n",
+    "Set $Q = I$ for now.\n",
+    "\n",
+    "$\\theta = (\\mu, \\sigma)$"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 27,
+   "id": "e344b7d0-f76e-4a28-9be9-884a2ba1f2c4",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "[autoreload of cutils_ext failed: Traceback (most recent call last):\n",
+      "  File \"/home/michaln/Michal_Linux/git/GSoC/pymc-extras/.pixi/envs/default/lib/python3.12/site-packages/IPython/extensions/autoreload.py\", line 283, in check\n",
+      "    superreload(m, reload, self.old_objects)\n",
+      "  File \"/home/michaln/Michal_Linux/git/GSoC/pymc-extras/.pixi/envs/default/lib/python3.12/site-packages/IPython/extensions/autoreload.py\", line 483, in superreload\n",
+      "    module = reload(module)\n",
+      "             ^^^^^^^^^^^^^^\n",
+      "  File \"/home/michaln/Michal_Linux/git/GSoC/pymc-extras/.pixi/envs/default/lib/python3.12/importlib/__init__.py\", line 130, in reload\n",
+      "    raise ModuleNotFoundError(f\"spec not found for the module {name!r}\", name=name)\n",
+      "ModuleNotFoundError: spec not found for the module 'cutils_ext'\n",
+      "]\n"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "\u001b[1m============================= test session starts ==============================\u001b[0m\n",
+      "platform linux -- Python 3.12.10, pytest-8.4.1, pluggy-1.6.0 -- /home/michaln/Michal_Linux/git/GSoC/pymc-extras/.pixi/envs/default/bin/python\n",
+      "cachedir: .pytest_cache\n",
+      "rootdir: /home/michaln/Michal_Linux/git/GSoC/pymc-extras\n",
+      "configfile: pyproject.toml\n",
+      "plugins: anyio-4.9.0\n",
+      "collected 1 item                                                               \u001b[0m\u001b[1m\n",
+      "\n",
+      "../tests/test_inla.py::test_get_conditional_gaussian_approximation \u001b[32mPASSED\u001b[0m\u001b[32m [100%]\u001b[0m\n",
+      "\n",
+      "\u001b[32m============================== \u001b[32m\u001b[1m1 passed\u001b[0m\u001b[32m in 5.99s\u001b[0m\u001b[32m ===============================\u001b[0m\n"
+     ]
+    }
+   ],
+   "source": [
+    "!python -m pytest -v /home/michaln/Michal_Linux/git/GSoC/pymc-extras/tests/test_inla.py"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "f1eb9a04-15dd-437d-bcf0-ce369feec912",
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3 (ipykernel)",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.12.11"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/pymc_extras/__init__.py b/pymc_extras/__init__.py
index cee0ffeb..6566c337 100644
--- a/pymc_extras/__init__.py
+++ b/pymc_extras/__init__.py
@@ -17,7 +17,7 @@
 
 from pymc_extras import gp, statespace, utils
 from pymc_extras.distributions import *
-from pymc_extras.inference import find_MAP, fit, fit_laplace, fit_pathfinder
+from pymc_extras.inference import find_MAP, fit, fit_laplace  # , fit_pathfinder
 from pymc_extras.model.marginal.marginal_model import (
     MarginalModel,
     marginalize,
diff --git a/pymc_extras/inference/__init__.py b/pymc_extras/inference/__init__.py
index a01fdd5c..05041762 100644
--- a/pymc_extras/inference/__init__.py
+++ b/pymc_extras/inference/__init__.py
@@ -14,7 +14,9 @@
 
 from pymc_extras.inference.find_map import find_MAP
 from pymc_extras.inference.fit import fit
+from pymc_extras.inference.inla import fit_INLA
 from pymc_extras.inference.laplace import fit_laplace
-from pymc_extras.inference.pathfinder.pathfinder import fit_pathfinder
 
-__all__ = ["fit", "fit_pathfinder", "fit_laplace", "find_MAP"]
+# from pymc_extras.inference.pathfinder.pathfinder import fit_pathfinder
+
+__all__ = ["fit", "fit_pathfinder", "fit_laplace", "find_MAP", "fit_INLA"]
diff --git a/pymc_extras/inference/fit.py b/pymc_extras/inference/fit.py
index 5b83ff1f..8814ba3b 100644
--- a/pymc_extras/inference/fit.py
+++ b/pymc_extras/inference/fit.py
@@ -36,7 +36,17 @@ def fit(method: str, **kwargs) -> az.InferenceData:
 
         return fit_pathfinder(**kwargs)
 
-    if method == "laplace":
+    elif method == "laplace":
         from pymc_extras.inference.laplace import fit_laplace
 
         return fit_laplace(**kwargs)
+
+    elif method == "INLA":
+        from pymc_extras.inference.laplace import fit_INLA
+
+        return fit_INLA(**kwargs)
+
+    else:
+        raise ValueError(
+            f"method '{method}' not supported. Use one of 'pathfinder', 'laplace' or 'INLA'."
+        )
diff --git a/pymc_extras/inference/inla.py b/pymc_extras/inference/inla.py
new file mode 100644
index 00000000..3f6db2cf
--- /dev/null
+++ b/pymc_extras/inference/inla.py
@@ -0,0 +1,170 @@
+import arviz as az
+import numpy as np
+import pymc as pm
+import pytensor
+import pytensor.tensor as pt
+
+from better_optimize.constants import minimize_method
+from numpy.typing import ArrayLike
+from pytensor.tensor import TensorVariable
+from pytensor.tensor.optimize import minimize
+
+
+def get_conditional_gaussian_approximation(
+    x: TensorVariable,
+    Q: TensorVariable | ArrayLike,
+    mu: TensorVariable | ArrayLike,
+    model: pm.Model | None = None,
+    method: minimize_method = "BFGS",
+    use_jac: bool = True,
+    use_hess: bool = False,
+    optimizer_kwargs: dict | None = None,
+) -> list[TensorVariable]:
+    """
+    Returns an estimate the a posteriori probability of a latent Gaussian field x and its mode x0 using the Laplace approximation.
+
+    That is:
+    y | x, sigma ~ N(Ax, sigma^2 W)
+    x | params ~ N(mu, Q(params)^-1)
+
+    We seek to estimate p(x | y, params) with a Gaussian:
+
+    log(p(x | y, params)) = log(p(y | x, params)) + log(p(x | params)) + const
+
+    Let f(x) = log(p(y | x, params)). From the definition of our model above, we have log(p(x | params)) = -0.5*(x - mu).T Q (x - mu) + 0.5*logdet(Q).
+
+    This gives log(p(x | y, params)) = f(x) - 0.5*(x - mu).T Q (x - mu) + 0.5*logdet(Q). We will estimate this using the Laplace approximation by Taylor expanding f(x) about the mode.
+
+    Thus:
+
+    1. Maximize log(p(x | y, params)) = f(x) - 0.5*(x - mu).T Q (x - mu) wrt x (note that logdet(Q) does not depend on x) to find the mode x0.
+
+    2. Use the Laplace approximation expanded about the mode: p(x | y, params) ~= N(mu=x0, tau=Q - f''(x0)).
+
+    Parameters
+    ----------
+    x: TensorVariable
+        The parameter with which to maximize wrt (that is, find the mode in x). In INLA, this is the latent Gaussian field x~N(mu,Q^-1).
+    Q: TensorVariable | ArrayLike
+        The precision matrix of the latent field x.
+    mu: TensorVariable | ArrayLike
+        The mean of the latent field x.
+    model: Model
+        PyMC model to use.
+    method: minimize_method
+        Which minimization algorithm to use.
+    use_jac: bool
+        If true, the minimizer will compute the gradient of log(p(x | y, params)).
+    use_hess: bool
+        If true, the minimizer will compute the Hessian log(p(x | y, params)).
+    optimizer_kwargs: dict
+        Kwargs to pass to scipy.optimize.minimize.
+
+    Returns
+    -------
+    x0, p(x | y, params): list[TensorVariable]
+        Mode and Laplace approximation for posterior.
+    """
+    model = pm.modelcontext(model)
+
+    # f = log(p(y | x, params))
+    f_x = model.logp()
+
+    # log(p(x | y, params)) only including terms that depend on x for the minimization step (logdet(Q) ignored as it is a constant wrt x)
+    log_x_posterior = f_x - 0.5 * (x - mu).T @ Q @ (x - mu)
+
+    # Maximize log(p(x | y, params)) wrt x to find mode x0
+    x0, _ = minimize(
+        objective=-log_x_posterior,
+        x=x,
+        method=method,
+        jac=use_jac,
+        hess=use_hess,
+        optimizer_kwargs=optimizer_kwargs,
+    )
+
+    # require f''(x0) for Laplace approx
+    hess = pytensor.gradient.hessian(f_x, x)
+    hess = pytensor.graph.replace.graph_replace(hess, {x: x0})
+
+    # Could be made more efficient with adding diagonals only
+    tau = Q - hess
+
+    # Currently x is passed both as the query point for f(x, args) = logp(x | y, params) AND as an initial guess for x0. This may cause issues if the query point is
+    # far from the mode x0 or in a neighbourhood which results in poor convergence.
+    _, logdetTau = pt.nlinalg.slogdet(tau)
+    return x0, 0.5 * logdetTau - 0.5 * x0.shape[0] * np.log(2 * np.pi)
+
+
+def get_log_marginal_likelihood(
+    x: TensorVariable,
+    Q: TensorVariable | ArrayLike,
+    mu: TensorVariable | ArrayLike,
+    model: pm.Model | None = None,
+    method: minimize_method = "BFGS",
+    use_jac: bool = True,
+    use_hess: bool = False,
+    optimizer_kwargs: dict | None = None,
+) -> TensorVariable:
+    model = pm.modelcontext(model)
+
+    x0, log_laplace_approx = get_conditional_gaussian_approximation(
+        x, Q, mu, model, method, use_jac, use_hess, optimizer_kwargs
+    )
+    # log_laplace_approx = pm.logp(laplace_approx, x)#model.rvs_to_values[x])
+
+    _, logdetQ = pt.nlinalg.slogdet(Q)
+    # log_x_likelihood = (
+    #     -0.5 * (x - mu).T @ Q @ (x - mu) + 0.5 * logdetQ - 0.5 * x.shape[0] * np.log(2 * np.pi)
+    # )
+    log_x_likelihood = (
+        -0.5 * (x0 - mu).T @ Q @ (x0 - mu) + 0.5 * logdetQ - 0.5 * x0.shape[0] * np.log(2 * np.pi)
+    )
+
+    log_likelihood = (  # logp(y | params) =
+        model.logp()  # logp(y | x, params)
+        + log_x_likelihood  # * logp(x | params)
+        - log_laplace_approx  # / logp(x | y, params)
+    )
+
+    return x0, log_likelihood
+
+
+def fit_INLA(
+    x: TensorVariable,
+    Q: TensorVariable | ArrayLike,
+    mu: TensorVariable | ArrayLike,
+    model: pm.Model | None = None,
+    method: minimize_method = "BFGS",
+    use_jac: bool = True,
+    use_hess: bool = False,
+    optimizer_kwargs: dict | None = None,
+) -> az.InferenceData:
+    model = pm.modelcontext(model)
+
+    # logp(y | params)
+    x0, log_likelihood = get_log_marginal_likelihood(
+        x, Q, mu, model, method, use_jac, use_hess, optimizer_kwargs
+    )
+
+    # TODO How to obtain prior? It can parametrise Q, mu, y, etc. Not sure if we could extract from model.logp somehow. Otherwise simply specify as a user input
+    # Perhaps obtain as RVs which y depends on which aren't x?
+    prior = None
+    params = None
+    log_prior = pm.logp(prior, model.rvs_to_values[params])
+
+    # logp(params | y) = logp(y | params) + logp(params) + const
+    log_posterior = log_likelihood + log_prior
+    log_posterior = pytensor.graph.replace.graph_replace(log_posterior, {x: x0})
+
+    # TODO log_marginal_x_likelihood is almost the same as log_likelihood, but need to do some sampling?
+    log_marginal_x_likelihood = None
+    log_marginal_x_posterior = log_marginal_x_likelihood + log_prior
+
+    # TODO can we sample over log likelihoods?w
+    # Marginalize params
+    idata_params = log_posterior.sample()  # TODO something like NUTS, QMC, etc.?
+    idata_x = log_marginal_x_posterior.sample()
+
+    # Bundle up idatas somehow
+    return idata_params, idata_x
diff --git a/pymc_extras/inference/laplace.py b/pymc_extras/inference/laplace.py
index d64d2ada..9c0a0d27 100644
--- a/pymc_extras/inference/laplace.py
+++ b/pymc_extras/inference/laplace.py
@@ -15,7 +15,6 @@
 
 import logging
 
-from collections.abc import Callable
 from functools import reduce
 from importlib.util import find_spec
 from itertools import product
@@ -30,7 +29,6 @@
 
 from arviz import dict_to_dataset
 from better_optimize.constants import minimize_method
-from numpy.typing import ArrayLike
 from pymc import DictToArrayBijection
 from pymc.backends.arviz import (
     coords_and_dims_for_inferencedata,
@@ -41,8 +39,6 @@
 from pymc.model.transform.conditioning import remove_value_transforms
 from pymc.model.transform.optimization import freeze_dims_and_data
 from pymc.util import get_default_varnames
-from pytensor.tensor import TensorVariable
-from pytensor.tensor.optimize import minimize
 from scipy import stats
 
 from pymc_extras.inference.find_map import (
@@ -56,102 +52,6 @@
 _log = logging.getLogger(__name__)
 
 
-def get_conditional_gaussian_approximation(
-    x: TensorVariable,
-    Q: TensorVariable | ArrayLike,
-    mu: TensorVariable | ArrayLike,
-    args: list[TensorVariable] | None = None,
-    model: pm.Model | None = None,
-    method: minimize_method = "BFGS",
-    use_jac: bool = True,
-    use_hess: bool = False,
-    optimizer_kwargs: dict | None = None,
-) -> Callable:
-    """
-    Returns a function to estimate the a posteriori log probability of a latent Gaussian field x and its mode x0 using the Laplace approximation.
-
-    That is:
-    y | x, sigma ~ N(Ax, sigma^2 W)
-    x | params ~ N(mu, Q(params)^-1)
-
-    We seek to estimate log(p(x | y, params)):
-
-    log(p(x | y, params)) = log(p(y | x, params)) + log(p(x | params)) + const
-
-    Let f(x) = log(p(y | x, params)). From the definition of our model above, we have log(p(x | params)) = -0.5*(x - mu).T Q (x - mu) + 0.5*logdet(Q).
-
-    This gives log(p(x | y, params)) = f(x) - 0.5*(x - mu).T Q (x - mu) + 0.5*logdet(Q). We will estimate this using the Laplace approximation by Taylor expanding f(x) about the mode.
-
-    Thus:
-
-    1. Maximize log(p(x | y, params)) = f(x) - 0.5*(x - mu).T Q (x - mu) wrt x (note that logdet(Q) does not depend on x) to find the mode x0.
-
-    2. Substitute x0 into the Laplace approximation expanded about the mode: log(p(x | y, params)) ~= -0.5*x.T (-f''(x0) + Q) x + x.T (Q.mu + f'(x0) - f''(x0).x0) + 0.5*logdet(Q).
-
-    Parameters
-    ----------
-    x: TensorVariable
-        The parameter with which to maximize wrt (that is, find the mode in x). In INLA, this is the latent field x~N(mu,Q^-1).
-    Q: TensorVariable | ArrayLike
-        The precision matrix of the latent field x.
-    mu: TensorVariable | ArrayLike
-        The mean of the latent field x.
-    args: list[TensorVariable]
-        Args to supply to the compiled function. That is, (x0, logp) = f(x, *args). If set to None, assumes the model RVs are args.
-    model: Model
-        PyMC model to use.
-    method: minimize_method
-        Which minimization algorithm to use.
-    use_jac: bool
-        If true, the minimizer will compute the gradient of log(p(x | y, params)).
-    use_hess: bool
-        If true, the minimizer will compute the Hessian log(p(x | y, params)).
-    optimizer_kwargs: dict
-        Kwargs to pass to scipy.optimize.minimize.
-
-    Returns
-    -------
-    f: Callable
-        A function which accepts a value of x and args and returns [x0, log(p(x | y, params))], where x0 is the mode. x is currently both the point at which to evaluate logp and the initial guess for the minimizer.
-    """
-    model = pm.modelcontext(model)
-
-    if args is None:
-        args = model.continuous_value_vars + model.discrete_value_vars
-
-    # f = log(p(y | x, params))
-    f_x = model.logp()
-    jac = pytensor.gradient.grad(f_x, x)
-    hess = pytensor.gradient.jacobian(jac.flatten(), x)
-
-    # log(p(x | y, params)) only including terms that depend on x for the minimization step (logdet(Q) ignored as it is a constant wrt x)
-    log_x_posterior = f_x - 0.5 * (x - mu).T @ Q @ (x - mu)
-
-    # Maximize log(p(x | y, params)) wrt x to find mode x0
-    x0, _ = minimize(
-        objective=-log_x_posterior,
-        x=x,
-        method=method,
-        jac=use_jac,
-        hess=use_hess,
-        optimizer_kwargs=optimizer_kwargs,
-    )
-
-    # require f'(x0) and f''(x0) for Laplace approx
-    jac = pytensor.graph.replace.graph_replace(jac, {x: x0})
-    hess = pytensor.graph.replace.graph_replace(hess, {x: x0})
-
-    # Full log(p(x | y, params)) using the Laplace approximation (up to a constant)
-    _, logdetQ = pt.nlinalg.slogdet(Q)
-    conditional_gaussian_approx = (
-        -0.5 * x.T @ (-hess + Q) @ x + x.T @ (Q @ mu + jac - hess @ x0) + 0.5 * logdetQ
-    )
-
-    # Currently x is passed both as the query point for f(x, args) = logp(x | y, params) AND as an initial guess for x0. This may cause issues if the query point is
-    # far from the mode x0 or in a neighbourhood which results in poor convergence.
-    return pytensor.function(args, [x0, conditional_gaussian_approx])
-
-
 def laplace_draws_to_inferencedata(
     posterior_draws: list[np.ndarray[float | int]], model: pm.Model | None = None
 ) -> az.InferenceData:
diff --git a/pymc_extras/model/marginal/distributions.py b/pymc_extras/model/marginal/distributions.py
index 86aa5f02..42ee3ff8 100644
--- a/pymc_extras/model/marginal/distributions.py
+++ b/pymc_extras/model/marginal/distributions.py
@@ -132,6 +132,10 @@ class MarginalDiscreteMarkovChainRV(MarginalRV):
     """Base class for Marginalized Discrete Markov Chain RVs"""
 
 
+class MarginalLaplaceRV(MarginalRV):
+    """Base class for Marginalized Laplace-Approximated RVs"""
+
+
 def get_domain_of_finite_discrete_rv(rv: TensorVariable) -> tuple[int, ...]:
     op = rv.owner.op
     dist_params = rv.owner.op.dist_params(rv.owner)
@@ -371,3 +375,76 @@ def step_alpha(logp_emission, log_alpha, log_P):
     warn_non_separable_logp(values)
     dummy_logps = (DUMMY_ZERO,) * (len(values) - 1)
     return joint_logp, *dummy_logps
+
+
+@_logprob.register(MarginalLaplaceRV)
+def laplace_marginal_rv_logp(op: MarginalLaplaceRV, values, *inputs, **kwargs):
+    # Clone the inner RV graph of the Marginalized RV
+    x, *inner_rvs = inline_ofg_outputs(op, inputs)
+
+    # Obtain the joint_logp graph of the inner RV graph
+    inner_rv_values = dict(zip(inner_rvs, values))
+
+    marginalized_vv = x.clone()
+    rv_values = inner_rv_values | {x: marginalized_vv}
+    logps_dict = conditional_logp(rv_values=rv_values, **kwargs)
+
+    logp = pt.sum(
+        [pt.sum(logps_dict[k]) for k in logps_dict]
+    )  # TODO check this gives the proper p(y | x, params)
+
+    import pytensor
+
+    from pytensor.tensor.optimize import minimize
+
+    # Maximize log(p(x | y, params)) wrt x to find mode x0 # TODO args need to be user-supplied
+    x0, _ = minimize(
+        objective=-logp,
+        x=marginalized_vv,
+        method="BFGS",
+        # jac=use_jac,
+        # hess=use_hess,
+        optimizer_kwargs={"tol": 1e-8},
+    )
+
+    # print(op.__dict__)
+    # marginalized_rv_input_rvs = op.kwargs['marginalized_rv_input_rvs']
+    # x0 = op.kwargs['x0']
+    # log_laplace_approx = op.kwargs['log_laplace_approx']
+    # return logp - log_laplace_approx
+
+    rng = np.random.default_rng(12345)
+    d = 10
+    # Q = np.diag(rng.random(d))
+    from pymc import MvNormal
+
+    x = op.owner.inputs[0]
+    if not isinstance(x, MvNormal):
+        raise ValueError("Latent field x must be MvNormal.")
+    Q = x.owner.inputs[1]  # TODO double check this grabs the right thing
+    x0 = rng.random(d)
+
+    # x0 = pytensor.graph.replace.graph_replace(x0, {marginalized_vv: rng.random(d)})
+    # for rv in marginalized_rv_input_rvs:
+    #     x0 = pytensor.graph.replace.graph_replace(x0, {marginalized_vv: rng.random(d)})
+
+    # require f''(x0) for Laplace approx
+    hess = pytensor.gradient.hessian(logp, marginalized_vv)
+    # hess = pytensor.graph.replace.graph_replace(hess, {marginalized_vv: x0})
+
+    # Could be made more efficient with adding diagonals only
+    tau = Q - hess
+
+    # Currently x is passed both as the query point for f(x, args) = logp(x | y, params) AND as an initial guess for x0. This may cause issues if the query point is
+    # far from the mode x0 or in a neighbourhood which results in poor convergence.
+    _, logdetTau = pt.nlinalg.slogdet(tau)
+    log_laplace_approx = 0.5 * logdetTau - 0.5 * x0.shape[0] * np.log(2 * np.pi)
+
+    # Reduce logp dimensions corresponding to broadcasted variables
+    # marginalized_logp = logps_dict.pop(marginalized_vv)
+    joint_logp = logp - log_laplace_approx
+
+    # TODO this might cause circularity issues by overwriting x as an input to the x0 minimizer
+    joint_logp = pytensor.graph.replace.graph_replace(joint_logp, {marginalized_vv: x0})
+
+    return joint_logp  # TODO check if pm.sample adds on p(params). Otherwise this is p(y|params) not p(params|y)
diff --git a/tests/test_inla.py b/tests/test_inla.py
new file mode 100644
index 00000000..58ff9ae1
--- /dev/null
+++ b/tests/test_inla.py
@@ -0,0 +1,105 @@
+#   Copyright 2024 The PyMC Developers
+#
+#   Licensed under the Apache License, Version 2.0 (the "License");
+#   you may not use this file except in compliance with the License.
+#   You may obtain a copy of the License at
+#
+#       http://www.apache.org/licenses/LICENSE-2.0
+#
+#   Unless required by applicable law or agreed to in writing, software
+#   distributed under the License is distributed on an "AS IS" BASIS,
+#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#   See the License for the specific language governing permissions and
+#   limitations under the License.
+
+
+import numpy as np
+import pymc as pm
+import pytensor
+
+from pymc_extras.inference.inla import get_conditional_gaussian_approximation
+
+
+def test_get_conditional_gaussian_approximation():
+    """
+    Consider the trivial case of:
+
+    y | x ~ N(x, cov_param)
+    x | param ~ N(mu_param, Q^-1)
+
+    cov_param ~ N(cov_mu, cov_cov)
+    mu_param ~ N(mu_mu, mu_cov)
+    Q ~ N(Q_mu, Q_cov)
+
+    This has an analytic solution at the mode which we can compare against.
+    """
+    rng = np.random.default_rng(12345)
+    n = 10000
+    d = 10
+
+    # Initialise arrays
+    mu_true = rng.random(d)
+    cov_true = np.diag(rng.random(d))
+    Q_val = np.diag(rng.random(d))
+    cov_param_val = np.diag(rng.random(d))
+
+    x_val = rng.random(d)
+    mu_val = rng.random(d)
+
+    mu_mu = rng.random(d)
+    mu_cov = np.diag(np.ones(d))
+    cov_mu = rng.random(d**2)
+    cov_cov = np.diag(np.ones(d**2))
+    Q_mu = rng.random(d**2)
+    Q_cov = np.diag(np.ones(d**2))
+
+    with pm.Model() as model:
+        y_obs = rng.multivariate_normal(mean=mu_true, cov=cov_true, size=n)
+
+        mu_param = pm.MvNormal("mu_param", mu=mu_mu, cov=mu_cov)
+        cov_param = pm.MvNormal("cov_param", mu=cov_mu, cov=cov_cov)
+        Q = pm.MvNormal("Q", mu=Q_mu, cov=Q_cov)
+
+        x = pm.MvNormal("x", mu=mu_param, tau=Q_val)
+
+        y = pm.MvNormal(
+            "y",
+            mu=x,
+            cov=cov_param.reshape((d, d)),
+            observed=y_obs,
+        )
+
+        args = model.continuous_value_vars + model.discrete_value_vars
+
+        # logp(x | y, params)
+        x0, x_g = get_conditional_gaussian_approximation(
+            x=model.rvs_to_values[x],
+            Q=Q.reshape((d, d)),
+            mu=mu_param,
+            optimizer_kwargs={"tol": 1e-8},
+        )
+
+    cga = pytensor.function(args, [x0, pm.logp(x_g, model.rvs_to_values[x])])
+
+    x0, log_x_posterior = cga(
+        x=x_val, mu_param=mu_val, cov_param=cov_param_val.flatten(), Q=Q_val.flatten()
+    )
+
+    # Get analytic values of the mode and Laplace-approximated log posterior
+    cov_param_inv = np.linalg.inv(cov_param_val)
+
+    x0_true = np.linalg.inv(n * cov_param_inv + 2 * Q_val) @ (
+        cov_param_inv @ y_obs.sum(axis=0) + 2 * Q_val @ mu_val
+    )
+
+    hess_true = -n * cov_param_inv - Q_val
+    tau_true = Q_val - hess_true
+
+    log_x_taylor = (
+        -0.5 * (x_val - x0_true).T @ tau_true @ (x_val - x0_true)
+        + 0.5 * np.log(np.linalg.det(tau_true))
+        - 0.5 * d * np.log(2 * np.pi)
+    )
+
+    np.testing.assert_allclose(x0, x0_true, atol=0.1, rtol=0.1)
+    np.testing.assert_allclose(log_x_posterior, log_x_taylor, atol=0.1, rtol=0.1)
diff --git a/tests/test_laplace.py b/tests/test_laplace.py
index 72ff3e93..8f7a4c01 100644
--- a/tests/test_laplace.py
+++ b/tests/test_laplace.py
@@ -23,7 +23,6 @@
 from pymc_extras.inference.laplace import (
     fit_laplace,
     fit_mvn_at_MAP,
-    get_conditional_gaussian_approximation,
     sample_laplace_posterior,
 )
 
@@ -280,85 +279,3 @@ def test_laplace_scalar():
     assert idata_laplace.fit.covariance_matrix.shape == (1, 1)
 
     np.testing.assert_allclose(idata_laplace.fit.mean_vector.values.item(), data.mean(), atol=0.1)
-
-
-def test_get_conditional_gaussian_approximation():
-    """
-    Consider the trivial case of:
-
-    y | x ~ N(x, cov_param)
-    x | param ~ N(mu_param, Q^-1)
-
-    cov_param ~ N(cov_mu, cov_cov)
-    mu_param ~ N(mu_mu, mu_cov)
-    Q ~ N(Q_mu, Q_cov)
-
-    This has an analytic solution at the mode which we can compare against.
-    """
-    rng = np.random.default_rng(12345)
-    n = 10000
-    d = 10
-
-    # Initialise arrays
-    mu_true = rng.random(d)
-    cov_true = np.diag(rng.random(d))
-    Q_val = np.diag(rng.random(d))
-    cov_param_val = np.diag(rng.random(d))
-
-    x_val = rng.random(d)
-    mu_val = rng.random(d)
-
-    mu_mu = rng.random(d)
-    mu_cov = np.diag(np.ones(d))
-    cov_mu = rng.random(d**2)
-    cov_cov = np.diag(np.ones(d**2))
-    Q_mu = rng.random(d**2)
-    Q_cov = np.diag(np.ones(d**2))
-
-    with pm.Model() as model:
-        y_obs = rng.multivariate_normal(mean=mu_true, cov=cov_true, size=n)
-
-        mu_param = pm.MvNormal("mu_param", mu=mu_mu, cov=mu_cov)
-        cov_param = pm.MvNormal("cov_param", mu=cov_mu, cov=cov_cov)
-        Q = pm.MvNormal("Q", mu=Q_mu, cov=Q_cov)
-
-        # Pytensor currently doesn't support autograd for pt inverses, so we use a numeric Q instead
-        x = pm.MvNormal("x", mu=mu_param, cov=np.linalg.inv(Q_val))
-
-        y = pm.MvNormal(
-            "y",
-            mu=x,
-            cov=cov_param.reshape((d, d)),
-            observed=y_obs,
-        )
-
-        # logp(x | y, params)
-        cga = get_conditional_gaussian_approximation(
-            x=model.rvs_to_values[x],
-            Q=Q.reshape((d, d)),
-            mu=mu_param,
-            optimizer_kwargs={"tol": 1e-25},
-        )
-
-    x0, log_x_posterior = cga(
-        x=x_val, mu_param=mu_val, cov_param=cov_param_val.flatten(), Q=Q_val.flatten()
-    )
-
-    # Get analytic values of the mode and Laplace-approximated log posterior
-    cov_param_inv = np.linalg.inv(cov_param_val)
-
-    x0_true = np.linalg.inv(n * cov_param_inv + 2 * Q_val) @ (
-        cov_param_inv @ y_obs.sum(axis=0) + 2 * Q_val @ mu_val
-    )
-
-    jac_true = cov_param_inv @ (y_obs - x0_true).sum(axis=0) - Q_val @ (x0_true - mu_val)
-    hess_true = -n * cov_param_inv - Q_val
-
-    log_x_posterior_laplace_true = (
-        -0.5 * x_val.T @ (-hess_true + Q_val) @ x_val
-        + x_val.T @ (Q_val @ mu_val + jac_true - hess_true @ x0_true)
-        + 0.5 * np.log(np.linalg.det(Q_val))
-    )
-
-    np.testing.assert_allclose(x0, x0_true, atol=0.1, rtol=0.1)
-    np.testing.assert_allclose(log_x_posterior, log_x_posterior_laplace_true, atol=0.1, rtol=0.1)