-
Notifications
You must be signed in to change notification settings - Fork 73
Make diffusion model conditioning more flexible #521
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Conversation
Codecov ReportAttention: Patch coverage is
|
Thanks for the PR, I think this is a reasonable idea for advanced use cases. As this is another instance of multi-input networks (even though it's inside the inference network this time, and does not involve the adapter), we might want to include this in the discussion in #517. We might also think about:
Tagging @LarsKue for comment as well. |
One very general approach would be to break free from the fixed names, such as "inference_variables", and actually allow for:
|
I agree @vpratz. I added the same logic to both models. I am open to suggestions how the tensors should be passed to the subnet. To @stefanradev93 comment: I think, this flexibility is only needed for advanced users. So maybe we should not follow this general approach for now, as the fixed names help users to get started with BayesFlow. |
Absolutely, this is definitely a 2.>1.x idea. |
For now, it is okay as it is @stefanradev93? |
@@ -197,6 +200,35 @@ def convert_prediction_to_x( | |||
return (z + sigma_t**2 * pred) / alpha_t | |||
raise ValueError(f"Unknown prediction type {self._prediction_type}.") | |||
|
|||
def _subnet_input( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the name and docstring do not match the behavior, it does not indicate that the input is already passed through the subnet and returns the output. I think only assembling the inputs without passing might make be the option that is easier to read. What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I fixed the docstring and the name. I think only assembling makes the rest of the code harder to read, and in this way it is the same for the different free form approximators.
I think we need to document somewhere how this can be used (i.e., which inputs are passed to the network if It would be good to have a test in place for the |
Thanks @vpratz for the suggestions. I added the documentation. Regarding the test: As we do not have a network at the moment which can handle multiple inputs, I do not know a useful test for the |
@arrjon Thanks for the changes! |
I introduced a new keyword
concatenated_input
to thesubnet_kwargs
in the diffusion model. This keyword controls how inputs—such as parameters, noise, and condition—are fed into the model’s subnet.Previously, the model assumed all inputs were 1D vectors and concatenated them directly for the default MLP subnet. However, for more flexible architectures—such as subnets designed to preserve or induce spatial structures—this assumption doesn't hold. Now we can also return all inputs separately directly to the subnet.