Skip to content

Conversation

arrjon
Copy link
Member

@arrjon arrjon commented Sep 8, 2025

This pull request introduces compositional sampling support to the BayesFlow framework, enabling diffusion models to handle multiple compositional conditions efficiently. The main changes span the continuous approximator, diffusion model, and inference network modules, adding new methods and refactoring existing ones to support compositional structures in sampling, inference, and diffusion processes.

Larger changes include:

  • Added a new compositional_sample method to ContinuousApproximator, which generates samples with compositional structure and handles flattening, reshaping, and prior score computation for multiple compositional conditions. Supporting internal method _compositional_sample was also introduced.
  • In DiffusionModel, implemented compositional diffusion support including:
    • New compositional_bridge and compositional_velocity methods for compositional score calculation.
    • _compute_individual_scores helper for handling multiple compositional conditions.
    • _inverse_compositional method for inverse compositional diffusion sampling.

The idea is that the workflow now has the method compositional_sample, which expects conditions in the form (n_datasets, n_conditions, ...). Then we can perform compositional sampling with diffusion models.
compositional_sample allows to set a mini_batch_size for memory efficient computation of the compositional score, which does not work with jax backend however, as jax does not like stochasticity in its integrators which cannot be precomputed. We could support here only fixed step sizes though?

To compute the compositional score we need access to the score of the prior. Here we need to handle the adapter carefully so that we compute the correct score. In the current draft, I am not sure I computed the prior score correctly. Some ideas would be great, currently it fails for jax because the adpater is converting stuff to numpy back and forth, but for torch it is working.

@arrjon arrjon self-assigned this Sep 8, 2025
@arrjon arrjon requested a review from stefanradev93 September 8, 2025 15:13
Copy link

codecov bot commented Sep 8, 2025

@stefanradev93
Copy link
Contributor

Hi Jonas, this is great! One general design question that I would like to discuss is whether to add the new capabilities to the existing classes or inherit from the existing classes and add the new methods there , e.g., as in CompositionalApproximator, CompositionalDiffusionModel, ... etc. The latter has the advantage that the existing interfaces remain more compact but introduces the need for new classes. @vpratz @paul-buerkner Since the interface is already working well (except for JAX), I think it's a good time to discuss.

@paul-buerkner
Copy link
Contributor

Where can I see examples of it's use and how it would alternatively look if the structure was different?

@arrjon
Copy link
Member Author

arrjon commented Sep 16, 2025

So at the moment, the compositional part is only relevant during inference. You train a diffusion model, and then you can do the following:

# training_data.shape = (n_datasets, ...), so no conditions
# sim_data.shape = (n_datasets, n_condtions, ...)

workflow.approximator.inference_network.integrate_kwargs = {'method': 'euler_maruyama',
                                                                   'steps': 200,
                                                                   'mini_batch_size': 2,  # how many conditions in each step are used to compute the estimate of the compositional score
                                                                   'compositional_d1': 0.05,  # density bridge 
                                                                   }

posterior_samples = workflow.compositional_sample(num_samples=100,
                                                                  conditions={'sim_data': test_data_comp_trials['sim_data']},
                                                                  compute_prior_score=prior_score)
# posterior_samples.shape = (n_datasets, n_samples, n_parameters)

This implementation is based on the compositional approach in here.

Defining CompositionalApproximator, CompositionalDiffusionModel would essentially only change the code organization and that you have to specify the correct approximator already during training (even though there is no difference) or load the trained diffusion model into the CompositionalApproximator after training before you can do inference (which would mean define a new workflow and load the model, so not too difficult). A nice point about CompositionalDiffusionModel could be, if we want to have specific standard settings which differ to the other diffusion model, e.g., one should use a stochastic sampler rather than the deterministic one.

@paul-buerkner
Copy link
Contributor

Thank you. That makes sense. Is there any practical use-case where we could want to use the same diffusion model for both standard and compositional sampling?

@arrjon
Copy link
Member Author

arrjon commented Sep 16, 2025

I think the usual case is you know from the beginning that you want to do a compositional model. Only in rare cases, where you get new data after you trained a network, you might consider switching from diffusion to compositional diffusion.

However, at the moment a CompositionalApproximator might be overkill as we only have a single inference network suitable for this task.

@paul-buerkner
Copy link
Contributor

So you would suggest a new diffusion model class but not a new approximator class. I would personally be fine with that. That said, is there anything else beyond the approximator.compositional_sample method that would have to be added to the approximator class? If not, I think keeping the existing approximator is fine and checking inside of its compositional sampling method if the employed inference network supports this method.

@stefanradev93
Copy link
Contributor

@arrjon Can you post a minimal interface example for the latest version (model definition and sampling) to discuss with the others?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants