Skip to content

Add keep_conditions argument to continuous_approximator.sample #523

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 7 commits into from

Conversation

han-ol
Copy link
Collaborator

@han-ol han-ol commented Jun 27, 2025

To compute quantities that depend both on conditions and inference variables it is useful to keep the repeated conditions that correspond to samples generated by an approximator.

For convenience I'd propose to add a keep_conditions flag to the approximators that repeats the conditions and returns them along with the newly generated samples.

To-Do:

  • add keep_conditions for ContinuousApproximator.sample
  • add keep_conditions for PointApproximator.sample
  • tests
    [ ] add keep_conditions for ModelComparisonApproximator.sample

EDIT: ModelComparisonApproximator doesn't have sample, so I removed the todo item.

Copy link

codecov bot commented Jun 27, 2025

Codecov Report

All modified and coverable lines are covered by tests ✅

Files with missing lines Coverage Δ
bayesflow/approximators/continuous_approximator.py 90.64% <100.00%> (+0.23%) ⬆️
bayesflow/approximators/point_approximator.py 94.68% <100.00%> (+0.29%) ⬆️

... and 1 file with indirect coverage changes

@han-ol han-ol self-assigned this Jun 28, 2025
@han-ol
Copy link
Collaborator Author

han-ol commented Jun 28, 2025

Ok, after your review and eventual changes this would be ready to merge.

The convenience this PR provides is that instead of the following tricky commands

posterior_samples = workflow.sample(conditions=val_sims, num_samples=num_post_samples)
axis=1
rep_val_sims = keras.tree.map_structure(lambda tensor: np.repeat(np.expand_dims(tensor, axis=axis), num_post_samples, axis=axis), val_sims)
posterior_samples_with_conditions = rep_val_sims | posterior_samples

users can directly get

posterior_samples_with_conditions = workflow.sample(conditions=val_sims, num_samples=num_post_samples, keep_conditions=True)

Both result in shapes like this:

{'parameters': (64, 100, 2), 'observables': (64, 100, 160, 1)}

Please see docstrings for a precise description of keep_conditions.

@han-ol han-ol requested review from stefanradev93 and vpratz and removed request for stefanradev93 June 28, 2025 14:51
@han-ol han-ol marked this pull request as ready for review June 28, 2025 14:52
@han-ol
Copy link
Collaborator Author

han-ol commented Jun 28, 2025

When writing tests for this I came across a few skip instructions of multivariate normal approximators that can be re-enabled, so I added that to this PR.

@@ -465,6 +475,15 @@ def sample(

if split:
samples = split_arrays(samples, axis=-1)

if keep_conditions:
conditions = keras.tree.map_structure(keras.ops.convert_to_numpy, conditions)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The conditions that are repeated and kept are not the original ones, but the prepared (self._prepare_data) and adapted ones. Is this intentional? What are the downstream functions that this should be the input for, and which input would they expect?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As the same code is used below, do you think it would make sense to move this into a helper function?

Copy link
Collaborator Author

@han-ol han-ol Jun 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for taking a look!

Yes, exactly. My thinking was, that this way we can avoid the difficulty of conditions that are not (batch_size, ...) tensors but rather non-batchable conditions. One example for this is in the linear regression notebook, where the sample size is just an integer and it goes into the condition.

When we repeat the output of self._prepare_data we don't have to deal with this distinction what needs to be repeated and what does not.
Repeating the prepared conditions is also what happens in self._sample to generate the conditions passed along to the self.inference_network. This might help being robust against special adapter transformations.
However, the current implementation relies on the adapter inverse being faithful in all special cases and I am not sure of that at the moment.

Do you think this is unsafe in some way? I need to think about it some more.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The downstream functions would be for example approximator.log_prob, but also any other post processing users might want to do with samples from the joint posterior predictive.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explanation, I missed the inverse=True in the adapter call and misinterpreted what happens. I'm still a bit confused, though.

What would be your current assumption regarding the adapter calls? That conditions = self.adapter(conditions, inverse=True, strict=False, **kwargs) gives the same output as the conditions initially passed by the user? Or that there is a difference (i.e., the adapter is not perfectly invertible) that we try to exploit?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am a bit confused as well. I think users should have the choice on whether they want to output the conditions before or after the adapter call. Likely, before the adapter call will have more use cases.

Copy link
Collaborator Author

@han-ol han-ol Jun 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I admit its a bit confusing, sorry! I definitely agree, we always want the output to be repeated pre-adapter conditions. The question is only how to get there.

Two options:

  1. repeat all array values of condition dictionaries while leaving non-batched conditions intact.
  2. apply adapter to conditions (which will lead to broadcasting of the non-batched conditions to batch_shape), then apply an inverse adapter. This way only dict values that are repeated in approximator._sample survive which can then be repeated safely.

The second sounds more complicated, but parallels the journey of how conditions are actually repeated and passed to the inference network while sampling and avoids to have to guess what needs to be repeated and what does not.

The implementation I have proposed represents option 2.

My uncertainties in choosing which option is better are

  1. how hard is it to reliably guess what needs to be repeated? Are there special adapters where the naive guess (just not repeating anything that isn't an array) would fail?
  2. is the inverse adapter sufficiently reliable?

EDIT: I am now also confused. I am misrepresenting the order of what the code actually does.
EDIT2: I fixed the order.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After talking directly, Valentin and I think the feature to keep conditions in a repeated way cannot be provided to users in general. Ignoring the flexibility of real adapters is bound to produce hard-to-debug problems for users.
Both of the uncertainties I mentioned have a negative answers, so both options are therefore inadequate.

Uncertainty 1 (how hard is it to guess what conditions needs to be repeated): input to adapters can be too diverse to reliably guess how to repeat them. This takes option 1 off the table.

Uncertainty 2 (is the inverse adapter cycle consistent in general): We can think of adapters that fail the requirement by essentially losing information in the forward pass. The chain

$$\text{repeat} \circ \text{inverse-adapter} \circ \text{adapter}$$

then produces incorrect repeated conditions.

@han-ol
Copy link
Collaborator Author

han-ol commented Jun 30, 2025

As discussed in the comments (#523 (comment)) the proposed feature cannot be delivered for arbitrary adapters and I would thus close the PR tomorrow.

How to proceed:

To get a similar level of convenience without the trouble with reliable inverse adapters, we can solve this inside of log_prob.
The adapted conditions can be repeated inside of log_prob whenever the inference variables have a sample dimension.
Since the inverse adapter does not need to be applied to conditions, this circumvents the inherent difficulty that prevents the PR from working out.

@paul-buerkner
Copy link
Contributor

Thank you for thinking about it more throughly. I agree with your assessment. I believe such a feature would be (down the line) better suited for packages that build on bayesflow and support a restricted class of models with known variable structure, for which such a feature could be reliably implemented.

@paul-buerkner
Copy link
Contributor

One comment about log_prob: I would not solve this inside log_prob for now until we know that this is actually a sensible diagnostic. I.e. we would need to repeat some of the experiments in modrak with log_prob to see if that behaves similarily than the (analytic) joint log_lik. Only if the results are encouraging, we should implement it somewhat natively (because implementation from us implies that we encourage users to use it).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants