Skip to content

Make diffusion model conditioning more flexible #521

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: dev
Choose a base branch
from

Conversation

arrjon
Copy link
Member

@arrjon arrjon commented Jun 24, 2025

I introduced a new keyword concatenated_input to the subnet_kwargs in the diffusion model. This keyword controls how inputs—such as parameters, noise, and condition—are fed into the model’s subnet.

Previously, the model assumed all inputs were 1D vectors and concatenated them directly for the default MLP subnet. However, for more flexible architectures—such as subnets designed to preserve or induce spatial structures—this assumption doesn't hold. Now we can also return all inputs separately directly to the subnet.

@arrjon arrjon requested review from stefanradev93 and vpratz June 24, 2025 13:41
@arrjon arrjon self-assigned this Jun 24, 2025
Copy link

codecov bot commented Jun 24, 2025

@vpratz
Copy link
Collaborator

vpratz commented Jun 24, 2025

Thanks for the PR, I think this is a reasonable idea for advanced use cases. As this is another instance of multi-input networks (even though it's inside the inference network this time, and does not involve the adapter), we might want to include this in the discussion in #517. We might also think about:

  • how to pass inputs: named (as a dictionary), or via position (as a tuple like you propose in the PR)
  • Consistency in our multi-step models. If we offer this possibility here, we might want to do the same in flow matching and consistency models.

Tagging @LarsKue for comment as well.

@stefanradev93
Copy link
Contributor

stefanradev93 commented Jun 24, 2025

One very general approach would be to break free from the fixed names, such as "inference_variables", and actually allow for:

  • Marking different simulator outputs as either target variables, summary variables or inference conditions
  • Selecting a strategy for how different outputs of the same type are handled (e.g., concatenated, passed as a tuple, or passed as keyword arguments)
    This can be handled with another abstraction, such as SimulatorOutput with a flexible scheme.

@arrjon
Copy link
Member Author

arrjon commented Jun 24, 2025

Consistency in our multi-step models. If we offer this possibility here, we might want to do the same in flow matching and consistency models.

I agree @vpratz. I added the same logic to both models. I am open to suggestions how the tensors should be passed to the subnet.

To @stefanradev93 comment: I think, this flexibility is only needed for advanced users. So maybe we should not follow this general approach for now, as the fixed names help users to get started with BayesFlow.

@stefanradev93
Copy link
Contributor

Consistency in our multi-step models. If we offer this possibility here, we might want to do the same in flow matching and consistency models.

I agree @vpratz. I added the same logic to both models. I am open to suggestions how the tensors should be passed to the subnet.

To @stefanradev93 comment: I think, this flexibility is only needed for advanced users. So maybe we should not follow this general approach for now, as the fixed names help users to get started with BayesFlow.

Absolutely, this is definitely a 2.>1.x idea.

@arrjon
Copy link
Member Author

arrjon commented Jul 8, 2025

For now, it is okay as it is @stefanradev93?

@@ -197,6 +200,35 @@ def convert_prediction_to_x(
return (z + sigma_t**2 * pred) / alpha_t
raise ValueError(f"Unknown prediction type {self._prediction_type}.")

def _subnet_input(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the name and docstring do not match the behavior, it does not indicate that the input is already passed through the subnet and returns the output. I think only assembling the inputs without passing might make be the option that is easier to read. What do you think?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I fixed the docstring and the name. I think only assembling makes the rest of the code harder to read, and in this way it is the same for the different free form approximators.

@vpratz
Copy link
Collaborator

vpratz commented Jul 8, 2025

I think we need to document somewhere how this can be used (i.e., which inputs are passed to the network if concatenate_subnet_input is False`), as this currently only exists in the code itself. I'd suggest we pass the inputs by name, as this eases up communication (only names, no order).

It would be good to have a test in place for the concatenate_subnet_input is False case.

@arrjon
Copy link
Member Author

arrjon commented Jul 9, 2025

Thanks @vpratz for the suggestions. I added the documentation.

Regarding the test: As we do not have a network at the moment which can handle multiple inputs, I do not know a useful test for the concatenate_subnet_input=False case. Any suggestions?

@vpratz
Copy link
Collaborator

vpratz commented Jul 11, 2025

@arrjon Thanks for the changes!
I would propose to add a simple wrapper network that can handle the case to the test suite, similar to other dummy networks we have. It could just take the separate inputs, concatenate them and pass them to any other network as usual. The main point for me is that we are able to notice if we accidentally break the functionality, so a basic dedicated test should be sufficient.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants