diff --git a/bayesflow/approximators/continuous_approximator.py b/bayesflow/approximators/continuous_approximator.py index 46fe4bb0d..36001d835 100644 --- a/bayesflow/approximators/continuous_approximator.py +++ b/bayesflow/approximators/continuous_approximator.py @@ -422,6 +422,7 @@ def sample( num_samples: int, conditions: Mapping[str, np.ndarray], split: bool = False, + keep_conditions: bool = False, **kwargs, ) -> dict[str, np.ndarray]: """ @@ -437,13 +438,22 @@ def sample( split : bool, default=False Whether to split the output arrays along the last axis and return one column vector per target variable samples. + keep_conditions : bool, default=False + If True, the returned dict will include each of the original + conditioning variables, **repeated** along the sample axis so that + they align 1:1 with the generated samples. Each condition array + will have shape ``(num_datasets, num_samples, *condition_variable_shape)``. + + By default conditions are not included in the returned dict. **kwargs : dict Additional keyword arguments for the adapter and sampling process. Returns ------- dict[str, np.ndarray] - Dictionary containing generated samples with the same keys as `conditions`. + Dictionary containing generated samples and optionally the corresponding conditions. + + Dictionary values are arrays of shape ``(num_datasets, num_samples, *variable_shape)``. """ # Adapt, optionally standardize and convert conditions to tensor conditions = self._prepare_data(conditions, **kwargs) @@ -465,6 +475,15 @@ def sample( if split: samples = split_arrays(samples, axis=-1) + + if keep_conditions: + conditions = keras.tree.map_structure(keras.ops.convert_to_numpy, conditions) + conditions = self.adapter(conditions, inverse=True, strict=False, **kwargs) + repeated_conditions = keras.tree.map_structure( + lambda value: np.repeat(np.expand_dims(value, axis=1), num_samples, axis=1), conditions + ) + samples = repeated_conditions | samples + return samples def _prepare_data( diff --git a/bayesflow/approximators/point_approximator.py b/bayesflow/approximators/point_approximator.py index 1318185d2..edca50211 100644 --- a/bayesflow/approximators/point_approximator.py +++ b/bayesflow/approximators/point_approximator.py @@ -89,6 +89,7 @@ def sample( num_samples: int, conditions: Mapping[str, np.ndarray], split: bool = False, + keep_conditions: bool = False, **kwargs, ) -> dict[str, dict[str, np.ndarray]]: """ @@ -107,6 +108,14 @@ def sample( split : bool, optional If True, the sampled arrays are split along the last axis, by default False. Currently not supported for :py:class:`PointApproximator` . + keep_conditions : bool, default=False + If True, the returned dict will include each of the original + conditioning variables, **repeated** along the sample axis so that + they align 1:1 with the generated samples. Each condition array + will have shape ``(num_datasets, num_samples, *condition_variable_shape)``. + + By default conditions are not included in the returned dict. + **kwargs Additional keyword arguments passed to underlying processing functions. @@ -115,11 +124,11 @@ def sample( samples : dict[str, np.ndarray or dict[str, np.ndarray]] Samples for all inference variables and all parametric scoring rules in a nested dictionary. - 1. Each first-level key is the name of an inference variable. + 1. Each first-level key is the name of an inference variable or condition. 2. (If there are multiple parametric scores, each second-level key is the name of such a score.) Each output (i.e., dictionary value that is not itself a dictionary) is an array - of shape (num_datasets, num_samples, variable_block_size). + of shape ``(num_datasets, num_samples, *variable_shape)``. """ # Adapt, optionally standardize and convert conditions to tensor. conditions = self._prepare_data(conditions, **kwargs) @@ -141,6 +150,14 @@ def sample( # Squeeze sample dictionary if there's only one key-value pair. samples = self._squeeze_parametric_score_major_dict(samples) + if keep_conditions: + conditions = keras.tree.map_structure(keras.ops.convert_to_numpy, conditions) + conditions = self.adapter(conditions, inverse=True, strict=False, **kwargs) + repeated_conditions = keras.tree.map_structure( + lambda value: np.repeat(np.expand_dims(value, axis=1), num_samples, axis=1), conditions + ) + samples = repeated_conditions | samples + return samples def log_prob( diff --git a/tests/test_approximators/test_fit.py b/tests/test_approximators/test_fit.py index b561efb77..27d4716c4 100644 --- a/tests/test_approximators/test_fit.py +++ b/tests/test_approximators/test_fit.py @@ -3,7 +3,6 @@ import pytest import io from contextlib import redirect_stdout -from tests.utils import check_approximator_multivariate_normal_score @pytest.mark.skip(reason="not implemented") @@ -20,9 +19,6 @@ def test_fit(amortizer, dataset): def test_loss_progress(approximator, train_dataset, validation_dataset): - # as long as MultivariateNormalScore is unstable, skip fit progress test - check_approximator_multivariate_normal_score(approximator) - approximator.compile(optimizer="AdamW") num_epochs = 3 diff --git a/tests/test_approximators/test_log_prob.py b/tests/test_approximators/test_log_prob.py index 8cfbb2fe6..9c96cdeb6 100644 --- a/tests/test_approximators/test_log_prob.py +++ b/tests/test_approximators/test_log_prob.py @@ -1,12 +1,10 @@ import keras import numpy as np -from tests.utils import check_combination_simulator_adapter, check_approximator_multivariate_normal_score +from tests.utils import check_combination_simulator_adapter def test_approximator_log_prob(approximator, simulator, batch_size, adapter): check_combination_simulator_adapter(simulator, adapter) - # as long as MultivariateNormalScore is unstable, skip - check_approximator_multivariate_normal_score(approximator) num_batches = 4 data = simulator.sample((num_batches * batch_size,)) diff --git a/tests/test_approximators/test_sample.py b/tests/test_approximators/test_sample.py index d7c2a3bcf..639e6cf72 100644 --- a/tests/test_approximators/test_sample.py +++ b/tests/test_approximators/test_sample.py @@ -1,11 +1,10 @@ +import numpy as np import keras -from tests.utils import check_combination_simulator_adapter, check_approximator_multivariate_normal_score +from tests.utils import check_combination_simulator_adapter def test_approximator_sample(approximator, simulator, batch_size, adapter): check_combination_simulator_adapter(simulator, adapter) - # as long as MultivariateNormalScore is unstable, skip - check_approximator_multivariate_normal_score(approximator) num_batches = 4 data = simulator.sample((num_batches * batch_size,)) @@ -18,3 +17,39 @@ def test_approximator_sample(approximator, simulator, batch_size, adapter): samples = approximator.sample(num_samples=2, conditions=data) assert isinstance(samples, dict) + + +def test_approximator_sample_keep_conditions(approximator, simulator, batch_size, adapter): + check_combination_simulator_adapter(simulator, adapter) + + num_batches = 4 + data = simulator.sample((num_batches * batch_size,)) + + batch = adapter(data) + batch = keras.tree.map_structure(keras.ops.convert_to_tensor, batch) + batch_shapes = keras.tree.map_structure(keras.ops.shape, batch) + approximator.build(batch_shapes) + + num_samples = 2 + samples_and_conditions = approximator.sample(num_samples=num_samples, conditions=data, keep_conditions=True) + + assert isinstance(samples_and_conditions, dict) + + # remove inference_variables from sample output and apply adapter + inference_variables_keys = approximator.sample(num_samples=num_samples, conditions=data).keys() + for key in inference_variables_keys: + samples_and_conditions.pop(key) + adapted_conditions = adapter(samples_and_conditions, strict=False) + + assert any(k in adapted_conditions for k in approximator.CONDITION_KEYS), ( + f"adapter(approximator.sample(..., keep_conditions=True)) must return at least one of" + f"{approximator.CONDITION_KEYS!r}. Keys are {adapted_conditions.keys()}." + ) + + for key, value in adapted_conditions.items(): + assert value.shape[:2] == (num_batches * batch_size, num_samples), ( + f"{key} should have shape ({num_batches * batch_size}, {num_samples}, ...) but has {value.shape}." + ) + + if key in approximator.CONDITION_KEYS: + assert np.all(np.ptp(value, axis=1) == 0), "Not all values are the same along axis 1"