Skip to content

Make diffusion model conditioning more flexible #521

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: dev
Choose a base branch
from
50 changes: 41 additions & 9 deletions bayesflow/networks/consistency_models/consistency_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np

from bayesflow.types import Tensor
from bayesflow.utils import find_network, layer_kwargs, weighted_mean
from bayesflow.utils import find_network, layer_kwargs, weighted_mean, tensor_utils, expand_right_as
from bayesflow.utils.serialization import deserialize, serializable, serialize

from ..inference_network import InferenceNetwork
Expand Down Expand Up @@ -67,6 +67,11 @@ def __init__(
Final number of discretization steps
subnet_kwargs: dict[str, any], optional
Keyword arguments passed to the subnet constructor or used to update the default MLP settings.
concatenate_subnet_input: bool, optional
Flag for advanced users to control whether all inputs to the subnet should be concatenated
into a single vector or passed as separate arguments. If set to False, the subnet
must accept three separate inputs: 'x' (noisy parameters), 't' (time),
and optional 'conditions'. Default is True.
**kwargs : dict, optional, default: {}
Additional keyword arguments
"""
Expand All @@ -77,6 +82,7 @@ def __init__(
subnet_kwargs = subnet_kwargs or {}
if subnet == "mlp":
subnet_kwargs = ConsistencyModel.MLP_DEFAULT_CONFIG | subnet_kwargs
self._concatenate_subnet_input = kwargs.get("concatenate_subnet_input", True)

self.subnet = find_network(subnet, **subnet_kwargs)
self.output_projector = keras.layers.Dense(
Expand Down Expand Up @@ -119,6 +125,7 @@ def get_config(self):
"eps": self.eps,
"s0": self.s0,
"s1": self.s1,
"concatenate_subnet_input": self._concatenate_subnet_input,
# we do not need to store subnet_kwargs
}

Expand Down Expand Up @@ -256,6 +263,35 @@ def _inverse(self, z: Tensor, conditions: Tensor = None, training: bool = False,
x = self.consistency_function(x_n, t, conditions=conditions, training=training)
return x

def _apply_subnet(
self, x: Tensor, t: Tensor, conditions: Tensor = None, training: bool = False
) -> Tensor | tuple[Tensor, Tensor, Tensor]:
"""
Prepares and passes the input to the subnet either by concatenating the latent variable `x`,
the time `t`, and optional conditions or by returning them separately.

Parameters
----------
x : Tensor
The parameter tensor, typically of shape (..., D), but can vary.
t : Tensor
The time tensor, typically of shape (..., 1).
conditions : Tensor, optional
The optional conditioning tensor (e.g. parameters).
training : bool, optional
The training mode flag, which can be used to control behavior during training.

Returns
-------
Tensor
The output tensor from the subnet.
"""
if self._concatenate_subnet_input:
xtc = tensor_utils.concatenate_valid([x, t, conditions], axis=-1)
return self.subnet(xtc, training=training)
else:
return self.subnet(x=x, t=t, conditions=conditions, training=training)

def consistency_function(self, x: Tensor, t: Tensor, conditions: Tensor = None, training: bool = False) -> Tensor:
"""Compute consistency function.

Expand All @@ -271,12 +307,8 @@ def consistency_function(self, x: Tensor, t: Tensor, conditions: Tensor = None,
Whether internal layers (e.g., dropout) should behave in train or inference mode.
"""

if conditions is not None:
xtc = ops.concatenate([x, t, conditions], axis=-1)
else:
xtc = ops.concatenate([x, t], axis=-1)

f = self.output_projector(self.subnet(xtc, training=training))
subnet_out = self._apply_subnet(x, t, conditions, training=training)
f = self.output_projector(subnet_out)

# Compute skip and out parts (vectorized, since self.sigma2 is of shape (1, input_dim)
# Thus, we can do a cross product with the time vector which is (batch_size, 1) for
Expand Down Expand Up @@ -316,8 +348,8 @@ def compute_metrics(

log_p = ops.log(p)
times = keras.random.categorical(ops.expand_dims(log_p, 0), ops.shape(x)[0], seed=self.seed_generator)[0]
t1 = ops.take(discretized_time, times)[..., None]
t2 = ops.take(discretized_time, times + 1)[..., None]
t1 = expand_right_as(ops.take(discretized_time, times), x)
t2 = expand_right_as(ops.take(discretized_time, times + 1), x)

# generate noise vector
noise = keras.random.normal(keras.ops.shape(x), dtype=keras.ops.dtype(x), seed=self.seed_generator)
Expand Down
59 changes: 47 additions & 12 deletions bayesflow/networks/diffusion_model/diffusion_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,12 @@ def __init__(
Additional keyword arguments passed to the noise schedule constructor. Default is None.
integrate_kwargs : dict[str, any], optional
Configuration dictionary for integration during training or inference. Default is None.
concatenate_subnet_input: bool, optional
Flag for advanced users to control whether all inputs to the subnet should be concatenated
into a single vector or passed as separate arguments. If set to False, the subnet
must accept three separate inputs: 'x' (noisy parameters), 't' (log signal-to-noise ratio),
and optional 'conditions'. Default is True.

**kwargs
Additional keyword arguments passed to the base class and internal components.
"""
Expand Down Expand Up @@ -116,6 +122,7 @@ def __init__(
if subnet == "mlp":
subnet_kwargs = DiffusionModel.MLP_DEFAULT_CONFIG | subnet_kwargs
self.subnet = find_network(subnet, **subnet_kwargs)
self._concatenate_subnet_input = kwargs.get("concatenate_subnet_input", True)

self.output_projector = keras.layers.Dense(units=None, bias_initializer="zeros", name="output_projector")

Expand Down Expand Up @@ -149,6 +156,8 @@ def get_config(self):
"prediction_type": self._prediction_type,
"loss_type": self._loss_type,
"integrate_kwargs": self.integrate_kwargs,
"concatenate_subnet_input": self._concatenate_subnet_input,
# we do not need to store subnet_kwargs
}
return base_config | serialize(config)

Expand Down Expand Up @@ -197,6 +206,35 @@ def convert_prediction_to_x(
return (z + sigma_t**2 * pred) / alpha_t
raise ValueError(f"Unknown prediction type {self._prediction_type}.")

def _apply_subnet(
self, xz: Tensor, log_snr: Tensor, conditions: Tensor = None, training: bool = False
) -> Tensor | tuple[Tensor, Tensor, Tensor]:
"""
Prepares and passes the input to the subnet either by concatenating the latent variable `xz`,
the signal-to-noise ratio `log_snr`, and optional conditions or by returning them separately.

Parameters
----------
xz : Tensor
The noisy input tensor for the diffusion model, typically of shape (..., D), but can vary.
log_snr : Tensor
The log signal-to-noise ratio tensor, typically of shape (..., 1).
conditions : Tensor, optional
The optional conditioning tensor (e.g. parameters).
training : bool, optional
The training mode flag, which can be used to control behavior during training.

Returns
-------
Tensor
The output tensor from the subnet.
"""
if self._concatenate_subnet_input:
xtc = tensor_utils.concatenate_valid([xz, log_snr, conditions], axis=-1)
return self.subnet(xtc, training=training)
else:
return self.subnet(x=xz, t=log_snr, conditions=conditions, training=training)

def velocity(
self,
xz: Tensor,
Expand All @@ -221,7 +259,7 @@ def velocity(
If True, computes the velocity for the stochastic formulation (SDE).
If False, uses the deterministic formulation (ODE).
conditions : Tensor, optional
Optional conditional inputs to the network, such as conditioning variables
Conditional inputs to the network, such as conditioning variables
or encoder outputs. Shape must be broadcastable with `xz`. Default is None.
training : bool, optional
Whether the model is in training mode. Affects behavior of dropout, batch norm,
Expand All @@ -238,12 +276,10 @@ def velocity(
log_snr_t = ops.broadcast_to(log_snr_t, ops.shape(xz)[:-1] + (1,))
alpha_t, sigma_t = self.noise_schedule.get_alpha_sigma(log_snr_t=log_snr_t)

if conditions is None:
xtc = tensor_utils.concatenate_valid([xz, self._transform_log_snr(log_snr_t)], axis=-1)
else:
xtc = tensor_utils.concatenate_valid([xz, self._transform_log_snr(log_snr_t), conditions], axis=-1)

pred = self.output_projector(self.subnet(xtc, training=training), training=training)
subnet_out = self._apply_subnet(
xz, self._transform_log_snr(log_snr_t), conditions=conditions, training=training
)
pred = self.output_projector(subnet_out, training=training)

x_pred = self.convert_prediction_to_x(pred=pred, z=xz, alpha_t=alpha_t, sigma_t=sigma_t, log_snr_t=log_snr_t)

Expand Down Expand Up @@ -461,11 +497,10 @@ def compute_metrics(
diffused_x = alpha_t * x + sigma_t * eps_t

# calculate output of the network
if conditions is None:
xtc = tensor_utils.concatenate_valid([diffused_x, self._transform_log_snr(log_snr_t)], axis=-1)
else:
xtc = tensor_utils.concatenate_valid([diffused_x, self._transform_log_snr(log_snr_t), conditions], axis=-1)
pred = self.output_projector(self.subnet(xtc, training=training), training=training)
subnet_out = self._apply_subnet(
diffused_x, self._transform_log_snr(log_snr_t), conditions=conditions, training=training
)
pred = self.output_projector(subnet_out, training=training)

x_pred = self.convert_prediction_to_x(
pred=pred, z=diffused_x, alpha_t=alpha_t, sigma_t=sigma_t, log_snr_t=log_snr_t
Expand Down
49 changes: 42 additions & 7 deletions bayesflow/networks/flow_matching/flow_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
layer_kwargs,
optimal_transport,
weighted_mean,
tensor_utils,
)
from bayesflow.utils.serialization import serialize, deserialize, serializable
from ..inference_network import InferenceNetwork
Expand Down Expand Up @@ -90,6 +91,11 @@ def __init__(
Additional keyword arguments for configuring optimal transport. Default is None.
subnet_kwargs: dict[str, any], optional, deprecated
Keyword arguments passed to the subnet constructor or used to update the default MLP settings.
concatenate_subnet_input: bool, optional
Flag for advanced users to control whether all inputs to the subnet should be concatenated
into a single vector or passed as separate arguments. If set to False, the subnet
must accept three separate inputs: 'x' (noisy parameters), 't' (time),
and optional 'conditions'. Default is True.
**kwargs
Additional keyword arguments passed to the subnet and other components.
"""
Expand All @@ -107,6 +113,7 @@ def __init__(
subnet_kwargs = subnet_kwargs or {}
if subnet == "mlp":
subnet_kwargs = FlowMatching.MLP_DEFAULT_CONFIG | subnet_kwargs
self._concatenate_subnet_input = kwargs.get("concatenate_subnet_input", True)

self.subnet = find_network(subnet, **subnet_kwargs)
self.output_projector = keras.layers.Dense(units=None, bias_initializer="zeros", name="output_projector")
Expand Down Expand Up @@ -147,22 +154,50 @@ def get_config(self):
"loss_fn": self.loss_fn,
"integrate_kwargs": self.integrate_kwargs,
"optimal_transport_kwargs": self.optimal_transport_kwargs,
"concatenate_subnet_input": self._concatenate_subnet_input,
# we do not need to store subnet_kwargs
}

return base_config | serialize(config)

def _apply_subnet(
self, x: Tensor, t: Tensor, conditions: Tensor = None, training: bool = False
) -> Tensor | tuple[Tensor, Tensor, Tensor]:
"""
Prepares and passes the input to the subnet either by concatenating the latent variable `x`,
the time `t`, and optional conditions or by returning them separately.

Parameters
----------
x : Tensor
The parameter tensor, typically of shape (..., D), but can vary.
t : Tensor
The time tensor, typically of shape (..., 1).
conditions : Tensor, optional
The optional conditioning tensor (e.g. parameters).
training : bool, optional
The training mode flag, which can be used to control behavior during training.

Returns
-------
Tensor
The output tensor from the subnet.
"""
if self._concatenate_subnet_input:
t = keras.ops.broadcast_to(t, keras.ops.shape(x)[:-1] + (1,))
xtc = tensor_utils.concatenate_valid([x, t, conditions], axis=-1)
return self.subnet(xtc, training=training)
else:
if training is False:
t = keras.ops.broadcast_to(t, keras.ops.shape(x)[:-1] + (1,))
return self.subnet(x=x, t=t, conditions=conditions, training=training)

def velocity(self, xz: Tensor, time: float | Tensor, conditions: Tensor = None, training: bool = False) -> Tensor:
time = keras.ops.convert_to_tensor(time, dtype=keras.ops.dtype(xz))
time = expand_right_as(time, xz)
time = keras.ops.broadcast_to(time, keras.ops.shape(xz)[:-1] + (1,))

if conditions is None:
xtc = keras.ops.concatenate([xz, time], axis=-1)
else:
xtc = keras.ops.concatenate([xz, time, conditions], axis=-1)

return self.output_projector(self.subnet(xtc, training=training), training=training)
subnet_out = self._apply_subnet(xz, time, conditions, training=training)
return self.output_projector(subnet_out, training=training)

def _velocity_trace(
self, xz: Tensor, time: Tensor, conditions: Tensor = None, max_steps: int = None, training: bool = False
Expand Down
50 changes: 48 additions & 2 deletions tests/test_networks/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import pytest

from bayesflow.networks import MLP
from bayesflow.utils.tensor_utils import concatenate_valid
import keras


@pytest.fixture()
Expand All @@ -15,6 +17,30 @@ def diffusion_model_edm_F():
)


class ConcatenateMLP(keras.Layer):
def __init__(self, widths):
super().__init__()
self.widths = widths
self.mlp = MLP(widths)

def call(self, x, t, conditions=None, training=False):
con = concatenate_valid([x, t, conditions], axis=-1)
return self.mlp(con)


@pytest.fixture()
def diffusion_model_edm_F_subnet_concatenate():
from bayesflow.networks import DiffusionModel

return DiffusionModel(
subnet=ConcatenateMLP([8, 8]),
integrate_kwargs={"method": "rk45", "steps": 250},
noise_schedule="edm",
prediction_type="F",
concatenate_subnet_input=False,
)


@pytest.fixture()
def diffusion_model_edm_velocity():
from bayesflow.networks import DiffusionModel
Expand Down Expand Up @@ -85,13 +111,29 @@ def flow_matching():
)


@pytest.fixture()
def flow_matching_subnet_concatenate():
from bayesflow.networks import FlowMatching

return FlowMatching(
subnet=ConcatenateMLP([8, 8]), integrate_kwargs={"method": "rk45", "steps": 100}, concatenate_subnet_input=False
)


@pytest.fixture()
def consistency_model():
from bayesflow.networks import ConsistencyModel

return ConsistencyModel(total_steps=100, subnet=MLP([8, 8]))


@pytest.fixture()
def consistency_model_subnet_concatenate():
from bayesflow.networks import ConsistencyModel

return ConsistencyModel(total_steps=100, subnet=ConcatenateMLP([8, 8]), concatenate_subnet_input=False)


@pytest.fixture()
def affine_coupling_flow():
from bayesflow.networks import CouplingFlow
Expand Down Expand Up @@ -186,17 +228,21 @@ def inference_network_subnet(request):

@pytest.fixture(
params=[
pytest.param("diffusion_model_edm_F_subnet_concatenate"),
"affine_coupling_flow",
"spline_coupling_flow",
"flow_matching",
pytest.param("flow_matching_subnet_concatenate"),
"free_form_flow",
"consistency_model",
pytest.param("consistency_model_subnet_concatenate"),
pytest.param("diffusion_model_edm_F"),
pytest.param("diffusion_model_edm_F_subnet_concatenate"),
pytest.param(
"diffusion_model_edm_noise",
marks=[
pytest.mark.slow,
pytest.mark.skip("noise predicition not testable without prior training for numerical reasons."),
pytest.mark.skip("noise prediction not testable without prior training for numerical reasons."),
],
),
pytest.param("diffusion_model_cosine_velocity", marks=pytest.mark.slow),
Expand All @@ -211,7 +257,7 @@ def inference_network_subnet(request):
"diffusion_model_cosine_noise",
marks=[
pytest.mark.slow,
pytest.mark.skip("noise predicition not testable without prior training for numerical reasons."),
pytest.mark.skip("noise prediction not testable without prior training for numerical reasons."),
],
),
pytest.param(
Expand Down