From 20b54ceec20a49e28e17516ff998e2b61d9d8b08 Mon Sep 17 00:00:00 2001 From: Daniel Habermann <133031176+daniel-habermann@users.noreply.github.com> Date: Fri, 23 May 2025 16:21:11 +0200 Subject: [PATCH 01/20] initial commit From 02cc91529fcc2eb4650cd114f33084603a7043d1 Mon Sep 17 00:00:00 2001 From: Daniel Habermann <133031176+daniel-habermann@users.noreply.github.com> Date: Mon, 2 Jun 2025 17:31:48 +0200 Subject: [PATCH 02/20] add networkx to project dependencies --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 902ffc1d2..b3c36c02a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,7 @@ requires-python = ">= 3.10, < 3.13" dependencies = [ "keras >= 3.9", "matplotlib", + "networkx>=3.4.2", "numpy >= 1.24, <2.0", "pandas", "scipy", From be9f3903a81a9cc31a93b61d3c5ce28b00a257ed Mon Sep 17 00:00:00 2001 From: Daniel Habermann <133031176+daniel-habermann@users.noreply.github.com> Date: Sun, 15 Jun 2025 15:26:23 +0200 Subject: [PATCH 03/20] initial implementation GraphicalSimulator --- .../graphical_simulator/__init__.py | 0 .../graphical_simulator/example_simulators.py | 52 ++++++++ .../graphical_simulator.py | 115 ++++++++++++++++++ 3 files changed, 167 insertions(+) create mode 100644 bayesflow/experimental/graphical_simulator/__init__.py create mode 100644 bayesflow/experimental/graphical_simulator/example_simulators.py create mode 100644 bayesflow/experimental/graphical_simulator/graphical_simulator.py diff --git a/bayesflow/experimental/graphical_simulator/__init__.py b/bayesflow/experimental/graphical_simulator/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/bayesflow/experimental/graphical_simulator/example_simulators.py b/bayesflow/experimental/graphical_simulator/example_simulators.py new file mode 100644 index 000000000..7dbab9bc0 --- /dev/null +++ b/bayesflow/experimental/graphical_simulator/example_simulators.py @@ -0,0 +1,52 @@ +import numpy as np +from .graphical_simmulator import GraphicalSimulator +from bayesflow.utils import batched_call + + +def test_batched_call(): + return batched_call(sample_fn, (10, 2), flatten=True) + pass + + +def sample_fn(): + return {"a": 3, "b": 6} + + +def twolevel_simulator(): + def sample_hypers(): + hyper_mean = np.random.normal() + hyper_std = np.abs(np.random.normal()) + + return {"hyper_mean": float(hyper_mean), "hyper_std": float(hyper_std)} + + def sample_locals(hyper_mean, hyper_std): + local_mean = np.random.normal(hyper_mean, hyper_std) + + return {"local_mean": float(local_mean)} + + def sample_shared(): + shared_std = np.abs(np.random.normal()) + + return {"shared_std": shared_std} + + def sample_y(local_mean, shared_std): + y = np.random.normal(local_mean, shared_std) + + return {"y": float(y)} + + simulator = GraphicalSimulator() + simulator.add_node("hypers", sampling_fn=sample_hypers, reps=1) + + simulator.add_node( + "locals", + sampling_fn=sample_locals, + reps=6, + ) + simulator.add_node("shared", sampling_fn=sample_shared, reps=1) + simulator.add_node("y", sampling_fn=sample_y, reps=10) + + simulator.add_edge("hypers", "locals") + simulator.add_edge("locals", "y") + simulator.add_edge("shared", "y") + + return simulator diff --git a/bayesflow/experimental/graphical_simulator/graphical_simulator.py b/bayesflow/experimental/graphical_simulator/graphical_simulator.py new file mode 100644 index 000000000..ac5de559a --- /dev/null +++ b/bayesflow/experimental/graphical_simulator/graphical_simulator.py @@ -0,0 +1,115 @@ +import inspect +import itertools +from collections.abc import Callable +from typing import Any, Optional + +import networkx as nx +import numpy as np + +from bayesflow.simulators import Simulator +from bayesflow.types import Shape + + +class GraphicalSimulator(Simulator): + """ + A graph-based simulator that generates samples by traversing a DAG + and calling user-defined sampling functions at each node. + + Parameters + ---------- + meta_fn : Optional[Callable[[], dict[str, Any]]] + A callable that returns a dictionary of meta data. + This meta data can be used to dynamically vary the number of sampling repetitions (`reps`) + for nodes added via `add_node`. + """ + + def __init__(self, meta_fn: Optional[Callable[[], dict[str, Any]]] = None, *args, **kwargs): + super().__init__(*args, **kwargs) + self.graph = nx.DiGraph() + self.meta_fn = meta_fn + + def add_node(self, node: str, sampling_fn: Callable[..., dict[str, Any]], reps: int | str = 1): + self.graph.add_node(node, sampling_fn=sampling_fn, reps=reps) + + def add_edge(self, from_node: str, to_node: str): + self.graph.add_edge(from_node, to_node) + + def sample(self, batch_shape: Shape, **kwargs) -> dict[str, np.ndarray]: + """ + Generates samples by topologically traversing the DAG. + For each node, the sampling function is called based on parent values. + + Parameters + ---------- + batch_shape : Shape + The shape of the batch to sample. Typically, a tuple indicating the number of samples, + but an int can also be passed. + **kwargs + Unused + """ + _ = kwargs # Simulator class requires **kwargs, which are unused here + meta_dict = self.meta_fn() if self.meta_fn else {} + + # Initialize samples containers for each node + for node in self.graph.nodes: + self.graph.nodes[node]["samples"] = np.empty(batch_shape, dtype="object") + + for batch_idx in np.ndindex(batch_shape): + for node in nx.topological_sort(self.graph): + node_samples = [] + + parent_nodes = list(self.graph.predecessors(node)) + sampling_fn = self.graph.nodes[node]["sampling_fn"] + reps_field = self.graph.nodes[node]["reps"] + reps = reps_field if isinstance(reps_field, int) else meta_dict[reps_field] + + if not parent_nodes: + # root node: generate independent samples + node_samples = [ + {"__batch_idx": batch_idx, f"__{node}_idx": i} | sampling_fn() for i in range(1, reps + 1) + ] + else: + # non-root node: depends on parent samples + parent_samples = [self.graph.nodes[p]["samples"][batch_idx] for p in parent_nodes] + merged_dicts = merge_lists_of_dicts(parent_samples) + + for merged in merged_dicts: + index_entries = filter_indices(merged) + variable_entries = filter_variables(merged) + + node_samples.extend( + [ + index_entries | {f"__{node}_idx": i} | call_sampling_fn(sampling_fn, variable_entries) + for i in range(1, reps + 1) + ] + ) + + self.graph.nodes[node]["samples"][batch_idx] = node_samples + + return {"a": np.zeros(3)} + + +def merge_lists_of_dicts(nested_list: list[list[dict]]) -> list[dict]: + """ + Merges all combinations of dictionaries from a list of lists. + Equivalent to a Cartesian product of dicts, then flattening. + """ + + all_combinations = itertools.product(*nested_list) + return [{k: v for d in combo for k, v in d.items()} for combo in all_combinations] + + +def call_sampling_fn(sampling_fn: Callable, inputs: dict) -> dict[str, Any]: + num_args = len(inspect.signature(sampling_fn).parameters) + if num_args == 0: + return sampling_fn() + else: + return sampling_fn(**inputs) + + +def filter_indices(d: dict) -> dict[str, Any]: + return {k: v for k, v in d.items() if k.startswith("__")} + + +def filter_variables(d: dict) -> dict[str, Any]: + return {k: v for k, v in d.items() if not k.startswith("__")} From 4194062b7f1f61d82fa4b737de7705da008fad70 Mon Sep 17 00:00:00 2001 From: Daniel Habermann <133031176+daniel-habermann@users.noreply.github.com> Date: Fri, 20 Jun 2025 01:10:08 +0200 Subject: [PATCH 04/20] samples method of GraphicalSimulator now returns a dict of appropriately shaped numpy arrays --- .../graphical_simulator/example_simulators.py | 12 +- .../graphical_simulator.py | 103 +++++++++++++----- 2 files changed, 79 insertions(+), 36 deletions(-) diff --git a/bayesflow/experimental/graphical_simulator/example_simulators.py b/bayesflow/experimental/graphical_simulator/example_simulators.py index 7dbab9bc0..23f1ebfeb 100644 --- a/bayesflow/experimental/graphical_simulator/example_simulators.py +++ b/bayesflow/experimental/graphical_simulator/example_simulators.py @@ -1,15 +1,5 @@ import numpy as np -from .graphical_simmulator import GraphicalSimulator -from bayesflow.utils import batched_call - - -def test_batched_call(): - return batched_call(sample_fn, (10, 2), flatten=True) - pass - - -def sample_fn(): - return {"a": 3, "b": 6} +from .graphical_simulator import GraphicalSimulator def twolevel_simulator(): diff --git a/bayesflow/experimental/graphical_simulator/graphical_simulator.py b/bayesflow/experimental/graphical_simulator/graphical_simulator.py index ac5de559a..82a4bd3a4 100644 --- a/bayesflow/experimental/graphical_simulator/graphical_simulator.py +++ b/bayesflow/experimental/graphical_simulator/graphical_simulator.py @@ -1,4 +1,3 @@ -import inspect import itertools from collections.abc import Callable from typing import Any, Optional @@ -8,6 +7,7 @@ from bayesflow.simulators import Simulator from bayesflow.types import Shape +from bayesflow.utils.decorators import allow_batch_size class GraphicalSimulator(Simulator): @@ -34,6 +34,7 @@ def add_node(self, node: str, sampling_fn: Callable[..., dict[str, Any]], reps: def add_edge(self, from_node: str, to_node: str): self.graph.add_edge(from_node, to_node) + @allow_batch_size def sample(self, batch_shape: Shape, **kwargs) -> dict[str, np.ndarray]: """ Generates samples by topologically traversing the DAG. @@ -49,10 +50,11 @@ def sample(self, batch_shape: Shape, **kwargs) -> dict[str, np.ndarray]: """ _ = kwargs # Simulator class requires **kwargs, which are unused here meta_dict = self.meta_fn() if self.meta_fn else {} + samples_by_node = {} - # Initialize samples containers for each node + # Initialize samples container for each node for node in self.graph.nodes: - self.graph.nodes[node]["samples"] = np.empty(batch_shape, dtype="object") + samples_by_node[node] = np.empty(batch_shape, dtype="object") for batch_idx in np.ndindex(batch_shape): for node in nx.topological_sort(self.graph): @@ -70,46 +72,97 @@ def sample(self, batch_shape: Shape, **kwargs) -> dict[str, np.ndarray]: ] else: # non-root node: depends on parent samples - parent_samples = [self.graph.nodes[p]["samples"][batch_idx] for p in parent_nodes] + parent_samples = [samples_by_node[p][batch_idx] for p in parent_nodes] merged_dicts = merge_lists_of_dicts(parent_samples) for merged in merged_dicts: - index_entries = filter_indices(merged) - variable_entries = filter_variables(merged) + index_entries = {k: v for k, v in merged.items() if k.startswith("__")} + variable_entries = {k: v for k, v in merged.items() if not k.startswith("__")} node_samples.extend( [ - index_entries | {f"__{node}_idx": i} | call_sampling_fn(sampling_fn, variable_entries) + index_entries | {f"__{node}_idx": i} | sampling_fn(**variable_entries) for i in range(1, reps + 1) ] ) - self.graph.nodes[node]["samples"][batch_idx] = node_samples + samples_by_node[node][batch_idx] = node_samples - return {"a": np.zeros(3)} + output_dict = {} + for node in nx.topological_sort(self.graph): + output_dict.update(self._collect_output(samples_by_node[node])) + + return output_dict + + def _collect_output(self, samples): + output_dict = {} + + index_entries = [k for k in samples.flat[0][0].keys() if k.startswith("__")] + node = index_entries[-1].removeprefix("__").removesuffix("_idx") + ancestors = non_root_ancestors(self.graph, node) + variable_names = self._variable_names(samples) + + for variable in variable_names: + output_shape = self._output_shape(samples, variable) + output_dict[variable] = np.empty(output_shape) + + for batch_idx in np.ndindex(samples.shape): + for sample in samples[batch_idx]: + idx = tuple( + [*batch_idx] + + [sample[f"__{a}_idx"] - 1 for a in ancestors] + + [sample[f"__{node}_idx"] - 1] # - 1 for 0-based indexing + ) + output_dict[variable][idx] = sample[variable] + + return output_dict + + def _variable_names(self, samples): + return [k for k in samples.flat[0][0].keys() if not k.startswith("__")] + + def _output_shape(self, samples, variable): + index_entries = [k for k in samples.flat[0][0].keys() if k.startswith("__")] + node = index_entries[-1].removeprefix("__").removesuffix("_idx") + + # start with batch shape + batch_shape = samples.shape + output_shape = [*batch_shape] + ancestors = non_root_ancestors(self.graph, node) + + # add reps of non root ancestors + for ancestor in ancestors: + reps = max(s[f"__{ancestor}_idx"] for s in samples.flat[0]) + output_shape.append(reps) + + # add node reps + if not is_root_node(self.graph, node): + node_reps = max(s[f"__{node}_idx"] for s in samples.flat[0]) + output_shape.append(node_reps) + + # add variable shape + variable_shape = np.atleast_1d(samples.flat[0][0][variable]).shape + output_shape.extend(variable_shape) + + return tuple(output_shape) + + +def non_root_ancestors(graph, node): + return [n for n in nx.topological_sort(graph) if n in nx.ancestors(graph, node) and not is_root_node(graph, n)] + + +def is_root_node(graph, node): + return len(list(graph.predecessors(node))) == 0 def merge_lists_of_dicts(nested_list: list[list[dict]]) -> list[dict]: """ Merges all combinations of dictionaries from a list of lists. Equivalent to a Cartesian product of dicts, then flattening. + + Examples: + >>> merge_lists_of_dicts([[{"a": 1, "b": 2}], [{"c": 3}, {"d": 4}]]) + [{'a': 1, 'b': 2, 'c': 3}, {'a': 1, 'b': 2, 'd': 4}] """ all_combinations = itertools.product(*nested_list) return [{k: v for d in combo for k, v in d.items()} for combo in all_combinations] - - -def call_sampling_fn(sampling_fn: Callable, inputs: dict) -> dict[str, Any]: - num_args = len(inspect.signature(sampling_fn).parameters) - if num_args == 0: - return sampling_fn() - else: - return sampling_fn(**inputs) - - -def filter_indices(d: dict) -> dict[str, Any]: - return {k: v for k, v in d.items() if k.startswith("__")} - - -def filter_variables(d: dict) -> dict[str, Any]: - return {k: v for k, v in d.items() if not k.startswith("__")} From 60e589a243221c64b13113819ed7bd6600453001 Mon Sep 17 00:00:00 2001 From: Daniel Habermann <133031176+daniel-habermann@users.noreply.github.com> Date: Fri, 20 Jun 2025 02:09:08 +0200 Subject: [PATCH 05/20] add irt_simulator and threelevel_simulator --- .../graphical_simulator/example_simulators.py | 136 ++++++++++++++++++ 1 file changed, 136 insertions(+) diff --git a/bayesflow/experimental/graphical_simulator/example_simulators.py b/bayesflow/experimental/graphical_simulator/example_simulators.py index 23f1ebfeb..9801b5d3c 100644 --- a/bayesflow/experimental/graphical_simulator/example_simulators.py +++ b/bayesflow/experimental/graphical_simulator/example_simulators.py @@ -2,6 +2,93 @@ from .graphical_simulator import GraphicalSimulator +def irt_simulator(): + # schools have different exam difficulties + def sample_school(): + mu_exam_mean = np.random.normal(loc=1.1, scale=0.2) + sigma_exam_mean = abs(np.random.normal(loc=0, scale=1)) + + # hierarchical mu/sigma for the exam difficulty standard deviation (logscale) + mu_exam_std = np.random.normal(loc=0.5, scale=0.3) + sigma_exam_std = abs(np.random.normal(loc=0, scale=1)) + + return dict( + mu_exam_mean=mu_exam_mean, + sigma_exam_mean=sigma_exam_mean, + mu_exam_std=mu_exam_std, + sigma_exam_std=sigma_exam_std, + ) + + # exams have different question difficulties + def sample_exam(mu_exam_mean, sigma_exam_mean, mu_exam_std, sigma_exam_std): + # mean question difficulty for an exam + exam_mean = np.random.normal(loc=mu_exam_mean, scale=sigma_exam_mean) + + # standard deviation of question difficulty + log_exam_std = np.random.normal(loc=mu_exam_std, scale=sigma_exam_std) + exam_std = float(np.exp(log_exam_std)) + + return dict(exam_mean=exam_mean, exam_std=exam_std) + + # realizations of individual question difficulties + def sample_question(exam_mean, exam_std): + question_difficulty = np.random.normal(loc=exam_mean, scale=exam_std) + + return dict(question_difficulty=question_difficulty) + + # realizations of individual student abilities + def sample_student(**kwargs): + student_ability = np.random.normal(loc=0, scale=1) + + return dict(student_ability=student_ability) + + # realizations of individual observations + def sample_observation(question_difficulty, student_ability): + theta = np.exp(question_difficulty + student_ability) / (1 + np.exp(question_difficulty + student_ability)) + + obs = np.random.binomial(n=1, p=theta) + + return dict(obs=obs) + + def meta_fn(): + return { + "num_exams": np.random.randint(2, 4), + "num_questions": np.random.randint(10, 21), + "num_students": np.random.randint(100, 201), + } + + simulator = GraphicalSimulator(meta_fn=meta_fn) + simulator.add_node( + "schools", + sampling_fn=sample_school, + ) + simulator.add_node( + "exams", + sampling_fn=sample_exam, + reps="num_exams", + ) + simulator.add_node( + "questions", + sampling_fn=sample_question, + reps="num_questions", + ) + simulator.add_node( + "students", + sampling_fn=sample_student, + reps="num_students", + ) + + simulator.add_node("observations", sampling_fn=sample_observation) + + simulator.add_edge("schools", "exams") + simulator.add_edge("schools", "students") + simulator.add_edge("exams", "questions") + simulator.add_edge("questions", "observations") + simulator.add_edge("students", "observations") + + return simulator + + def twolevel_simulator(): def sample_hypers(): hyper_mean = np.random.normal() @@ -40,3 +127,52 @@ def sample_y(local_mean, shared_std): simulator.add_edge("shared", "y") return simulator + + +def threelevel_simulator(): + def sample_level_1(): + level_1_mean = np.random.normal() + + return {"level_1_mean": float(level_1_mean)} + + def sample_level_2(level_1_mean): + level_2_mean = np.random.normal(level_1_mean, 1) + + return {"level_2_mean": float(level_2_mean)} + + def sample_level_3(level_2_mean): + level_3_mean = np.random.normal(level_2_mean, 1) + + return {"level_3_mean": float(level_3_mean)} + + def sample_shared(): + shared_std = np.abs(np.random.normal()) + + return {"shared_std": shared_std} + + def sample_y(level_3_mean, shared_std): + y = np.random.normal(level_3_mean, shared_std, size=10) + + return {"y": y} + + simulator = GraphicalSimulator() + simulator.add_node("level1", sampling_fn=sample_level_1) + simulator.add_node( + "level2", + sampling_fn=sample_level_2, + reps=10, + ) + simulator.add_node( + "level3", + sampling_fn=sample_level_3, + reps=20, + ) + simulator.add_node("shared", sampling_fn=sample_shared) + simulator.add_node("y", sampling_fn=sample_y, reps=10) + + simulator.add_edge("level1", "level2") + simulator.add_edge("level2", "level3") + simulator.add_edge("level3", "y") + simulator.add_edge("shared", "y") + + return simulator From d1624ee15767db75d7dc4906d1a1a8ffb5ca796d Mon Sep 17 00:00:00 2001 From: Daniel Habermann <133031176+daniel-habermann@users.noreply.github.com> Date: Fri, 20 Jun 2025 13:30:52 +0200 Subject: [PATCH 06/20] enable sampling_fn with no arguments for non root nodes, change output dimensionality rules for sample method --- .../graphical_simulator.py | 35 ++++++++++++++----- 1 file changed, 26 insertions(+), 9 deletions(-) diff --git a/bayesflow/experimental/graphical_simulator/graphical_simulator.py b/bayesflow/experimental/graphical_simulator/graphical_simulator.py index 82a4bd3a4..95b12e10f 100644 --- a/bayesflow/experimental/graphical_simulator/graphical_simulator.py +++ b/bayesflow/experimental/graphical_simulator/graphical_simulator.py @@ -1,3 +1,4 @@ +import inspect import itertools from collections.abc import Callable from typing import Any, Optional @@ -68,7 +69,8 @@ def sample(self, batch_shape: Shape, **kwargs) -> dict[str, np.ndarray]: if not parent_nodes: # root node: generate independent samples node_samples = [ - {"__batch_idx": batch_idx, f"__{node}_idx": i} | sampling_fn() for i in range(1, reps + 1) + {"__batch_idx": batch_idx, f"__{node}_idx": i} | self._call_sampling_fn(sampling_fn, {}) + for i in range(1, reps + 1) ] else: # non-root node: depends on parent samples @@ -79,9 +81,12 @@ def sample(self, batch_shape: Shape, **kwargs) -> dict[str, np.ndarray]: index_entries = {k: v for k, v in merged.items() if k.startswith("__")} variable_entries = {k: v for k, v in merged.items() if not k.startswith("__")} + sampling_fn_input = variable_entries | meta_dict node_samples.extend( [ - index_entries | {f"__{node}_idx": i} | sampling_fn(**variable_entries) + index_entries + | {f"__{node}_idx": i} + | self._call_sampling_fn(sampling_fn, sampling_fn_input) for i in range(1, reps + 1) ] ) @@ -92,6 +97,8 @@ def sample(self, batch_shape: Shape, **kwargs) -> dict[str, np.ndarray]: for node in nx.topological_sort(self.graph): output_dict.update(self._collect_output(samples_by_node[node])) + output_dict.update(meta_dict) + return output_dict def _collect_output(self, samples): @@ -99,6 +106,7 @@ def _collect_output(self, samples): index_entries = [k for k in samples.flat[0][0].keys() if k.startswith("__")] node = index_entries[-1].removeprefix("__").removesuffix("_idx") + node_reps = max(s[f"__{node}_idx"] for s in samples.flat[0]) ancestors = non_root_ancestors(self.graph, node) variable_names = self._variable_names(samples) @@ -108,12 +116,13 @@ def _collect_output(self, samples): for batch_idx in np.ndindex(samples.shape): for sample in samples[batch_idx]: - idx = tuple( - [*batch_idx] - + [sample[f"__{a}_idx"] - 1 for a in ancestors] - + [sample[f"__{node}_idx"] - 1] # - 1 for 0-based indexing - ) - output_dict[variable][idx] = sample[variable] + idx = [*batch_idx] + for ancestor in ancestors: + idx.append(sample[f"__{ancestor}_idx"] - 1) + if not is_root_node(self.graph, node): + if node_reps != 1: + idx.append(sample[f"__{node}_idx"] - 1) # -1 for 0-based indexing + output_dict[variable][tuple(idx)] = sample[variable] return output_dict @@ -137,7 +146,8 @@ def _output_shape(self, samples, variable): # add node reps if not is_root_node(self.graph, node): node_reps = max(s[f"__{node}_idx"] for s in samples.flat[0]) - output_shape.append(node_reps) + if node_reps != 1: + output_shape.append(node_reps) # add variable shape variable_shape = np.atleast_1d(samples.flat[0][0][variable]).shape @@ -145,6 +155,13 @@ def _output_shape(self, samples, variable): return tuple(output_shape) + def _call_sampling_fn(self, sampling_fn, args): + signature = inspect.signature(sampling_fn) + fn_args = signature.parameters + accepted_args = {k: v for k, v in args.items() if k in fn_args} + + return sampling_fn(**accepted_args) + def non_root_ancestors(graph, node): return [n for n in nx.topological_sort(graph) if n in nx.ancestors(graph, node) and not is_root_node(graph, n)] From ae105f65a852b29ebd56d4a40d1353fc44addce8 Mon Sep 17 00:00:00 2001 From: Daniel Habermann <133031176+daniel-habermann@users.noreply.github.com> Date: Fri, 20 Jun 2025 13:31:20 +0200 Subject: [PATCH 07/20] add one and three-level example simulators --- .../graphical_simulator/example_simulators.py | 32 +++++++++++++++++-- 1 file changed, 30 insertions(+), 2 deletions(-) diff --git a/bayesflow/experimental/graphical_simulator/example_simulators.py b/bayesflow/experimental/graphical_simulator/example_simulators.py index 9801b5d3c..78c1fb68a 100644 --- a/bayesflow/experimental/graphical_simulator/example_simulators.py +++ b/bayesflow/experimental/graphical_simulator/example_simulators.py @@ -10,7 +10,7 @@ def sample_school(): # hierarchical mu/sigma for the exam difficulty standard deviation (logscale) mu_exam_std = np.random.normal(loc=0.5, scale=0.3) - sigma_exam_std = abs(np.random.normal(loc=0, scale=1)) + sigma_exam_std = abs(np.random.normal(loc=0, scale=0.5)) return dict( mu_exam_mean=mu_exam_mean, @@ -37,7 +37,7 @@ def sample_question(exam_mean, exam_std): return dict(question_difficulty=question_difficulty) # realizations of individual student abilities - def sample_student(**kwargs): + def sample_student(): student_ability = np.random.normal(loc=0, scale=1) return dict(student_ability=student_ability) @@ -89,6 +89,34 @@ def meta_fn(): return simulator +def onelevel_simulator(): + def prior(): + beta = np.random.normal([2, 0], [3, 1]) + sigma = np.random.gamma(1, 1) + + return {"beta": beta, "sigma": sigma} + + def likelihood(beta, sigma, N): + x = np.random.normal(0, 1, size=N) + y = np.random.normal(beta[0] + beta[1] * x, sigma, size=N) + + return {"x": x, "y": y} + + def meta(): + N = np.random.randint(5, 15) + + return {"N": N} + + simulator = GraphicalSimulator(meta_fn=meta) + + simulator.add_node("prior", sampling_fn=prior) + simulator.add_node("likelihood", sampling_fn=likelihood) + + simulator.add_edge("prior", "likelihood") + + return simulator + + def twolevel_simulator(): def sample_hypers(): hyper_mean = np.random.normal() From d8ac4fd08fff18b0b5e26710b159b1d3d9a8de08 Mon Sep 17 00:00:00 2001 From: Daniel Habermann <133031176+daniel-habermann@users.noreply.github.com> Date: Sat, 28 Jun 2025 00:22:54 +0200 Subject: [PATCH 08/20] allow root node repetitions --- .../graphical_simulator.py | 46 ++++++++++++------- 1 file changed, 30 insertions(+), 16 deletions(-) diff --git a/bayesflow/experimental/graphical_simulator/graphical_simulator.py b/bayesflow/experimental/graphical_simulator/graphical_simulator.py index 95b12e10f..d1ea10121 100644 --- a/bayesflow/experimental/graphical_simulator/graphical_simulator.py +++ b/bayesflow/experimental/graphical_simulator/graphical_simulator.py @@ -93,6 +93,7 @@ def sample(self, batch_shape: Shape, **kwargs) -> dict[str, np.ndarray]: samples_by_node[node][batch_idx] = node_samples + # collect outputs output_dict = {} for node in nx.topological_sort(self.graph): output_dict.update(self._collect_output(samples_by_node[node])) @@ -104,12 +105,20 @@ def sample(self, batch_shape: Shape, **kwargs) -> dict[str, np.ndarray]: def _collect_output(self, samples): output_dict = {} + # retrieve node and ancestors from internal sample representation index_entries = [k for k in samples.flat[0][0].keys() if k.startswith("__")] node = index_entries[-1].removeprefix("__").removesuffix("_idx") - node_reps = max(s[f"__{node}_idx"] for s in samples.flat[0]) - ancestors = non_root_ancestors(self.graph, node) + ancestors = sorted_ancestors(self.graph, node) + + # build dict of node repetitions + reps = {} + for ancestor in ancestors: + reps[ancestor] = max(s[f"__{ancestor}_idx"] for s in samples.flat[0]) + reps[node] = max(s[f"__{node}_idx"] for s in samples.flat[0]) + variable_names = self._variable_names(samples) + # collect output for each variable for variable in variable_names: output_shape = self._output_shape(samples, variable) output_dict[variable] = np.empty(output_shape) @@ -117,11 +126,16 @@ def _collect_output(self, samples): for batch_idx in np.ndindex(samples.shape): for sample in samples[batch_idx]: idx = [*batch_idx] + + # add index elements for ancestors for ancestor in ancestors: - idx.append(sample[f"__{ancestor}_idx"] - 1) - if not is_root_node(self.graph, node): - if node_reps != 1: - idx.append(sample[f"__{node}_idx"] - 1) # -1 for 0-based indexing + if reps[ancestor] != 1: + idx.append(sample[f"__{ancestor}_idx"] - 1) # -1 for 0-based indexing + + # add index elements for node + if reps[node] != 1: + idx.append(sample[f"__{node}_idx"] - 1) # -1 for 0-based indexing + output_dict[variable][tuple(idx)] = sample[variable] return output_dict @@ -136,19 +150,19 @@ def _output_shape(self, samples, variable): # start with batch shape batch_shape = samples.shape output_shape = [*batch_shape] - ancestors = non_root_ancestors(self.graph, node) + ancestors = sorted_ancestors(self.graph, node) - # add reps of non root ancestors + # add ancestor reps for ancestor in ancestors: - reps = max(s[f"__{ancestor}_idx"] for s in samples.flat[0]) - output_shape.append(reps) - - # add node reps - if not is_root_node(self.graph, node): - node_reps = max(s[f"__{node}_idx"] for s in samples.flat[0]) + node_reps = max(s[f"__{ancestor}_idx"] for s in samples.flat[0]) if node_reps != 1: output_shape.append(node_reps) + # add node reps + node_reps = max(s[f"__{node}_idx"] for s in samples.flat[0]) + if node_reps != 1: + output_shape.append(node_reps) + # add variable shape variable_shape = np.atleast_1d(samples.flat[0][0][variable]).shape output_shape.extend(variable_shape) @@ -163,8 +177,8 @@ def _call_sampling_fn(self, sampling_fn, args): return sampling_fn(**accepted_args) -def non_root_ancestors(graph, node): - return [n for n in nx.topological_sort(graph) if n in nx.ancestors(graph, node) and not is_root_node(graph, n)] +def sorted_ancestors(graph, node): + return [n for n in nx.topological_sort(graph) if n in nx.ancestors(graph, node)] def is_root_node(graph, node): From 56f76816a87a7b876ef80198131c631cfd7668de Mon Sep 17 00:00:00 2001 From: Daniel Habermann <133031176+daniel-habermann@users.noreply.github.com> Date: Sat, 28 Jun 2025 13:15:10 +0200 Subject: [PATCH 09/20] export GraphicalSimulator --- bayesflow/experimental/__init__.py | 4 ++-- bayesflow/experimental/graphical_simulator/__init__.py | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/bayesflow/experimental/__init__.py b/bayesflow/experimental/__init__.py index 6c0b6828f..3482c3628 100644 --- a/bayesflow/experimental/__init__.py +++ b/bayesflow/experimental/__init__.py @@ -2,11 +2,11 @@ Unstable or largely untested networks, proceed with caution. """ +from ..utils._docs import _add_imports_to_all from .cif import CIF from .continuous_time_consistency_model import ContinuousTimeConsistencyModel from .diffusion_model import DiffusionModel from .free_form_flow import FreeFormFlow - -from ..utils._docs import _add_imports_to_all +from .graphical_simulator import GraphicalSimulator _add_imports_to_all(include_modules=["diffusion_model"]) diff --git a/bayesflow/experimental/graphical_simulator/__init__.py b/bayesflow/experimental/graphical_simulator/__init__.py index e69de29bb..d6bd92196 100644 --- a/bayesflow/experimental/graphical_simulator/__init__.py +++ b/bayesflow/experimental/graphical_simulator/__init__.py @@ -0,0 +1 @@ +from .graphical_simulator import GraphicalSimulator From e59b2b22b19ac49ee6cfaf63cc0135323572ede4 Mon Sep 17 00:00:00 2001 From: Daniel Habermann <133031176+daniel-habermann@users.noreply.github.com> Date: Sat, 28 Jun 2025 13:39:44 +0200 Subject: [PATCH 10/20] rename sampling_fn argument to sample_fn in GraphicalSimulator.add_node method --- .../graphical_simulator/example_simulators.py | 29 ++++++++++--------- .../graphical_simulator.py | 6 ++-- 2 files changed, 18 insertions(+), 17 deletions(-) diff --git a/bayesflow/experimental/graphical_simulator/example_simulators.py b/bayesflow/experimental/graphical_simulator/example_simulators.py index 78c1fb68a..ded044ae3 100644 --- a/bayesflow/experimental/graphical_simulator/example_simulators.py +++ b/bayesflow/experimental/graphical_simulator/example_simulators.py @@ -1,4 +1,5 @@ import numpy as np + from .graphical_simulator import GraphicalSimulator @@ -60,21 +61,21 @@ def meta_fn(): simulator = GraphicalSimulator(meta_fn=meta_fn) simulator.add_node( "schools", - sampling_fn=sample_school, + sample_fn=sample_school, ) simulator.add_node( "exams", - sampling_fn=sample_exam, + sample_fn=sample_exam, reps="num_exams", ) simulator.add_node( "questions", - sampling_fn=sample_question, + sample_fn=sample_question, reps="num_questions", ) simulator.add_node( "students", - sampling_fn=sample_student, + sample_fn=sample_student, reps="num_students", ) @@ -109,8 +110,8 @@ def meta(): simulator = GraphicalSimulator(meta_fn=meta) - simulator.add_node("prior", sampling_fn=prior) - simulator.add_node("likelihood", sampling_fn=likelihood) + simulator.add_node("prior", sample_fn=prior) + simulator.add_node("likelihood", sample_fn=likelihood) simulator.add_edge("prior", "likelihood") @@ -140,15 +141,15 @@ def sample_y(local_mean, shared_std): return {"y": float(y)} simulator = GraphicalSimulator() - simulator.add_node("hypers", sampling_fn=sample_hypers, reps=1) + simulator.add_node("hypers", sample_fn=sample_hypers, reps=5) simulator.add_node( "locals", sampling_fn=sample_locals, reps=6, ) - simulator.add_node("shared", sampling_fn=sample_shared, reps=1) - simulator.add_node("y", sampling_fn=sample_y, reps=10) + simulator.add_node("shared", sample_fn=sample_shared, reps=1) + simulator.add_node("y", sample_fn=sample_y, reps=10) simulator.add_edge("hypers", "locals") simulator.add_edge("locals", "y") @@ -184,19 +185,19 @@ def sample_y(level_3_mean, shared_std): return {"y": y} simulator = GraphicalSimulator() - simulator.add_node("level1", sampling_fn=sample_level_1) + simulator.add_node("level1", sample_fn=sample_level_1) simulator.add_node( "level2", - sampling_fn=sample_level_2, + sample_fn=sample_level_2, reps=10, ) simulator.add_node( "level3", - sampling_fn=sample_level_3, + sample_fn=sample_level_3, reps=20, ) - simulator.add_node("shared", sampling_fn=sample_shared) - simulator.add_node("y", sampling_fn=sample_y, reps=10) + simulator.add_node("shared", sample_fn=sample_shared) + simulator.add_node("y", sample_fn=sample_y, reps=10) simulator.add_edge("level1", "level2") simulator.add_edge("level2", "level3") diff --git a/bayesflow/experimental/graphical_simulator/graphical_simulator.py b/bayesflow/experimental/graphical_simulator/graphical_simulator.py index d1ea10121..fa58c5e3a 100644 --- a/bayesflow/experimental/graphical_simulator/graphical_simulator.py +++ b/bayesflow/experimental/graphical_simulator/graphical_simulator.py @@ -29,8 +29,8 @@ def __init__(self, meta_fn: Optional[Callable[[], dict[str, Any]]] = None, *args self.graph = nx.DiGraph() self.meta_fn = meta_fn - def add_node(self, node: str, sampling_fn: Callable[..., dict[str, Any]], reps: int | str = 1): - self.graph.add_node(node, sampling_fn=sampling_fn, reps=reps) + def add_node(self, node: str, sample_fn: Callable[..., dict[str, Any]], reps: int | str = 1): + self.graph.add_node(node, sample_fn=sample_fn, reps=reps) def add_edge(self, from_node: str, to_node: str): self.graph.add_edge(from_node, to_node) @@ -62,7 +62,7 @@ def sample(self, batch_shape: Shape, **kwargs) -> dict[str, np.ndarray]: node_samples = [] parent_nodes = list(self.graph.predecessors(node)) - sampling_fn = self.graph.nodes[node]["sampling_fn"] + sampling_fn = self.graph.nodes[node]["sample_fn"] reps_field = self.graph.nodes[node]["reps"] reps = reps_field if isinstance(reps_field, int) else meta_dict[reps_field] From 9284333e91aa7a5e325a66e4991f84f68ac5722c Mon Sep 17 00:00:00 2001 From: Daniel Habermann <133031176+daniel-habermann@users.noreply.github.com> Date: Sat, 28 Jun 2025 17:19:59 +0200 Subject: [PATCH 11/20] move example simulators to own submodule --- .../graphical_simulator/__init__.py | 1 + .../graphical_simulator/example_simulators.py | 207 ------------------ .../example_simulators/__init__.py | 4 + .../example_simulators/irt.py | 87 ++++++++ .../example_simulators/single_level.py | 36 +++ .../example_simulators/two_level.py | 57 +++++ .../two_level_repeated_roots.py | 56 +++++ .../graphical_simulator.py | 2 +- 8 files changed, 242 insertions(+), 208 deletions(-) delete mode 100644 bayesflow/experimental/graphical_simulator/example_simulators.py create mode 100644 bayesflow/experimental/graphical_simulator/example_simulators/__init__.py create mode 100644 bayesflow/experimental/graphical_simulator/example_simulators/irt.py create mode 100644 bayesflow/experimental/graphical_simulator/example_simulators/single_level.py create mode 100644 bayesflow/experimental/graphical_simulator/example_simulators/two_level.py create mode 100644 bayesflow/experimental/graphical_simulator/example_simulators/two_level_repeated_roots.py diff --git a/bayesflow/experimental/graphical_simulator/__init__.py b/bayesflow/experimental/graphical_simulator/__init__.py index d6bd92196..caaae84a5 100644 --- a/bayesflow/experimental/graphical_simulator/__init__.py +++ b/bayesflow/experimental/graphical_simulator/__init__.py @@ -1 +1,2 @@ from .graphical_simulator import GraphicalSimulator +from . import example_simulators diff --git a/bayesflow/experimental/graphical_simulator/example_simulators.py b/bayesflow/experimental/graphical_simulator/example_simulators.py deleted file mode 100644 index ded044ae3..000000000 --- a/bayesflow/experimental/graphical_simulator/example_simulators.py +++ /dev/null @@ -1,207 +0,0 @@ -import numpy as np - -from .graphical_simulator import GraphicalSimulator - - -def irt_simulator(): - # schools have different exam difficulties - def sample_school(): - mu_exam_mean = np.random.normal(loc=1.1, scale=0.2) - sigma_exam_mean = abs(np.random.normal(loc=0, scale=1)) - - # hierarchical mu/sigma for the exam difficulty standard deviation (logscale) - mu_exam_std = np.random.normal(loc=0.5, scale=0.3) - sigma_exam_std = abs(np.random.normal(loc=0, scale=0.5)) - - return dict( - mu_exam_mean=mu_exam_mean, - sigma_exam_mean=sigma_exam_mean, - mu_exam_std=mu_exam_std, - sigma_exam_std=sigma_exam_std, - ) - - # exams have different question difficulties - def sample_exam(mu_exam_mean, sigma_exam_mean, mu_exam_std, sigma_exam_std): - # mean question difficulty for an exam - exam_mean = np.random.normal(loc=mu_exam_mean, scale=sigma_exam_mean) - - # standard deviation of question difficulty - log_exam_std = np.random.normal(loc=mu_exam_std, scale=sigma_exam_std) - exam_std = float(np.exp(log_exam_std)) - - return dict(exam_mean=exam_mean, exam_std=exam_std) - - # realizations of individual question difficulties - def sample_question(exam_mean, exam_std): - question_difficulty = np.random.normal(loc=exam_mean, scale=exam_std) - - return dict(question_difficulty=question_difficulty) - - # realizations of individual student abilities - def sample_student(): - student_ability = np.random.normal(loc=0, scale=1) - - return dict(student_ability=student_ability) - - # realizations of individual observations - def sample_observation(question_difficulty, student_ability): - theta = np.exp(question_difficulty + student_ability) / (1 + np.exp(question_difficulty + student_ability)) - - obs = np.random.binomial(n=1, p=theta) - - return dict(obs=obs) - - def meta_fn(): - return { - "num_exams": np.random.randint(2, 4), - "num_questions": np.random.randint(10, 21), - "num_students": np.random.randint(100, 201), - } - - simulator = GraphicalSimulator(meta_fn=meta_fn) - simulator.add_node( - "schools", - sample_fn=sample_school, - ) - simulator.add_node( - "exams", - sample_fn=sample_exam, - reps="num_exams", - ) - simulator.add_node( - "questions", - sample_fn=sample_question, - reps="num_questions", - ) - simulator.add_node( - "students", - sample_fn=sample_student, - reps="num_students", - ) - - simulator.add_node("observations", sampling_fn=sample_observation) - - simulator.add_edge("schools", "exams") - simulator.add_edge("schools", "students") - simulator.add_edge("exams", "questions") - simulator.add_edge("questions", "observations") - simulator.add_edge("students", "observations") - - return simulator - - -def onelevel_simulator(): - def prior(): - beta = np.random.normal([2, 0], [3, 1]) - sigma = np.random.gamma(1, 1) - - return {"beta": beta, "sigma": sigma} - - def likelihood(beta, sigma, N): - x = np.random.normal(0, 1, size=N) - y = np.random.normal(beta[0] + beta[1] * x, sigma, size=N) - - return {"x": x, "y": y} - - def meta(): - N = np.random.randint(5, 15) - - return {"N": N} - - simulator = GraphicalSimulator(meta_fn=meta) - - simulator.add_node("prior", sample_fn=prior) - simulator.add_node("likelihood", sample_fn=likelihood) - - simulator.add_edge("prior", "likelihood") - - return simulator - - -def twolevel_simulator(): - def sample_hypers(): - hyper_mean = np.random.normal() - hyper_std = np.abs(np.random.normal()) - - return {"hyper_mean": float(hyper_mean), "hyper_std": float(hyper_std)} - - def sample_locals(hyper_mean, hyper_std): - local_mean = np.random.normal(hyper_mean, hyper_std) - - return {"local_mean": float(local_mean)} - - def sample_shared(): - shared_std = np.abs(np.random.normal()) - - return {"shared_std": shared_std} - - def sample_y(local_mean, shared_std): - y = np.random.normal(local_mean, shared_std) - - return {"y": float(y)} - - simulator = GraphicalSimulator() - simulator.add_node("hypers", sample_fn=sample_hypers, reps=5) - - simulator.add_node( - "locals", - sampling_fn=sample_locals, - reps=6, - ) - simulator.add_node("shared", sample_fn=sample_shared, reps=1) - simulator.add_node("y", sample_fn=sample_y, reps=10) - - simulator.add_edge("hypers", "locals") - simulator.add_edge("locals", "y") - simulator.add_edge("shared", "y") - - return simulator - - -def threelevel_simulator(): - def sample_level_1(): - level_1_mean = np.random.normal() - - return {"level_1_mean": float(level_1_mean)} - - def sample_level_2(level_1_mean): - level_2_mean = np.random.normal(level_1_mean, 1) - - return {"level_2_mean": float(level_2_mean)} - - def sample_level_3(level_2_mean): - level_3_mean = np.random.normal(level_2_mean, 1) - - return {"level_3_mean": float(level_3_mean)} - - def sample_shared(): - shared_std = np.abs(np.random.normal()) - - return {"shared_std": shared_std} - - def sample_y(level_3_mean, shared_std): - y = np.random.normal(level_3_mean, shared_std, size=10) - - return {"y": y} - - simulator = GraphicalSimulator() - simulator.add_node("level1", sample_fn=sample_level_1) - simulator.add_node( - "level2", - sample_fn=sample_level_2, - reps=10, - ) - simulator.add_node( - "level3", - sample_fn=sample_level_3, - reps=20, - ) - simulator.add_node("shared", sample_fn=sample_shared) - simulator.add_node("y", sample_fn=sample_y, reps=10) - - simulator.add_edge("level1", "level2") - simulator.add_edge("level2", "level3") - simulator.add_edge("level3", "y") - simulator.add_edge("shared", "y") - - return simulator diff --git a/bayesflow/experimental/graphical_simulator/example_simulators/__init__.py b/bayesflow/experimental/graphical_simulator/example_simulators/__init__.py new file mode 100644 index 000000000..ea8e4607a --- /dev/null +++ b/bayesflow/experimental/graphical_simulator/example_simulators/__init__.py @@ -0,0 +1,4 @@ +from .single_level import single_level +from .two_level import two_level +from .two_level_repeated_roots import two_level_repeated_roots +from .irt import irt diff --git a/bayesflow/experimental/graphical_simulator/example_simulators/irt.py b/bayesflow/experimental/graphical_simulator/example_simulators/irt.py new file mode 100644 index 000000000..21a64122d --- /dev/null +++ b/bayesflow/experimental/graphical_simulator/example_simulators/irt.py @@ -0,0 +1,87 @@ +import numpy as np + +from ..graphical_simulator import GraphicalSimulator + + +def irt(): + r""" + Item Response Theory (IRT) model implemented as a graphical simultor. + + schools + / \ + exams students + | | + questions | + \ / + observations + """ + + # schools have different exam difficulties + def sample_school(): + mu_exam_mean = np.random.normal(loc=1.1, scale=0.2) + sigma_exam_mean = abs(np.random.normal(loc=0, scale=1)) + + # hierarchical mu/sigma for the exam difficulty standard deviation (logscale) + mu_exam_std = np.random.normal(loc=0.5, scale=0.3) + sigma_exam_std = abs(np.random.normal(loc=0, scale=0.5)) + + return dict( + mu_exam_mean=mu_exam_mean, + sigma_exam_mean=sigma_exam_mean, + mu_exam_std=mu_exam_std, + sigma_exam_std=sigma_exam_std, + ) + + # exams have different question difficulties + def sample_exam(mu_exam_mean, sigma_exam_mean, mu_exam_std, sigma_exam_std): + # mean question difficulty for an exam + exam_mean = np.random.normal(loc=mu_exam_mean, scale=sigma_exam_mean) + + # standard deviation of question difficulty + log_exam_std = np.random.normal(loc=mu_exam_std, scale=sigma_exam_std) + exam_std = float(np.exp(log_exam_std)) + + return dict(exam_mean=exam_mean, exam_std=exam_std) + + # realizations of individual question difficulties + def sample_question(exam_mean, exam_std): + question_difficulty = np.random.normal(loc=exam_mean, scale=exam_std) + + return dict(question_difficulty=question_difficulty) + + # realizations of individual student abilities + def sample_student(): + student_ability = np.random.normal(loc=0, scale=1) + + return dict(student_ability=student_ability) + + # realizations of individual observations + def sample_observation(question_difficulty, student_ability): + theta = np.exp(question_difficulty + student_ability) / (1 + np.exp(question_difficulty + student_ability)) + + obs = np.random.binomial(n=1, p=theta) + + return dict(obs=obs) + + def meta_fn(): + return { + "num_exams": np.random.randint(2, 4), + "num_questions": np.random.randint(10, 21), + "num_students": np.random.randint(100, 201), + } + + simulator = GraphicalSimulator(meta_fn=meta_fn) + + simulator.add_node("schools", sample_fn=sample_school) + simulator.add_node("exams", sample_fn=sample_exam, reps="num_exams") + simulator.add_node("questions", sample_fn=sample_question, reps="num_questions") + simulator.add_node("students", sample_fn=sample_student, reps="num_students") + simulator.add_node("observations", sample_fn=sample_observation) + + simulator.add_edge("schools", "exams") + simulator.add_edge("schools", "students") + simulator.add_edge("exams", "questions") + simulator.add_edge("questions", "observations") + simulator.add_edge("students", "observations") + + return simulator diff --git a/bayesflow/experimental/graphical_simulator/example_simulators/single_level.py b/bayesflow/experimental/graphical_simulator/example_simulators/single_level.py new file mode 100644 index 000000000..49be680ae --- /dev/null +++ b/bayesflow/experimental/graphical_simulator/example_simulators/single_level.py @@ -0,0 +1,36 @@ +import numpy as np + +from ..graphical_simulator import GraphicalSimulator + + +def single_level(): + """ + Simple single-level simulator that implements the same model as in + https://bayesflow.org/main/_examples/Linear_Regression_Starter.html + """ + + def prior(): + beta = np.random.normal([2, 0], [3, 1]) + sigma = np.random.gamma(1, 1) + + return {"beta": beta, "sigma": sigma} + + def likelihood(beta, sigma, N): + x = np.random.normal(0, 1, size=N) + y = np.random.normal(beta[0] + beta[1] * x, sigma, size=N) + + return {"x": x, "y": y} + + def meta(): + N = np.random.randint(5, 15) + + return {"N": N} + + simulator = GraphicalSimulator(meta_fn=meta) + + simulator.add_node("prior", sample_fn=prior) + simulator.add_node("likelihood", sample_fn=likelihood) + + simulator.add_edge("prior", "likelihood") + + return simulator diff --git a/bayesflow/experimental/graphical_simulator/example_simulators/two_level.py b/bayesflow/experimental/graphical_simulator/example_simulators/two_level.py new file mode 100644 index 000000000..296569068 --- /dev/null +++ b/bayesflow/experimental/graphical_simulator/example_simulators/two_level.py @@ -0,0 +1,57 @@ +import numpy as np + +from ..graphical_simulator import GraphicalSimulator + + +def two_level(): + r""" + Simple hierarchical model with two levels of parameters: hyperparameters + and local parameters, along with a shared parameter: + + hypers + | + locals shared + \ / + \ / + y + + """ + + def sample_hypers(): + hyper_mean = np.random.normal() + hyper_std = np.abs(np.random.normal()) + + return {"hyper_mean": hyper_mean, "hyper_std": hyper_std} + + def sample_locals(hyper_mean, hyper_std): + local_mean = np.random.normal(hyper_mean, hyper_std) + + return {"local_mean": local_mean} + + def sample_shared(): + shared_std = np.abs(np.random.normal()) + + return {"shared_std": shared_std} + + def sample_y(local_mean, shared_std): + y = np.random.normal(local_mean, shared_std) + + return {"y": y} + + simulator = GraphicalSimulator() + simulator.add_node("hypers", sample_fn=sample_hypers) + + simulator.add_node( + "locals", + sample_fn=sample_locals, + reps=6, + ) + + simulator.add_node("shared", sample_fn=sample_shared) + simulator.add_node("y", sample_fn=sample_y, reps=10) + + simulator.add_edge("hypers", "locals") + simulator.add_edge("locals", "y") + simulator.add_edge("shared", "y") + + return simulator diff --git a/bayesflow/experimental/graphical_simulator/example_simulators/two_level_repeated_roots.py b/bayesflow/experimental/graphical_simulator/example_simulators/two_level_repeated_roots.py new file mode 100644 index 000000000..9b907156b --- /dev/null +++ b/bayesflow/experimental/graphical_simulator/example_simulators/two_level_repeated_roots.py @@ -0,0 +1,56 @@ +import numpy as np + +from ..graphical_simulator import GraphicalSimulator + + +def two_level_repeated_roots(): + r""" + Same as two_level(), but the hypers root node is repeated 5 times. + + hypers + | + locals shared + \ / + \ / + y + + """ + + def sample_hypers(): + hyper_mean = np.random.normal() + hyper_std = np.abs(np.random.normal()) + + return {"hyper_mean": hyper_mean, "hyper_std": hyper_std} + + def sample_locals(hyper_mean, hyper_std): + local_mean = np.random.normal(hyper_mean, hyper_std) + + return {"local_mean": local_mean} + + def sample_shared(): + shared_std = np.abs(np.random.normal()) + + return {"shared_std": shared_std} + + def sample_y(local_mean, shared_std): + y = np.random.normal(local_mean, shared_std) + + return {"y": y} + + simulator = GraphicalSimulator() + simulator.add_node("hypers", sample_fn=sample_hypers, reps=5) + + simulator.add_node( + "locals", + sample_fn=sample_locals, + reps=6, + ) + + simulator.add_node("shared", sample_fn=sample_shared) + simulator.add_node("y", sample_fn=sample_y, reps=10) + + simulator.add_edge("hypers", "locals") + simulator.add_edge("locals", "y") + simulator.add_edge("shared", "y") + + return simulator diff --git a/bayesflow/experimental/graphical_simulator/graphical_simulator.py b/bayesflow/experimental/graphical_simulator/graphical_simulator.py index fa58c5e3a..17a251fcc 100644 --- a/bayesflow/experimental/graphical_simulator/graphical_simulator.py +++ b/bayesflow/experimental/graphical_simulator/graphical_simulator.py @@ -36,7 +36,7 @@ def add_edge(self, from_node: str, to_node: str): self.graph.add_edge(from_node, to_node) @allow_batch_size - def sample(self, batch_shape: Shape, **kwargs) -> dict[str, np.ndarray]: + def sample(self, batch_shape: Shape | int, **kwargs) -> dict[str, np.ndarray]: """ Generates samples by topologically traversing the DAG. For each node, the sampling function is called based on parent values. From b75c5f5d9de81c94c712a65fc7997199da88c308 Mon Sep 17 00:00:00 2001 From: Daniel Habermann <133031176+daniel-habermann@users.noreply.github.com> Date: Sat, 28 Jun 2025 17:20:37 +0200 Subject: [PATCH 12/20] add unit tests for single level graphical model --- tests/test_simulators/conftest.py | 7 +++++++ .../test_graphical_simulator.py | 18 ++++++++++++++++++ 2 files changed, 25 insertions(+) create mode 100644 tests/test_simulators/test_graphical_simulator.py diff --git a/tests/test_simulators/conftest.py b/tests/test_simulators/conftest.py index 7dcc22c12..410b08c84 100644 --- a/tests/test_simulators/conftest.py +++ b/tests/test_simulators/conftest.py @@ -247,3 +247,10 @@ def fixed_mu(): ) def simulator(request): return request.getfixturevalue(request.param) + + +@pytest.fixture() +def single_level_simulator(): + from bayesflow.experimental.graphical_simulator.example_simulators import single_level + + return single_level() diff --git a/tests/test_simulators/test_graphical_simulator.py b/tests/test_simulators/test_graphical_simulator.py new file mode 100644 index 000000000..6ec8c3e7e --- /dev/null +++ b/tests/test_simulators/test_graphical_simulator.py @@ -0,0 +1,18 @@ +import numpy as np + +import bayesflow as bf + + +def test_single_level_simulator(single_level_simulator): + assert isinstance(single_level_simulator, bf.experimental.graphical_simulator.GraphicalSimulator) + assert isinstance(single_level_simulator.sample(5), dict) + + samples = single_level_simulator.sample((12,)) + expected_keys = ["N", "beta", "sigma", "x", "y"] + + assert set(samples.keys()) == set(expected_keys) + assert 5 <= samples["N"] < 15 + assert np.shape(samples["beta"]) == (12, 2) # num_samples, beta_dim + assert np.shape(samples["sigma"]) == (12, 1) # num_samples, sigma_dim + assert np.shape(samples["x"]) == (12, samples["N"]) + assert np.shape(samples["y"]) == (12, samples["N"]) From 1243f8c9ae89667a4287b012bad05612b5132f0f Mon Sep 17 00:00:00 2001 From: Daniel Habermann <133031176+daniel-habermann@users.noreply.github.com> Date: Sat, 28 Jun 2025 19:31:49 +0200 Subject: [PATCH 13/20] add unit tests for two_level and irt graphical simulators --- tests/test_simulators/conftest.py | 21 +++ .../test_graphical_simulator.py | 124 +++++++++++++++++- 2 files changed, 144 insertions(+), 1 deletion(-) diff --git a/tests/test_simulators/conftest.py b/tests/test_simulators/conftest.py index 410b08c84..5b6a8ca25 100644 --- a/tests/test_simulators/conftest.py +++ b/tests/test_simulators/conftest.py @@ -254,3 +254,24 @@ def single_level_simulator(): from bayesflow.experimental.graphical_simulator.example_simulators import single_level return single_level() + + +@pytest.fixture() +def two_level_simulator(): + from bayesflow.experimental.graphical_simulator.example_simulators import two_level + + return two_level() + + +@pytest.fixture() +def two_level_repeated_roots_simulator(): + from bayesflow.experimental.graphical_simulator.example_simulators import two_level_repeated_roots + + return two_level_repeated_roots() + + +@pytest.fixture() +def irt_simulator(): + from bayesflow.experimental.graphical_simulator.example_simulators import irt + + return irt() diff --git a/tests/test_simulators/test_graphical_simulator.py b/tests/test_simulators/test_graphical_simulator.py index 6ec8c3e7e..f3d711768 100644 --- a/tests/test_simulators/test_graphical_simulator.py +++ b/tests/test_simulators/test_graphical_simulator.py @@ -4,15 +4,137 @@ def test_single_level_simulator(single_level_simulator): + # prior -> likelihood assert isinstance(single_level_simulator, bf.experimental.graphical_simulator.GraphicalSimulator) assert isinstance(single_level_simulator.sample(5), dict) - samples = single_level_simulator.sample((12,)) + samples = single_level_simulator.sample(12) expected_keys = ["N", "beta", "sigma", "x", "y"] assert set(samples.keys()) == set(expected_keys) assert 5 <= samples["N"] < 15 + + # prior node assert np.shape(samples["beta"]) == (12, 2) # num_samples, beta_dim assert np.shape(samples["sigma"]) == (12, 1) # num_samples, sigma_dim + + # likelihood node assert np.shape(samples["x"]) == (12, samples["N"]) assert np.shape(samples["y"]) == (12, samples["N"]) + + +def test_two_level_simulator(two_level_simulator): + # hypers + # | + # locals shared + # \ / + # \ / + # y + + assert isinstance(two_level_simulator, bf.experimental.graphical_simulator.GraphicalSimulator) + assert isinstance(two_level_simulator.sample(5), dict) + + samples = two_level_simulator.sample(15) + expected_keys = ["hyper_mean", "hyper_std", "local_mean", "shared_std", "y"] + + assert set(samples.keys()) == set(expected_keys) + + # hypers node + assert np.shape(samples["hyper_mean"]) == (15, 1) + assert np.shape(samples["hyper_std"]) == (15, 1) + + # locals node + assert np.shape(samples["local_mean"]) == (15, 6, 1) + + # shared node + assert np.shape(samples["shared_std"]) == (15, 1) + + # y node + assert np.shape(samples["y"]) == (15, 6, 10, 1) + + +def test_two_level_repeated_roots_simulator(two_level_repeated_roots_simulator): + # hypers + # | + # locals shared + # \ / + # \ / + # y + + simulator = two_level_repeated_roots_simulator + assert isinstance(simulator, bf.experimental.graphical_simulator.GraphicalSimulator) + assert isinstance(simulator.sample(5), dict) + + samples = simulator.sample(15) + expected_keys = ["hyper_mean", "hyper_std", "local_mean", "shared_std", "y"] + + assert set(samples.keys()) == set(expected_keys) + + # hypers node + assert np.shape(samples["hyper_mean"]) == (15, 5, 1) + assert np.shape(samples["hyper_std"]) == (15, 5, 1) + + # locals node + assert np.shape(samples["local_mean"]) == (15, 5, 6, 1) + + # shared node + assert np.shape(samples["shared_std"]) == (15, 1) + + # y node + assert np.shape(samples["y"]) == (15, 5, 6, 10, 1) + + +def test_irt_simulator(irt_simulator): + # schools + # / \ + # exams students + # | | + # questions | + # \ / + # observations + + assert isinstance(irt_simulator, bf.experimental.graphical_simulator.GraphicalSimulator) + assert isinstance(irt_simulator.sample(5), dict) + + samples = irt_simulator.sample(22) + expected_keys = [ + "mu_exam_mean", + "sigma_exam_mean", + "mu_exam_std", + "sigma_exam_std", + "exam_mean", + "exam_std", + "question_difficulty", + "student_ability", + "obs", + "num_exams", # np.random.randint(2, 4) + "num_questions", # np.random.randint(10, 21) + "num_students", # np.random.randint(100, 201) + ] + + assert set(samples.keys()) == set(expected_keys) + + # schools node + assert np.shape(samples["mu_exam_mean"]) == (22, 1) + assert np.shape(samples["sigma_exam_mean"]) == (22, 1) + assert np.shape(samples["mu_exam_std"]) == (22, 1) + assert np.shape(samples["sigma_exam_std"]) == (22, 1) + + # exams node + assert np.shape(samples["exam_mean"]) == (22, samples["num_exams"], 1) + assert np.shape(samples["exam_std"]) == (22, samples["num_exams"], 1) + + # questions node + assert np.shape(samples["question_difficulty"]) == (22, samples["num_exams"], samples["num_questions"], 1) + + # students node + assert np.shape(samples["student_ability"]) == (22, samples["num_students"], 1) + + # observations node + assert np.shape(samples["obs"]) == ( + 22, + samples["num_exams"], + samples["num_students"], + samples["num_questions"], + 1, + ) From 55b6dfd5e23fe1454f3f5ac394c310d129c2d7bd Mon Sep 17 00:00:00 2001 From: Daniel Habermann <133031176+daniel-habermann@users.noreply.github.com> Date: Sun, 29 Jun 2025 14:00:30 +0200 Subject: [PATCH 14/20] rename GraphicalSimulator._call_sampling_fn to _call_sample_fn to reflect renamed sample method argument --- .../graphical_simulator/graphical_simulator.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/bayesflow/experimental/graphical_simulator/graphical_simulator.py b/bayesflow/experimental/graphical_simulator/graphical_simulator.py index 17a251fcc..ec457c020 100644 --- a/bayesflow/experimental/graphical_simulator/graphical_simulator.py +++ b/bayesflow/experimental/graphical_simulator/graphical_simulator.py @@ -69,7 +69,7 @@ def sample(self, batch_shape: Shape | int, **kwargs) -> dict[str, np.ndarray]: if not parent_nodes: # root node: generate independent samples node_samples = [ - {"__batch_idx": batch_idx, f"__{node}_idx": i} | self._call_sampling_fn(sampling_fn, {}) + {"__batch_idx": batch_idx, f"__{node}_idx": i} | self._call_sample_fn(sampling_fn, {}) for i in range(1, reps + 1) ] else: @@ -86,7 +86,7 @@ def sample(self, batch_shape: Shape | int, **kwargs) -> dict[str, np.ndarray]: [ index_entries | {f"__{node}_idx": i} - | self._call_sampling_fn(sampling_fn, sampling_fn_input) + | self._call_sample_fn(sampling_fn, sampling_fn_input) for i in range(1, reps + 1) ] ) @@ -169,12 +169,12 @@ def _output_shape(self, samples, variable): return tuple(output_shape) - def _call_sampling_fn(self, sampling_fn, args): - signature = inspect.signature(sampling_fn) + def _call_sample_fn(self, sample_fn, args): + signature = inspect.signature(sample_fn) fn_args = signature.parameters accepted_args = {k: v for k, v in args.items() if k in fn_args} - return sampling_fn(**accepted_args) + return sample_fn(**accepted_args) def sorted_ancestors(graph, node): From 4035c169f4f378fc7d7df0a4e4e03f91dd63338a Mon Sep 17 00:00:00 2001 From: Daniel Habermann <133031176+daniel-habermann@users.noreply.github.com> Date: Sun, 29 Jun 2025 14:27:15 +0200 Subject: [PATCH 15/20] rename examples in graphical_simulator.example_simulators --- .../example_simulators/__init__.py | 7 +-- ...irt.py => crossed_design_irt_simulator.py} | 4 +- ...gle_level.py => single_level_simulator.py} | 2 +- .../two_level_repeated_roots.py | 56 ------------------- .../{two_level.py => two_level_simulator.py} | 12 +++- tests/test_simulators/conftest.py | 18 +++--- .../test_graphical_simulator.py | 24 ++++---- 7 files changed, 39 insertions(+), 84 deletions(-) rename bayesflow/experimental/graphical_simulator/example_simulators/{irt.py => crossed_design_irt_simulator.py} (98%) rename bayesflow/experimental/graphical_simulator/example_simulators/{single_level.py => single_level_simulator.py} (96%) delete mode 100644 bayesflow/experimental/graphical_simulator/example_simulators/two_level_repeated_roots.py rename bayesflow/experimental/graphical_simulator/example_simulators/{two_level.py => two_level_simulator.py} (81%) diff --git a/bayesflow/experimental/graphical_simulator/example_simulators/__init__.py b/bayesflow/experimental/graphical_simulator/example_simulators/__init__.py index ea8e4607a..05caded82 100644 --- a/bayesflow/experimental/graphical_simulator/example_simulators/__init__.py +++ b/bayesflow/experimental/graphical_simulator/example_simulators/__init__.py @@ -1,4 +1,3 @@ -from .single_level import single_level -from .two_level import two_level -from .two_level_repeated_roots import two_level_repeated_roots -from .irt import irt +from .single_level_simulator import single_level_simulator +from .two_level_simulator import two_level_simulator +from .crossed_design_irt_simulator import crossed_design_irt_simulator diff --git a/bayesflow/experimental/graphical_simulator/example_simulators/irt.py b/bayesflow/experimental/graphical_simulator/example_simulators/crossed_design_irt_simulator.py similarity index 98% rename from bayesflow/experimental/graphical_simulator/example_simulators/irt.py rename to bayesflow/experimental/graphical_simulator/example_simulators/crossed_design_irt_simulator.py index 21a64122d..70fa8ae12 100644 --- a/bayesflow/experimental/graphical_simulator/example_simulators/irt.py +++ b/bayesflow/experimental/graphical_simulator/example_simulators/crossed_design_irt_simulator.py @@ -3,9 +3,9 @@ from ..graphical_simulator import GraphicalSimulator -def irt(): +def crossed_design_irt_simulator(): r""" - Item Response Theory (IRT) model implemented as a graphical simultor. + Item Response Theory (IRT) model implemented as a graphical simulator. schools / \ diff --git a/bayesflow/experimental/graphical_simulator/example_simulators/single_level.py b/bayesflow/experimental/graphical_simulator/example_simulators/single_level_simulator.py similarity index 96% rename from bayesflow/experimental/graphical_simulator/example_simulators/single_level.py rename to bayesflow/experimental/graphical_simulator/example_simulators/single_level_simulator.py index 49be680ae..af26920eb 100644 --- a/bayesflow/experimental/graphical_simulator/example_simulators/single_level.py +++ b/bayesflow/experimental/graphical_simulator/example_simulators/single_level_simulator.py @@ -3,7 +3,7 @@ from ..graphical_simulator import GraphicalSimulator -def single_level(): +def single_level_simulator(): """ Simple single-level simulator that implements the same model as in https://bayesflow.org/main/_examples/Linear_Regression_Starter.html diff --git a/bayesflow/experimental/graphical_simulator/example_simulators/two_level_repeated_roots.py b/bayesflow/experimental/graphical_simulator/example_simulators/two_level_repeated_roots.py deleted file mode 100644 index 9b907156b..000000000 --- a/bayesflow/experimental/graphical_simulator/example_simulators/two_level_repeated_roots.py +++ /dev/null @@ -1,56 +0,0 @@ -import numpy as np - -from ..graphical_simulator import GraphicalSimulator - - -def two_level_repeated_roots(): - r""" - Same as two_level(), but the hypers root node is repeated 5 times. - - hypers - | - locals shared - \ / - \ / - y - - """ - - def sample_hypers(): - hyper_mean = np.random.normal() - hyper_std = np.abs(np.random.normal()) - - return {"hyper_mean": hyper_mean, "hyper_std": hyper_std} - - def sample_locals(hyper_mean, hyper_std): - local_mean = np.random.normal(hyper_mean, hyper_std) - - return {"local_mean": local_mean} - - def sample_shared(): - shared_std = np.abs(np.random.normal()) - - return {"shared_std": shared_std} - - def sample_y(local_mean, shared_std): - y = np.random.normal(local_mean, shared_std) - - return {"y": y} - - simulator = GraphicalSimulator() - simulator.add_node("hypers", sample_fn=sample_hypers, reps=5) - - simulator.add_node( - "locals", - sample_fn=sample_locals, - reps=6, - ) - - simulator.add_node("shared", sample_fn=sample_shared) - simulator.add_node("y", sample_fn=sample_y, reps=10) - - simulator.add_edge("hypers", "locals") - simulator.add_edge("locals", "y") - simulator.add_edge("shared", "y") - - return simulator diff --git a/bayesflow/experimental/graphical_simulator/example_simulators/two_level.py b/bayesflow/experimental/graphical_simulator/example_simulators/two_level_simulator.py similarity index 81% rename from bayesflow/experimental/graphical_simulator/example_simulators/two_level.py rename to bayesflow/experimental/graphical_simulator/example_simulators/two_level_simulator.py index 296569068..45bdc149e 100644 --- a/bayesflow/experimental/graphical_simulator/example_simulators/two_level.py +++ b/bayesflow/experimental/graphical_simulator/example_simulators/two_level_simulator.py @@ -3,7 +3,7 @@ from ..graphical_simulator import GraphicalSimulator -def two_level(): +def two_level_simulator(repeated_roots=False): r""" Simple hierarchical model with two levels of parameters: hyperparameters and local parameters, along with a shared parameter: @@ -15,6 +15,10 @@ def two_level(): \ / y + Parameters + ---------- + repeated_roots : bool, default false. + """ def sample_hypers(): @@ -39,7 +43,11 @@ def sample_y(local_mean, shared_std): return {"y": y} simulator = GraphicalSimulator() - simulator.add_node("hypers", sample_fn=sample_hypers) + + if not repeated_roots: + simulator.add_node("hypers", sample_fn=sample_hypers) + else: + simulator.add_node("hypers", sample_fn=sample_hypers, reps=5) simulator.add_node( "locals", diff --git a/tests/test_simulators/conftest.py b/tests/test_simulators/conftest.py index 5b6a8ca25..9a0ae09f2 100644 --- a/tests/test_simulators/conftest.py +++ b/tests/test_simulators/conftest.py @@ -251,27 +251,27 @@ def simulator(request): @pytest.fixture() def single_level_simulator(): - from bayesflow.experimental.graphical_simulator.example_simulators import single_level + from bayesflow.experimental.graphical_simulator.example_simulators import single_level_simulator - return single_level() + return single_level_simulator() @pytest.fixture() def two_level_simulator(): - from bayesflow.experimental.graphical_simulator.example_simulators import two_level + from bayesflow.experimental.graphical_simulator.example_simulators import two_level_simulator - return two_level() + return two_level_simulator() @pytest.fixture() def two_level_repeated_roots_simulator(): - from bayesflow.experimental.graphical_simulator.example_simulators import two_level_repeated_roots + from bayesflow.experimental.graphical_simulator.example_simulators import two_level_simulator - return two_level_repeated_roots() + return two_level_simulator(repeated_roots=True) @pytest.fixture() -def irt_simulator(): - from bayesflow.experimental.graphical_simulator.example_simulators import irt +def crossed_design_irt_simulator(): + from bayesflow.experimental.graphical_simulator.example_simulators import crossed_design_irt_simulator - return irt() + return crossed_design_irt_simulator() diff --git a/tests/test_simulators/test_graphical_simulator.py b/tests/test_simulators/test_graphical_simulator.py index f3d711768..0707af71b 100644 --- a/tests/test_simulators/test_graphical_simulator.py +++ b/tests/test_simulators/test_graphical_simulator.py @@ -5,10 +5,12 @@ def test_single_level_simulator(single_level_simulator): # prior -> likelihood - assert isinstance(single_level_simulator, bf.experimental.graphical_simulator.GraphicalSimulator) - assert isinstance(single_level_simulator.sample(5), dict) - samples = single_level_simulator.sample(12) + simulator = single_level_simulator + assert isinstance(simulator, bf.experimental.graphical_simulator.GraphicalSimulator) + assert isinstance(simulator.sample(5), dict) + + samples = simulator.sample(12) expected_keys = ["N", "beta", "sigma", "x", "y"] assert set(samples.keys()) == set(expected_keys) @@ -31,10 +33,11 @@ def test_two_level_simulator(two_level_simulator): # \ / # y - assert isinstance(two_level_simulator, bf.experimental.graphical_simulator.GraphicalSimulator) - assert isinstance(two_level_simulator.sample(5), dict) + simulator = two_level_simulator + assert isinstance(simulator, bf.experimental.graphical_simulator.GraphicalSimulator) + assert isinstance(simulator.sample(5), dict) - samples = two_level_simulator.sample(15) + samples = simulator.sample(15) expected_keys = ["hyper_mean", "hyper_std", "local_mean", "shared_std", "y"] assert set(samples.keys()) == set(expected_keys) @@ -84,7 +87,7 @@ def test_two_level_repeated_roots_simulator(two_level_repeated_roots_simulator): assert np.shape(samples["y"]) == (15, 5, 6, 10, 1) -def test_irt_simulator(irt_simulator): +def test_crossed_design_irt_simulator(crossed_design_irt_simulator): # schools # / \ # exams students @@ -93,10 +96,11 @@ def test_irt_simulator(irt_simulator): # \ / # observations - assert isinstance(irt_simulator, bf.experimental.graphical_simulator.GraphicalSimulator) - assert isinstance(irt_simulator.sample(5), dict) + simulator = crossed_design_irt_simulator + assert isinstance(simulator, bf.experimental.graphical_simulator.GraphicalSimulator) + assert isinstance(simulator.sample(5), dict) - samples = irt_simulator.sample(22) + samples = simulator.sample(22) expected_keys = [ "mu_exam_mean", "sigma_exam_mean", From b5a653c08881c82b83f135fe9129c2068e95a233 Mon Sep 17 00:00:00 2001 From: Daniel Habermann <133031176+daniel-habermann@users.noreply.github.com> Date: Sun, 29 Jun 2025 14:28:07 +0200 Subject: [PATCH 16/20] update description of **kwargs parameter in GraphicalSimulator.sample docstring --- .../experimental/graphical_simulator/graphical_simulator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bayesflow/experimental/graphical_simulator/graphical_simulator.py b/bayesflow/experimental/graphical_simulator/graphical_simulator.py index ec457c020..67df8bbe3 100644 --- a/bayesflow/experimental/graphical_simulator/graphical_simulator.py +++ b/bayesflow/experimental/graphical_simulator/graphical_simulator.py @@ -47,7 +47,7 @@ def sample(self, batch_shape: Shape | int, **kwargs) -> dict[str, np.ndarray]: The shape of the batch to sample. Typically, a tuple indicating the number of samples, but an int can also be passed. **kwargs - Unused + Currently unused """ _ = kwargs # Simulator class requires **kwargs, which are unused here meta_dict = self.meta_fn() if self.meta_fn else {} From a3c1fb6d19d7565e3579d32507083ea0fa2afd7a Mon Sep 17 00:00:00 2001 From: Daniel Habermann <133031176+daniel-habermann@users.noreply.github.com> Date: Sun, 29 Jun 2025 14:28:52 +0200 Subject: [PATCH 17/20] remove unused is_root_node function --- .../experimental/graphical_simulator/graphical_simulator.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/bayesflow/experimental/graphical_simulator/graphical_simulator.py b/bayesflow/experimental/graphical_simulator/graphical_simulator.py index 67df8bbe3..0a604d439 100644 --- a/bayesflow/experimental/graphical_simulator/graphical_simulator.py +++ b/bayesflow/experimental/graphical_simulator/graphical_simulator.py @@ -181,10 +181,6 @@ def sorted_ancestors(graph, node): return [n for n in nx.topological_sort(graph) if n in nx.ancestors(graph, node)] -def is_root_node(graph, node): - return len(list(graph.predecessors(node))) == 0 - - def merge_lists_of_dicts(nested_list: list[list[dict]]) -> list[dict]: """ Merges all combinations of dictionaries from a list of lists. From c5044c1dee4bf25815c994f94089ea3822d741d4 Mon Sep 17 00:00:00 2001 From: Daniel Habermann <133031176+daniel-habermann@users.noreply.github.com> Date: Mon, 7 Jul 2025 00:07:35 +0200 Subject: [PATCH 18/20] use 0-based index for internal representation --- .../graphical_simulator/graphical_simulator.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/bayesflow/experimental/graphical_simulator/graphical_simulator.py b/bayesflow/experimental/graphical_simulator/graphical_simulator.py index 0a604d439..1dedc12de 100644 --- a/bayesflow/experimental/graphical_simulator/graphical_simulator.py +++ b/bayesflow/experimental/graphical_simulator/graphical_simulator.py @@ -70,7 +70,7 @@ def sample(self, batch_shape: Shape | int, **kwargs) -> dict[str, np.ndarray]: # root node: generate independent samples node_samples = [ {"__batch_idx": batch_idx, f"__{node}_idx": i} | self._call_sample_fn(sampling_fn, {}) - for i in range(1, reps + 1) + for i in range(reps) ] else: # non-root node: depends on parent samples @@ -87,7 +87,7 @@ def sample(self, batch_shape: Shape | int, **kwargs) -> dict[str, np.ndarray]: index_entries | {f"__{node}_idx": i} | self._call_sample_fn(sampling_fn, sampling_fn_input) - for i in range(1, reps + 1) + for i in range(reps) ] ) @@ -113,8 +113,8 @@ def _collect_output(self, samples): # build dict of node repetitions reps = {} for ancestor in ancestors: - reps[ancestor] = max(s[f"__{ancestor}_idx"] for s in samples.flat[0]) - reps[node] = max(s[f"__{node}_idx"] for s in samples.flat[0]) + reps[ancestor] = max(s[f"__{ancestor}_idx"] for s in samples.flat[0]) + 1 + reps[node] = max(s[f"__{node}_idx"] for s in samples.flat[0]) + 1 variable_names = self._variable_names(samples) @@ -130,11 +130,11 @@ def _collect_output(self, samples): # add index elements for ancestors for ancestor in ancestors: if reps[ancestor] != 1: - idx.append(sample[f"__{ancestor}_idx"] - 1) # -1 for 0-based indexing + idx.append(sample[f"__{ancestor}_idx"]) # add index elements for node if reps[node] != 1: - idx.append(sample[f"__{node}_idx"] - 1) # -1 for 0-based indexing + idx.append(sample[f"__{node}_idx"]) output_dict[variable][tuple(idx)] = sample[variable] @@ -154,12 +154,12 @@ def _output_shape(self, samples, variable): # add ancestor reps for ancestor in ancestors: - node_reps = max(s[f"__{ancestor}_idx"] for s in samples.flat[0]) + node_reps = max(s[f"__{ancestor}_idx"] for s in samples.flat[0]) + 1 if node_reps != 1: output_shape.append(node_reps) # add node reps - node_reps = max(s[f"__{node}_idx"] for s in samples.flat[0]) + node_reps = max(s[f"__{node}_idx"] for s in samples.flat[0]) + 1 if node_reps != 1: output_shape.append(node_reps) From 6f70cd274297e8ef14852dee2fea42ee087236f6 Mon Sep 17 00:00:00 2001 From: Daniel Habermann <133031176+daniel-habermann@users.noreply.github.com> Date: Thu, 24 Jul 2025 14:51:04 +0200 Subject: [PATCH 19/20] rename sampling_fn variables to sample_fn, precompute topological sort, remove nesting in node_samples computation --- .../graphical_simulator.py | 29 +++++++++---------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/bayesflow/experimental/graphical_simulator/graphical_simulator.py b/bayesflow/experimental/graphical_simulator/graphical_simulator.py index 1dedc12de..1e2854cd9 100644 --- a/bayesflow/experimental/graphical_simulator/graphical_simulator.py +++ b/bayesflow/experimental/graphical_simulator/graphical_simulator.py @@ -14,13 +14,13 @@ class GraphicalSimulator(Simulator): """ A graph-based simulator that generates samples by traversing a DAG - and calling user-defined sampling functions at each node. + and calling user-defined sample functions at each node. Parameters ---------- meta_fn : Optional[Callable[[], dict[str, Any]]] A callable that returns a dictionary of meta data. - This meta data can be used to dynamically vary the number of sampling repetitions (`reps`) + This meta data can be used to dynamically vary the number of sample repetitions (`reps`) for nodes added via `add_node`. """ @@ -39,7 +39,7 @@ def add_edge(self, from_node: str, to_node: str): def sample(self, batch_shape: Shape | int, **kwargs) -> dict[str, np.ndarray]: """ Generates samples by topologically traversing the DAG. - For each node, the sampling function is called based on parent values. + For each node, the sample function is called based on parent values. Parameters ---------- @@ -57,19 +57,21 @@ def sample(self, batch_shape: Shape | int, **kwargs) -> dict[str, np.ndarray]: for node in self.graph.nodes: samples_by_node[node] = np.empty(batch_shape, dtype="object") + ordered_nodes = list(nx.topological_sort(self.graph)) + for batch_idx in np.ndindex(batch_shape): - for node in nx.topological_sort(self.graph): + for node in ordered_nodes: node_samples = [] parent_nodes = list(self.graph.predecessors(node)) - sampling_fn = self.graph.nodes[node]["sample_fn"] + sample_fn = self.graph.nodes[node]["sample_fn"] reps_field = self.graph.nodes[node]["reps"] reps = reps_field if isinstance(reps_field, int) else meta_dict[reps_field] if not parent_nodes: # root node: generate independent samples node_samples = [ - {"__batch_idx": batch_idx, f"__{node}_idx": i} | self._call_sample_fn(sampling_fn, {}) + {"__batch_idx": batch_idx, f"__{node}_idx": i} | self._call_sample_fn(sample_fn, {}) for i in range(reps) ] else: @@ -81,15 +83,12 @@ def sample(self, batch_shape: Shape | int, **kwargs) -> dict[str, np.ndarray]: index_entries = {k: v for k, v in merged.items() if k.startswith("__")} variable_entries = {k: v for k, v in merged.items() if not k.startswith("__")} - sampling_fn_input = variable_entries | meta_dict - node_samples.extend( - [ - index_entries - | {f"__{node}_idx": i} - | self._call_sample_fn(sampling_fn, sampling_fn_input) - for i in range(reps) - ] - ) + sample_fn_input = variable_entries | meta_dict + samples = [ + index_entries | {f"__{node}_idx": i} | self._call_sample_fn(sample_fn, sample_fn_input) + for i in range(reps) + ] + node_samples.extend(samples) samples_by_node[node][batch_idx] = node_samples From e7331d6e21b9419ec53244b95b1f5c1bb9b267f8 Mon Sep 17 00:00:00 2001 From: Daniel Habermann <133031176+daniel-habermann@users.noreply.github.com> Date: Thu, 24 Jul 2025 15:07:55 +0200 Subject: [PATCH 20/20] [*batch_shape] to list(batch_shape) --- .../experimental/graphical_simulator/graphical_simulator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bayesflow/experimental/graphical_simulator/graphical_simulator.py b/bayesflow/experimental/graphical_simulator/graphical_simulator.py index 1e2854cd9..22e5b8c43 100644 --- a/bayesflow/experimental/graphical_simulator/graphical_simulator.py +++ b/bayesflow/experimental/graphical_simulator/graphical_simulator.py @@ -148,7 +148,7 @@ def _output_shape(self, samples, variable): # start with batch shape batch_shape = samples.shape - output_shape = [*batch_shape] + output_shape = list(batch_shape) ancestors = sorted_ancestors(self.graph, node) # add ancestor reps