Skip to content

Add keep_conditions argument to continuous_approximator.sample #523

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion bayesflow/approximators/continuous_approximator.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,7 @@ def sample(
num_samples: int,
conditions: Mapping[str, np.ndarray],
split: bool = False,
keep_conditions: bool = False,
**kwargs,
) -> dict[str, np.ndarray]:
"""
Expand All @@ -437,13 +438,22 @@ def sample(
split : bool, default=False
Whether to split the output arrays along the last axis and return one column vector per target variable
samples.
keep_conditions : bool, default=False
If True, the returned dict will include each of the original
conditioning variables, **repeated** along the sample axis so that
they align 1:1 with the generated samples. Each condition array
will have shape ``(num_datasets, num_samples, *condition_variable_shape)``.

By default conditions are not included in the returned dict.
**kwargs : dict
Additional keyword arguments for the adapter and sampling process.

Returns
-------
dict[str, np.ndarray]
Dictionary containing generated samples with the same keys as `conditions`.
Dictionary containing generated samples and optionally the corresponding conditions.

Dictionary values are arrays of shape ``(num_datasets, num_samples, *variable_shape)``.
"""
# Adapt, optionally standardize and convert conditions to tensor
conditions = self._prepare_data(conditions, **kwargs)
Expand All @@ -465,6 +475,15 @@ def sample(

if split:
samples = split_arrays(samples, axis=-1)

if keep_conditions:
conditions = keras.tree.map_structure(keras.ops.convert_to_numpy, conditions)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The conditions that are repeated and kept are not the original ones, but the prepared (self._prepare_data) and adapted ones. Is this intentional? What are the downstream functions that this should be the input for, and which input would they expect?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As the same code is used below, do you think it would make sense to move this into a helper function?

Copy link
Collaborator Author

@han-ol han-ol Jun 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for taking a look!

Yes, exactly. My thinking was, that this way we can avoid the difficulty of conditions that are not (batch_size, ...) tensors but rather non-batchable conditions. One example for this is in the linear regression notebook, where the sample size is just an integer and it goes into the condition.

When we repeat the output of self._prepare_data we don't have to deal with this distinction what needs to be repeated and what does not.
Repeating the prepared conditions is also what happens in self._sample to generate the conditions passed along to the self.inference_network. This might help being robust against special adapter transformations.
However, the current implementation relies on the adapter inverse being faithful in all special cases and I am not sure of that at the moment.

Do you think this is unsafe in some way? I need to think about it some more.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The downstream functions would be for example approximator.log_prob, but also any other post processing users might want to do with samples from the joint posterior predictive.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explanation, I missed the inverse=True in the adapter call and misinterpreted what happens. I'm still a bit confused, though.

What would be your current assumption regarding the adapter calls? That conditions = self.adapter(conditions, inverse=True, strict=False, **kwargs) gives the same output as the conditions initially passed by the user? Or that there is a difference (i.e., the adapter is not perfectly invertible) that we try to exploit?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am a bit confused as well. I think users should have the choice on whether they want to output the conditions before or after the adapter call. Likely, before the adapter call will have more use cases.

Copy link
Collaborator Author

@han-ol han-ol Jun 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I admit its a bit confusing, sorry! I definitely agree, we always want the output to be repeated pre-adapter conditions. The question is only how to get there.

Two options:

  1. repeat all array values of condition dictionaries while leaving non-batched conditions intact.
  2. apply adapter to conditions (which will lead to broadcasting of the non-batched conditions to batch_shape), then apply an inverse adapter. This way only dict values that are repeated in approximator._sample survive which can then be repeated safely.

The second sounds more complicated, but parallels the journey of how conditions are actually repeated and passed to the inference network while sampling and avoids to have to guess what needs to be repeated and what does not.

The implementation I have proposed represents option 2.

My uncertainties in choosing which option is better are

  1. how hard is it to reliably guess what needs to be repeated? Are there special adapters where the naive guess (just not repeating anything that isn't an array) would fail?
  2. is the inverse adapter sufficiently reliable?

EDIT: I am now also confused. I am misrepresenting the order of what the code actually does.
EDIT2: I fixed the order.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After talking directly, Valentin and I think the feature to keep conditions in a repeated way cannot be provided to users in general. Ignoring the flexibility of real adapters is bound to produce hard-to-debug problems for users.
Both of the uncertainties I mentioned have a negative answers, so both options are therefore inadequate.

Uncertainty 1 (how hard is it to guess what conditions needs to be repeated): input to adapters can be too diverse to reliably guess how to repeat them. This takes option 1 off the table.

Uncertainty 2 (is the inverse adapter cycle consistent in general): We can think of adapters that fail the requirement by essentially losing information in the forward pass. The chain

$$\text{repeat} \circ \text{inverse-adapter} \circ \text{adapter}$$

then produces incorrect repeated conditions.

conditions = self.adapter(conditions, inverse=True, strict=False, **kwargs)
repeated_conditions = keras.tree.map_structure(
lambda value: np.repeat(np.expand_dims(value, axis=1), num_samples, axis=1), conditions
)
samples = repeated_conditions | samples

return samples

def _prepare_data(
Expand Down
21 changes: 19 additions & 2 deletions bayesflow/approximators/point_approximator.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def sample(
num_samples: int,
conditions: Mapping[str, np.ndarray],
split: bool = False,
keep_conditions: bool = False,
**kwargs,
) -> dict[str, dict[str, np.ndarray]]:
"""
Expand All @@ -107,6 +108,14 @@ def sample(
split : bool, optional
If True, the sampled arrays are split along the last axis, by default False.
Currently not supported for :py:class:`PointApproximator` .
keep_conditions : bool, default=False
If True, the returned dict will include each of the original
conditioning variables, **repeated** along the sample axis so that
they align 1:1 with the generated samples. Each condition array
will have shape ``(num_datasets, num_samples, *condition_variable_shape)``.

By default conditions are not included in the returned dict.

**kwargs
Additional keyword arguments passed to underlying processing functions.

Expand All @@ -115,11 +124,11 @@ def sample(
samples : dict[str, np.ndarray or dict[str, np.ndarray]]
Samples for all inference variables and all parametric scoring rules in a nested dictionary.

1. Each first-level key is the name of an inference variable.
1. Each first-level key is the name of an inference variable or condition.
2. (If there are multiple parametric scores, each second-level key is the name of such a score.)

Each output (i.e., dictionary value that is not itself a dictionary) is an array
of shape (num_datasets, num_samples, variable_block_size).
of shape ``(num_datasets, num_samples, *variable_shape)``.
"""
# Adapt, optionally standardize and convert conditions to tensor.
conditions = self._prepare_data(conditions, **kwargs)
Expand All @@ -141,6 +150,14 @@ def sample(
# Squeeze sample dictionary if there's only one key-value pair.
samples = self._squeeze_parametric_score_major_dict(samples)

if keep_conditions:
conditions = keras.tree.map_structure(keras.ops.convert_to_numpy, conditions)
conditions = self.adapter(conditions, inverse=True, strict=False, **kwargs)
repeated_conditions = keras.tree.map_structure(
lambda value: np.repeat(np.expand_dims(value, axis=1), num_samples, axis=1), conditions
)
samples = repeated_conditions | samples

return samples

def log_prob(
Expand Down
4 changes: 0 additions & 4 deletions tests/test_approximators/test_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import pytest
import io
from contextlib import redirect_stdout
from tests.utils import check_approximator_multivariate_normal_score


@pytest.mark.skip(reason="not implemented")
Expand All @@ -20,9 +19,6 @@ def test_fit(amortizer, dataset):


def test_loss_progress(approximator, train_dataset, validation_dataset):
# as long as MultivariateNormalScore is unstable, skip fit progress test
check_approximator_multivariate_normal_score(approximator)

approximator.compile(optimizer="AdamW")
num_epochs = 3

Expand Down
4 changes: 1 addition & 3 deletions tests/test_approximators/test_log_prob.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import keras
import numpy as np
from tests.utils import check_combination_simulator_adapter, check_approximator_multivariate_normal_score
from tests.utils import check_combination_simulator_adapter


def test_approximator_log_prob(approximator, simulator, batch_size, adapter):
check_combination_simulator_adapter(simulator, adapter)
# as long as MultivariateNormalScore is unstable, skip
check_approximator_multivariate_normal_score(approximator)

num_batches = 4
data = simulator.sample((num_batches * batch_size,))
Expand Down
41 changes: 38 additions & 3 deletions tests/test_approximators/test_sample.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import numpy as np
import keras
from tests.utils import check_combination_simulator_adapter, check_approximator_multivariate_normal_score
from tests.utils import check_combination_simulator_adapter


def test_approximator_sample(approximator, simulator, batch_size, adapter):
check_combination_simulator_adapter(simulator, adapter)
# as long as MultivariateNormalScore is unstable, skip
check_approximator_multivariate_normal_score(approximator)

num_batches = 4
data = simulator.sample((num_batches * batch_size,))
Expand All @@ -18,3 +17,39 @@ def test_approximator_sample(approximator, simulator, batch_size, adapter):
samples = approximator.sample(num_samples=2, conditions=data)

assert isinstance(samples, dict)


def test_approximator_sample_keep_conditions(approximator, simulator, batch_size, adapter):
check_combination_simulator_adapter(simulator, adapter)

num_batches = 4
data = simulator.sample((num_batches * batch_size,))

batch = adapter(data)
batch = keras.tree.map_structure(keras.ops.convert_to_tensor, batch)
batch_shapes = keras.tree.map_structure(keras.ops.shape, batch)
approximator.build(batch_shapes)

num_samples = 2
samples_and_conditions = approximator.sample(num_samples=num_samples, conditions=data, keep_conditions=True)

assert isinstance(samples_and_conditions, dict)

# remove inference_variables from sample output and apply adapter
inference_variables_keys = approximator.sample(num_samples=num_samples, conditions=data).keys()
for key in inference_variables_keys:
samples_and_conditions.pop(key)
adapted_conditions = adapter(samples_and_conditions, strict=False)

assert any(k in adapted_conditions for k in approximator.CONDITION_KEYS), (
f"adapter(approximator.sample(..., keep_conditions=True)) must return at least one of"
f"{approximator.CONDITION_KEYS!r}. Keys are {adapted_conditions.keys()}."
)

for key, value in adapted_conditions.items():
assert value.shape[:2] == (num_batches * batch_size, num_samples), (
f"{key} should have shape ({num_batches * batch_size}, {num_samples}, ...) but has {value.shape}."
)

if key in approximator.CONDITION_KEYS:
assert np.all(np.ptp(value, axis=1) == 0), "Not all values are the same along axis 1"