diff --git a/bayesflow/experimental/__init__.py b/bayesflow/experimental/__init__.py index 6c0b6828f..3482c3628 100644 --- a/bayesflow/experimental/__init__.py +++ b/bayesflow/experimental/__init__.py @@ -2,11 +2,11 @@ Unstable or largely untested networks, proceed with caution. """ +from ..utils._docs import _add_imports_to_all from .cif import CIF from .continuous_time_consistency_model import ContinuousTimeConsistencyModel from .diffusion_model import DiffusionModel from .free_form_flow import FreeFormFlow - -from ..utils._docs import _add_imports_to_all +from .graphical_simulator import GraphicalSimulator _add_imports_to_all(include_modules=["diffusion_model"]) diff --git a/bayesflow/experimental/graphical_simulator/__init__.py b/bayesflow/experimental/graphical_simulator/__init__.py new file mode 100644 index 000000000..caaae84a5 --- /dev/null +++ b/bayesflow/experimental/graphical_simulator/__init__.py @@ -0,0 +1,2 @@ +from .graphical_simulator import GraphicalSimulator +from . import example_simulators diff --git a/bayesflow/experimental/graphical_simulator/example_simulators/__init__.py b/bayesflow/experimental/graphical_simulator/example_simulators/__init__.py new file mode 100644 index 000000000..05caded82 --- /dev/null +++ b/bayesflow/experimental/graphical_simulator/example_simulators/__init__.py @@ -0,0 +1,3 @@ +from .single_level_simulator import single_level_simulator +from .two_level_simulator import two_level_simulator +from .crossed_design_irt_simulator import crossed_design_irt_simulator diff --git a/bayesflow/experimental/graphical_simulator/example_simulators/crossed_design_irt_simulator.py b/bayesflow/experimental/graphical_simulator/example_simulators/crossed_design_irt_simulator.py new file mode 100644 index 000000000..70fa8ae12 --- /dev/null +++ b/bayesflow/experimental/graphical_simulator/example_simulators/crossed_design_irt_simulator.py @@ -0,0 +1,87 @@ +import numpy as np + +from ..graphical_simulator import GraphicalSimulator + + +def crossed_design_irt_simulator(): + r""" + Item Response Theory (IRT) model implemented as a graphical simulator. + + schools + / \ + exams students + | | + questions | + \ / + observations + """ + + # schools have different exam difficulties + def sample_school(): + mu_exam_mean = np.random.normal(loc=1.1, scale=0.2) + sigma_exam_mean = abs(np.random.normal(loc=0, scale=1)) + + # hierarchical mu/sigma for the exam difficulty standard deviation (logscale) + mu_exam_std = np.random.normal(loc=0.5, scale=0.3) + sigma_exam_std = abs(np.random.normal(loc=0, scale=0.5)) + + return dict( + mu_exam_mean=mu_exam_mean, + sigma_exam_mean=sigma_exam_mean, + mu_exam_std=mu_exam_std, + sigma_exam_std=sigma_exam_std, + ) + + # exams have different question difficulties + def sample_exam(mu_exam_mean, sigma_exam_mean, mu_exam_std, sigma_exam_std): + # mean question difficulty for an exam + exam_mean = np.random.normal(loc=mu_exam_mean, scale=sigma_exam_mean) + + # standard deviation of question difficulty + log_exam_std = np.random.normal(loc=mu_exam_std, scale=sigma_exam_std) + exam_std = float(np.exp(log_exam_std)) + + return dict(exam_mean=exam_mean, exam_std=exam_std) + + # realizations of individual question difficulties + def sample_question(exam_mean, exam_std): + question_difficulty = np.random.normal(loc=exam_mean, scale=exam_std) + + return dict(question_difficulty=question_difficulty) + + # realizations of individual student abilities + def sample_student(): + student_ability = np.random.normal(loc=0, scale=1) + + return dict(student_ability=student_ability) + + # realizations of individual observations + def sample_observation(question_difficulty, student_ability): + theta = np.exp(question_difficulty + student_ability) / (1 + np.exp(question_difficulty + student_ability)) + + obs = np.random.binomial(n=1, p=theta) + + return dict(obs=obs) + + def meta_fn(): + return { + "num_exams": np.random.randint(2, 4), + "num_questions": np.random.randint(10, 21), + "num_students": np.random.randint(100, 201), + } + + simulator = GraphicalSimulator(meta_fn=meta_fn) + + simulator.add_node("schools", sample_fn=sample_school) + simulator.add_node("exams", sample_fn=sample_exam, reps="num_exams") + simulator.add_node("questions", sample_fn=sample_question, reps="num_questions") + simulator.add_node("students", sample_fn=sample_student, reps="num_students") + simulator.add_node("observations", sample_fn=sample_observation) + + simulator.add_edge("schools", "exams") + simulator.add_edge("schools", "students") + simulator.add_edge("exams", "questions") + simulator.add_edge("questions", "observations") + simulator.add_edge("students", "observations") + + return simulator diff --git a/bayesflow/experimental/graphical_simulator/example_simulators/single_level_simulator.py b/bayesflow/experimental/graphical_simulator/example_simulators/single_level_simulator.py new file mode 100644 index 000000000..af26920eb --- /dev/null +++ b/bayesflow/experimental/graphical_simulator/example_simulators/single_level_simulator.py @@ -0,0 +1,36 @@ +import numpy as np + +from ..graphical_simulator import GraphicalSimulator + + +def single_level_simulator(): + """ + Simple single-level simulator that implements the same model as in + https://bayesflow.org/main/_examples/Linear_Regression_Starter.html + """ + + def prior(): + beta = np.random.normal([2, 0], [3, 1]) + sigma = np.random.gamma(1, 1) + + return {"beta": beta, "sigma": sigma} + + def likelihood(beta, sigma, N): + x = np.random.normal(0, 1, size=N) + y = np.random.normal(beta[0] + beta[1] * x, sigma, size=N) + + return {"x": x, "y": y} + + def meta(): + N = np.random.randint(5, 15) + + return {"N": N} + + simulator = GraphicalSimulator(meta_fn=meta) + + simulator.add_node("prior", sample_fn=prior) + simulator.add_node("likelihood", sample_fn=likelihood) + + simulator.add_edge("prior", "likelihood") + + return simulator diff --git a/bayesflow/experimental/graphical_simulator/example_simulators/two_level_simulator.py b/bayesflow/experimental/graphical_simulator/example_simulators/two_level_simulator.py new file mode 100644 index 000000000..45bdc149e --- /dev/null +++ b/bayesflow/experimental/graphical_simulator/example_simulators/two_level_simulator.py @@ -0,0 +1,65 @@ +import numpy as np + +from ..graphical_simulator import GraphicalSimulator + + +def two_level_simulator(repeated_roots=False): + r""" + Simple hierarchical model with two levels of parameters: hyperparameters + and local parameters, along with a shared parameter: + + hypers + | + locals shared + \ / + \ / + y + + Parameters + ---------- + repeated_roots : bool, default false. + + """ + + def sample_hypers(): + hyper_mean = np.random.normal() + hyper_std = np.abs(np.random.normal()) + + return {"hyper_mean": hyper_mean, "hyper_std": hyper_std} + + def sample_locals(hyper_mean, hyper_std): + local_mean = np.random.normal(hyper_mean, hyper_std) + + return {"local_mean": local_mean} + + def sample_shared(): + shared_std = np.abs(np.random.normal()) + + return {"shared_std": shared_std} + + def sample_y(local_mean, shared_std): + y = np.random.normal(local_mean, shared_std) + + return {"y": y} + + simulator = GraphicalSimulator() + + if not repeated_roots: + simulator.add_node("hypers", sample_fn=sample_hypers) + else: + simulator.add_node("hypers", sample_fn=sample_hypers, reps=5) + + simulator.add_node( + "locals", + sample_fn=sample_locals, + reps=6, + ) + + simulator.add_node("shared", sample_fn=sample_shared) + simulator.add_node("y", sample_fn=sample_y, reps=10) + + simulator.add_edge("hypers", "locals") + simulator.add_edge("locals", "y") + simulator.add_edge("shared", "y") + + return simulator diff --git a/bayesflow/experimental/graphical_simulator/graphical_simulator.py b/bayesflow/experimental/graphical_simulator/graphical_simulator.py new file mode 100644 index 000000000..22e5b8c43 --- /dev/null +++ b/bayesflow/experimental/graphical_simulator/graphical_simulator.py @@ -0,0 +1,194 @@ +import inspect +import itertools +from collections.abc import Callable +from typing import Any, Optional + +import networkx as nx +import numpy as np + +from bayesflow.simulators import Simulator +from bayesflow.types import Shape +from bayesflow.utils.decorators import allow_batch_size + + +class GraphicalSimulator(Simulator): + """ + A graph-based simulator that generates samples by traversing a DAG + and calling user-defined sample functions at each node. + + Parameters + ---------- + meta_fn : Optional[Callable[[], dict[str, Any]]] + A callable that returns a dictionary of meta data. + This meta data can be used to dynamically vary the number of sample repetitions (`reps`) + for nodes added via `add_node`. + """ + + def __init__(self, meta_fn: Optional[Callable[[], dict[str, Any]]] = None, *args, **kwargs): + super().__init__(*args, **kwargs) + self.graph = nx.DiGraph() + self.meta_fn = meta_fn + + def add_node(self, node: str, sample_fn: Callable[..., dict[str, Any]], reps: int | str = 1): + self.graph.add_node(node, sample_fn=sample_fn, reps=reps) + + def add_edge(self, from_node: str, to_node: str): + self.graph.add_edge(from_node, to_node) + + @allow_batch_size + def sample(self, batch_shape: Shape | int, **kwargs) -> dict[str, np.ndarray]: + """ + Generates samples by topologically traversing the DAG. + For each node, the sample function is called based on parent values. + + Parameters + ---------- + batch_shape : Shape + The shape of the batch to sample. Typically, a tuple indicating the number of samples, + but an int can also be passed. + **kwargs + Currently unused + """ + _ = kwargs # Simulator class requires **kwargs, which are unused here + meta_dict = self.meta_fn() if self.meta_fn else {} + samples_by_node = {} + + # Initialize samples container for each node + for node in self.graph.nodes: + samples_by_node[node] = np.empty(batch_shape, dtype="object") + + ordered_nodes = list(nx.topological_sort(self.graph)) + + for batch_idx in np.ndindex(batch_shape): + for node in ordered_nodes: + node_samples = [] + + parent_nodes = list(self.graph.predecessors(node)) + sample_fn = self.graph.nodes[node]["sample_fn"] + reps_field = self.graph.nodes[node]["reps"] + reps = reps_field if isinstance(reps_field, int) else meta_dict[reps_field] + + if not parent_nodes: + # root node: generate independent samples + node_samples = [ + {"__batch_idx": batch_idx, f"__{node}_idx": i} | self._call_sample_fn(sample_fn, {}) + for i in range(reps) + ] + else: + # non-root node: depends on parent samples + parent_samples = [samples_by_node[p][batch_idx] for p in parent_nodes] + merged_dicts = merge_lists_of_dicts(parent_samples) + + for merged in merged_dicts: + index_entries = {k: v for k, v in merged.items() if k.startswith("__")} + variable_entries = {k: v for k, v in merged.items() if not k.startswith("__")} + + sample_fn_input = variable_entries | meta_dict + samples = [ + index_entries | {f"__{node}_idx": i} | self._call_sample_fn(sample_fn, sample_fn_input) + for i in range(reps) + ] + node_samples.extend(samples) + + samples_by_node[node][batch_idx] = node_samples + + # collect outputs + output_dict = {} + for node in nx.topological_sort(self.graph): + output_dict.update(self._collect_output(samples_by_node[node])) + + output_dict.update(meta_dict) + + return output_dict + + def _collect_output(self, samples): + output_dict = {} + + # retrieve node and ancestors from internal sample representation + index_entries = [k for k in samples.flat[0][0].keys() if k.startswith("__")] + node = index_entries[-1].removeprefix("__").removesuffix("_idx") + ancestors = sorted_ancestors(self.graph, node) + + # build dict of node repetitions + reps = {} + for ancestor in ancestors: + reps[ancestor] = max(s[f"__{ancestor}_idx"] for s in samples.flat[0]) + 1 + reps[node] = max(s[f"__{node}_idx"] for s in samples.flat[0]) + 1 + + variable_names = self._variable_names(samples) + + # collect output for each variable + for variable in variable_names: + output_shape = self._output_shape(samples, variable) + output_dict[variable] = np.empty(output_shape) + + for batch_idx in np.ndindex(samples.shape): + for sample in samples[batch_idx]: + idx = [*batch_idx] + + # add index elements for ancestors + for ancestor in ancestors: + if reps[ancestor] != 1: + idx.append(sample[f"__{ancestor}_idx"]) + + # add index elements for node + if reps[node] != 1: + idx.append(sample[f"__{node}_idx"]) + + output_dict[variable][tuple(idx)] = sample[variable] + + return output_dict + + def _variable_names(self, samples): + return [k for k in samples.flat[0][0].keys() if not k.startswith("__")] + + def _output_shape(self, samples, variable): + index_entries = [k for k in samples.flat[0][0].keys() if k.startswith("__")] + node = index_entries[-1].removeprefix("__").removesuffix("_idx") + + # start with batch shape + batch_shape = samples.shape + output_shape = list(batch_shape) + ancestors = sorted_ancestors(self.graph, node) + + # add ancestor reps + for ancestor in ancestors: + node_reps = max(s[f"__{ancestor}_idx"] for s in samples.flat[0]) + 1 + if node_reps != 1: + output_shape.append(node_reps) + + # add node reps + node_reps = max(s[f"__{node}_idx"] for s in samples.flat[0]) + 1 + if node_reps != 1: + output_shape.append(node_reps) + + # add variable shape + variable_shape = np.atleast_1d(samples.flat[0][0][variable]).shape + output_shape.extend(variable_shape) + + return tuple(output_shape) + + def _call_sample_fn(self, sample_fn, args): + signature = inspect.signature(sample_fn) + fn_args = signature.parameters + accepted_args = {k: v for k, v in args.items() if k in fn_args} + + return sample_fn(**accepted_args) + + +def sorted_ancestors(graph, node): + return [n for n in nx.topological_sort(graph) if n in nx.ancestors(graph, node)] + + +def merge_lists_of_dicts(nested_list: list[list[dict]]) -> list[dict]: + """ + Merges all combinations of dictionaries from a list of lists. + Equivalent to a Cartesian product of dicts, then flattening. + + Examples: + >>> merge_lists_of_dicts([[{"a": 1, "b": 2}], [{"c": 3}, {"d": 4}]]) + [{'a': 1, 'b': 2, 'c': 3}, {'a': 1, 'b': 2, 'd': 4}] + """ + + all_combinations = itertools.product(*nested_list) + return [{k: v for d in combo for k, v in d.items()} for combo in all_combinations] diff --git a/pyproject.toml b/pyproject.toml index 902ffc1d2..b3c36c02a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,7 @@ requires-python = ">= 3.10, < 3.13" dependencies = [ "keras >= 3.9", "matplotlib", + "networkx>=3.4.2", "numpy >= 1.24, <2.0", "pandas", "scipy", diff --git a/tests/test_simulators/conftest.py b/tests/test_simulators/conftest.py index 7dcc22c12..9a0ae09f2 100644 --- a/tests/test_simulators/conftest.py +++ b/tests/test_simulators/conftest.py @@ -247,3 +247,31 @@ def fixed_mu(): ) def simulator(request): return request.getfixturevalue(request.param) + + +@pytest.fixture() +def single_level_simulator(): + from bayesflow.experimental.graphical_simulator.example_simulators import single_level_simulator + + return single_level_simulator() + + +@pytest.fixture() +def two_level_simulator(): + from bayesflow.experimental.graphical_simulator.example_simulators import two_level_simulator + + return two_level_simulator() + + +@pytest.fixture() +def two_level_repeated_roots_simulator(): + from bayesflow.experimental.graphical_simulator.example_simulators import two_level_simulator + + return two_level_simulator(repeated_roots=True) + + +@pytest.fixture() +def crossed_design_irt_simulator(): + from bayesflow.experimental.graphical_simulator.example_simulators import crossed_design_irt_simulator + + return crossed_design_irt_simulator() diff --git a/tests/test_simulators/test_graphical_simulator.py b/tests/test_simulators/test_graphical_simulator.py new file mode 100644 index 000000000..0707af71b --- /dev/null +++ b/tests/test_simulators/test_graphical_simulator.py @@ -0,0 +1,144 @@ +import numpy as np + +import bayesflow as bf + + +def test_single_level_simulator(single_level_simulator): + # prior -> likelihood + + simulator = single_level_simulator + assert isinstance(simulator, bf.experimental.graphical_simulator.GraphicalSimulator) + assert isinstance(simulator.sample(5), dict) + + samples = simulator.sample(12) + expected_keys = ["N", "beta", "sigma", "x", "y"] + + assert set(samples.keys()) == set(expected_keys) + assert 5 <= samples["N"] < 15 + + # prior node + assert np.shape(samples["beta"]) == (12, 2) # num_samples, beta_dim + assert np.shape(samples["sigma"]) == (12, 1) # num_samples, sigma_dim + + # likelihood node + assert np.shape(samples["x"]) == (12, samples["N"]) + assert np.shape(samples["y"]) == (12, samples["N"]) + + +def test_two_level_simulator(two_level_simulator): + # hypers + # | + # locals shared + # \ / + # \ / + # y + + simulator = two_level_simulator + assert isinstance(simulator, bf.experimental.graphical_simulator.GraphicalSimulator) + assert isinstance(simulator.sample(5), dict) + + samples = simulator.sample(15) + expected_keys = ["hyper_mean", "hyper_std", "local_mean", "shared_std", "y"] + + assert set(samples.keys()) == set(expected_keys) + + # hypers node + assert np.shape(samples["hyper_mean"]) == (15, 1) + assert np.shape(samples["hyper_std"]) == (15, 1) + + # locals node + assert np.shape(samples["local_mean"]) == (15, 6, 1) + + # shared node + assert np.shape(samples["shared_std"]) == (15, 1) + + # y node + assert np.shape(samples["y"]) == (15, 6, 10, 1) + + +def test_two_level_repeated_roots_simulator(two_level_repeated_roots_simulator): + # hypers + # | + # locals shared + # \ / + # \ / + # y + + simulator = two_level_repeated_roots_simulator + assert isinstance(simulator, bf.experimental.graphical_simulator.GraphicalSimulator) + assert isinstance(simulator.sample(5), dict) + + samples = simulator.sample(15) + expected_keys = ["hyper_mean", "hyper_std", "local_mean", "shared_std", "y"] + + assert set(samples.keys()) == set(expected_keys) + + # hypers node + assert np.shape(samples["hyper_mean"]) == (15, 5, 1) + assert np.shape(samples["hyper_std"]) == (15, 5, 1) + + # locals node + assert np.shape(samples["local_mean"]) == (15, 5, 6, 1) + + # shared node + assert np.shape(samples["shared_std"]) == (15, 1) + + # y node + assert np.shape(samples["y"]) == (15, 5, 6, 10, 1) + + +def test_crossed_design_irt_simulator(crossed_design_irt_simulator): + # schools + # / \ + # exams students + # | | + # questions | + # \ / + # observations + + simulator = crossed_design_irt_simulator + assert isinstance(simulator, bf.experimental.graphical_simulator.GraphicalSimulator) + assert isinstance(simulator.sample(5), dict) + + samples = simulator.sample(22) + expected_keys = [ + "mu_exam_mean", + "sigma_exam_mean", + "mu_exam_std", + "sigma_exam_std", + "exam_mean", + "exam_std", + "question_difficulty", + "student_ability", + "obs", + "num_exams", # np.random.randint(2, 4) + "num_questions", # np.random.randint(10, 21) + "num_students", # np.random.randint(100, 201) + ] + + assert set(samples.keys()) == set(expected_keys) + + # schools node + assert np.shape(samples["mu_exam_mean"]) == (22, 1) + assert np.shape(samples["sigma_exam_mean"]) == (22, 1) + assert np.shape(samples["mu_exam_std"]) == (22, 1) + assert np.shape(samples["sigma_exam_std"]) == (22, 1) + + # exams node + assert np.shape(samples["exam_mean"]) == (22, samples["num_exams"], 1) + assert np.shape(samples["exam_std"]) == (22, samples["num_exams"], 1) + + # questions node + assert np.shape(samples["question_difficulty"]) == (22, samples["num_exams"], samples["num_questions"], 1) + + # students node + assert np.shape(samples["student_ability"]) == (22, samples["num_students"], 1) + + # observations node + assert np.shape(samples["obs"]) == ( + 22, + samples["num_exams"], + samples["num_students"], + samples["num_questions"], + 1, + )