Skip to content

Make diffusion model conditioning more flexible #521

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 24 commits into from
Jul 24, 2025
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 60 additions & 17 deletions bayesflow/networks/consistency_models/consistency_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np

from bayesflow.types import Tensor
from bayesflow.utils import find_network, layer_kwargs, weighted_mean
from bayesflow.utils import find_network, layer_kwargs, weighted_mean, tensor_utils, expand_right_as
from bayesflow.utils.serialization import deserialize, serializable, serialize

from ..inference_network import InferenceNetwork
Expand Down Expand Up @@ -67,6 +67,11 @@ def __init__(
Final number of discretization steps
subnet_kwargs: dict[str, any], optional
Keyword arguments passed to the subnet constructor or used to update the default MLP settings.
concatenate_subnet_input: bool, optional
Flag for advanced users to control whether all inputs to the subnet should be concatenated
into a single vector or passed as separate arguments. If set to False, the subnet
must accept three separate inputs: 'x' (noisy parameters), 't' (time),
and optional 'conditions'. Default is True.
**kwargs : dict, optional, default: {}
Additional keyword arguments
"""
Expand All @@ -77,6 +82,7 @@ def __init__(
subnet_kwargs = subnet_kwargs or {}
if subnet == "mlp":
subnet_kwargs = ConsistencyModel.MLP_DEFAULT_CONFIG | subnet_kwargs
self._concatenate_subnet_input = kwargs.get("concatenate_subnet_input", True)

self.subnet = find_network(subnet, **subnet_kwargs)
self.output_projector = keras.layers.Dense(
Expand Down Expand Up @@ -119,6 +125,7 @@ def get_config(self):
"eps": self.eps,
"s0": self.s0,
"s1": self.s1,
"concatenate_subnet_input": self._concatenate_subnet_input,
# we do not need to store subnet_kwargs
}

Expand Down Expand Up @@ -161,18 +168,29 @@ def build(self, xz_shape, conditions_shape=None):

input_shape = list(xz_shape)

# time vector
input_shape[-1] += 1
if self._concatenate_subnet_input:
# construct time vector
input_shape[-1] += 1
if conditions_shape is not None:
input_shape[-1] += conditions_shape[-1]
input_shape = tuple(input_shape)

if conditions_shape is not None:
input_shape[-1] += conditions_shape[-1]
self.subnet.build(input_shape)
out_shape = self.subnet.compute_output_shape(input_shape)
else:
# Multiple separate inputs
main_input_shape = xz_shape
time_shape = xz_shape[:-1] + (1,) # same batch/sequence dims, 1 feature

input_shape = tuple(input_shape)
# Build subnet with multiple input shapes
input_shape = [main_input_shape, time_shape]
if conditions_shape is not None:
input_shape.append(conditions_shape)

self.subnet.build(input_shape)
self.subnet.build(input_shape) # Pass list of shapes
out_shape = self.subnet.compute_output_shape(input_shape)

input_shape = self.subnet.compute_output_shape(input_shape)
self.output_projector.build(input_shape)
self.output_projector.build(out_shape)

# Choose coefficient according to [2] Section 3.3
self.c_huber = 0.00054 * ops.sqrt(xz_shape[-1])
Expand Down Expand Up @@ -256,6 +274,35 @@ def _inverse(self, z: Tensor, conditions: Tensor = None, training: bool = False,
x = self.consistency_function(x_n, t, conditions=conditions, training=training)
return x

def _apply_subnet(
self, x: Tensor, t: Tensor, conditions: Tensor = None, training: bool = False
) -> Tensor | tuple[Tensor, Tensor, Tensor]:
"""
Prepares and passes the input to the subnet either by concatenating the latent variable `x`,
the time `t`, and optional conditions or by returning them separately.

Parameters
----------
x : Tensor
The parameter tensor, typically of shape (..., D), but can vary.
t : Tensor
The time tensor, typically of shape (..., 1).
conditions : Tensor, optional
The optional conditioning tensor (e.g. parameters).
training : bool, optional
The training mode flag, which can be used to control behavior during training.

Returns
-------
Tensor
The output tensor from the subnet.
"""
if self._concatenate_subnet_input:
xtc = tensor_utils.concatenate_valid([x, t, conditions], axis=-1)
return self.subnet(xtc, training=training)
else:
return self.subnet(x=x, t=t, conditions=conditions, training=training)

def consistency_function(self, x: Tensor, t: Tensor, conditions: Tensor = None, training: bool = False) -> Tensor:
"""Compute consistency function.

Expand All @@ -271,12 +318,8 @@ def consistency_function(self, x: Tensor, t: Tensor, conditions: Tensor = None,
Whether internal layers (e.g., dropout) should behave in train or inference mode.
"""

if conditions is not None:
xtc = ops.concatenate([x, t, conditions], axis=-1)
else:
xtc = ops.concatenate([x, t], axis=-1)

f = self.output_projector(self.subnet(xtc, training=training))
subnet_out = self._apply_subnet(x, t, conditions, training=training)
f = self.output_projector(subnet_out)

# Compute skip and out parts (vectorized, since self.sigma2 is of shape (1, input_dim)
# Thus, we can do a cross product with the time vector which is (batch_size, 1) for
Expand Down Expand Up @@ -316,8 +359,8 @@ def compute_metrics(

log_p = ops.log(p)
times = keras.random.categorical(ops.expand_dims(log_p, 0), ops.shape(x)[0], seed=self.seed_generator)[0]
t1 = ops.take(discretized_time, times)[..., None]
t2 = ops.take(discretized_time, times + 1)[..., None]
t1 = expand_right_as(ops.take(discretized_time, times), x)
t2 = expand_right_as(ops.take(discretized_time, times + 1), x)

# generate noise vector
noise = keras.random.normal(keras.ops.shape(x), dtype=keras.ops.dtype(x), seed=self.seed_generator)
Expand Down
86 changes: 67 additions & 19 deletions bayesflow/networks/diffusion_model/diffusion_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,12 @@ def __init__(
Additional keyword arguments passed to the noise schedule constructor. Default is None.
integrate_kwargs : dict[str, any], optional
Configuration dictionary for integration during training or inference. Default is None.
concatenate_subnet_input: bool, optional
Flag for advanced users to control whether all inputs to the subnet should be concatenated
into a single vector or passed as separate arguments. If set to False, the subnet
must accept three separate inputs: 'x' (noisy parameters), 't' (log signal-to-noise ratio),
and optional 'conditions'. Default is True.

**kwargs
Additional keyword arguments passed to the base class and internal components.
"""
Expand Down Expand Up @@ -116,6 +122,7 @@ def __init__(
if subnet == "mlp":
subnet_kwargs = DiffusionModel.MLP_DEFAULT_CONFIG | subnet_kwargs
self.subnet = find_network(subnet, **subnet_kwargs)
self._concatenate_subnet_input = kwargs.get("concatenate_subnet_input", True)

self.output_projector = keras.layers.Dense(units=None, bias_initializer="zeros", name="output_projector")

Expand All @@ -128,15 +135,28 @@ def build(self, xz_shape: Shape, conditions_shape: Shape = None) -> None:
self.output_projector.units = xz_shape[-1]
input_shape = list(xz_shape)

# construct time vector
input_shape[-1] += 1
if conditions_shape is not None:
input_shape[-1] += conditions_shape[-1]
if self._concatenate_subnet_input:
# construct time vector
input_shape[-1] += 1
if conditions_shape is not None:
input_shape[-1] += conditions_shape[-1]
input_shape = tuple(input_shape)

self.subnet.build(input_shape)
out_shape = self.subnet.compute_output_shape(input_shape)
else:
# Multiple separate inputs
main_input_shape = xz_shape
time_shape = xz_shape[:-1] + (1,) # same batch/sequence dims, 1 feature

# Build subnet with multiple input shapes
input_shape = [main_input_shape, time_shape]
if conditions_shape is not None:
input_shape.append(conditions_shape)

input_shape = tuple(input_shape)
self.subnet.build(input_shape) # Pass list of shapes
out_shape = self.subnet.compute_output_shape(input_shape)

self.subnet.build(input_shape)
out_shape = self.subnet.compute_output_shape(input_shape)
self.output_projector.build(out_shape)

def get_config(self):
Expand All @@ -149,6 +169,8 @@ def get_config(self):
"prediction_type": self._prediction_type,
"loss_type": self._loss_type,
"integrate_kwargs": self.integrate_kwargs,
"concatenate_subnet_input": self._concatenate_subnet_input,
# we do not need to store subnet_kwargs
}
return base_config | serialize(config)

Expand Down Expand Up @@ -197,6 +219,35 @@ def convert_prediction_to_x(
return (z + sigma_t**2 * pred) / alpha_t
raise ValueError(f"Unknown prediction type {self._prediction_type}.")

def _apply_subnet(
self, xz: Tensor, log_snr: Tensor, conditions: Tensor = None, training: bool = False
) -> Tensor | tuple[Tensor, Tensor, Tensor]:
"""
Prepares and passes the input to the subnet either by concatenating the latent variable `xz`,
the signal-to-noise ratio `log_snr`, and optional conditions or by returning them separately.

Parameters
----------
xz : Tensor
The noisy input tensor for the diffusion model, typically of shape (..., D), but can vary.
log_snr : Tensor
The log signal-to-noise ratio tensor, typically of shape (..., 1).
conditions : Tensor, optional
The optional conditioning tensor (e.g. parameters).
training : bool, optional
The training mode flag, which can be used to control behavior during training.

Returns
-------
Tensor
The output tensor from the subnet.
"""
if self._concatenate_subnet_input:
xtc = tensor_utils.concatenate_valid([xz, log_snr, conditions], axis=-1)
return self.subnet(xtc, training=training)
else:
return self.subnet(x=xz, t=log_snr, conditions=conditions, training=training)

def velocity(
self,
xz: Tensor,
Expand All @@ -221,7 +272,7 @@ def velocity(
If True, computes the velocity for the stochastic formulation (SDE).
If False, uses the deterministic formulation (ODE).
conditions : Tensor, optional
Optional conditional inputs to the network, such as conditioning variables
Conditional inputs to the network, such as conditioning variables
or encoder outputs. Shape must be broadcastable with `xz`. Default is None.
training : bool, optional
Whether the model is in training mode. Affects behavior of dropout, batch norm,
Expand All @@ -238,12 +289,10 @@ def velocity(
log_snr_t = ops.broadcast_to(log_snr_t, ops.shape(xz)[:-1] + (1,))
alpha_t, sigma_t = self.noise_schedule.get_alpha_sigma(log_snr_t=log_snr_t)

if conditions is None:
xtc = tensor_utils.concatenate_valid([xz, self._transform_log_snr(log_snr_t)], axis=-1)
else:
xtc = tensor_utils.concatenate_valid([xz, self._transform_log_snr(log_snr_t), conditions], axis=-1)

pred = self.output_projector(self.subnet(xtc, training=training), training=training)
subnet_out = self._apply_subnet(
xz, self._transform_log_snr(log_snr_t), conditions=conditions, training=training
)
pred = self.output_projector(subnet_out, training=training)

x_pred = self.convert_prediction_to_x(pred=pred, z=xz, alpha_t=alpha_t, sigma_t=sigma_t, log_snr_t=log_snr_t)

Expand Down Expand Up @@ -461,11 +510,10 @@ def compute_metrics(
diffused_x = alpha_t * x + sigma_t * eps_t

# calculate output of the network
if conditions is None:
xtc = tensor_utils.concatenate_valid([diffused_x, self._transform_log_snr(log_snr_t)], axis=-1)
else:
xtc = tensor_utils.concatenate_valid([diffused_x, self._transform_log_snr(log_snr_t), conditions], axis=-1)
pred = self.output_projector(self.subnet(xtc, training=training), training=training)
subnet_out = self._apply_subnet(
diffused_x, self._transform_log_snr(log_snr_t), conditions=conditions, training=training
)
pred = self.output_projector(subnet_out, training=training)

x_pred = self.convert_prediction_to_x(
pred=pred, z=diffused_x, alpha_t=alpha_t, sigma_t=sigma_t, log_snr_t=log_snr_t
Expand Down
Loading
Loading