diff --git a/bayesflow/networks/consistency_models/consistency_model.py b/bayesflow/networks/consistency_models/consistency_model.py index 74d6acd6a..3e1778e89 100644 --- a/bayesflow/networks/consistency_models/consistency_model.py +++ b/bayesflow/networks/consistency_models/consistency_model.py @@ -4,7 +4,7 @@ import numpy as np from bayesflow.types import Tensor -from bayesflow.utils import find_network, layer_kwargs, weighted_mean +from bayesflow.utils import find_network, layer_kwargs, weighted_mean, tensor_utils, expand_right_as from bayesflow.utils.serialization import deserialize, serializable, serialize from ..inference_network import InferenceNetwork @@ -67,6 +67,11 @@ def __init__( Final number of discretization steps subnet_kwargs: dict[str, any], optional Keyword arguments passed to the subnet constructor or used to update the default MLP settings. + concatenate_subnet_input: bool, optional + Flag for advanced users to control whether all inputs to the subnet should be concatenated + into a single vector or passed as separate arguments. If set to False, the subnet + must accept three separate inputs: 'x' (noisy parameters), 't' (time), + and optional 'conditions'. Default is True. **kwargs : dict, optional, default: {} Additional keyword arguments """ @@ -77,6 +82,7 @@ def __init__( subnet_kwargs = subnet_kwargs or {} if subnet == "mlp": subnet_kwargs = ConsistencyModel.MLP_DEFAULT_CONFIG | subnet_kwargs + self._concatenate_subnet_input = kwargs.get("concatenate_subnet_input", True) self.subnet = find_network(subnet, **subnet_kwargs) self.output_projector = keras.layers.Dense( @@ -119,6 +125,7 @@ def get_config(self): "eps": self.eps, "s0": self.s0, "s1": self.s1, + "concatenate_subnet_input": self._concatenate_subnet_input, # we do not need to store subnet_kwargs } @@ -161,18 +168,23 @@ def build(self, xz_shape, conditions_shape=None): input_shape = list(xz_shape) - # time vector - input_shape[-1] += 1 + if self._concatenate_subnet_input: + # construct time vector + input_shape[-1] += 1 + if conditions_shape is not None: + input_shape[-1] += conditions_shape[-1] + input_shape = tuple(input_shape) - if conditions_shape is not None: - input_shape[-1] += conditions_shape[-1] - - input_shape = tuple(input_shape) - - self.subnet.build(input_shape) - - input_shape = self.subnet.compute_output_shape(input_shape) - self.output_projector.build(input_shape) + self.subnet.build(input_shape) + out_shape = self.subnet.compute_output_shape(input_shape) + else: + # Multiple separate inputs + time_shape = tuple(xz_shape[:-1]) + (1,) # same batch/sequence dims, 1 feature + self.subnet.build(x_shape=xz_shape, t_shape=time_shape, conditions_shape=conditions_shape) + out_shape = self.subnet.compute_output_shape( + x_shape=xz_shape, t_shape=time_shape, conditions_shape=conditions_shape + ) + self.output_projector.build(out_shape) # Choose coefficient according to [2] Section 3.3 self.c_huber = 0.00054 * ops.sqrt(xz_shape[-1]) @@ -256,6 +268,35 @@ def _inverse(self, z: Tensor, conditions: Tensor = None, training: bool = False, x = self.consistency_function(x_n, t, conditions=conditions, training=training) return x + def _apply_subnet( + self, x: Tensor, t: Tensor, conditions: Tensor = None, training: bool = False + ) -> Tensor | tuple[Tensor, Tensor, Tensor]: + """ + Prepares and passes the input to the subnet either by concatenating the latent variable `x`, + the time `t`, and optional conditions or by returning them separately. + + Parameters + ---------- + x : Tensor + The parameter tensor, typically of shape (..., D), but can vary. + t : Tensor + The time tensor, typically of shape (..., 1). + conditions : Tensor, optional + The optional conditioning tensor (e.g. parameters). + training : bool, optional + The training mode flag, which can be used to control behavior during training. + + Returns + ------- + Tensor + The output tensor from the subnet. + """ + if self._concatenate_subnet_input: + xtc = tensor_utils.concatenate_valid([x, t, conditions], axis=-1) + return self.subnet(xtc, training=training) + else: + return self.subnet(x=x, t=t, conditions=conditions, training=training) + def consistency_function(self, x: Tensor, t: Tensor, conditions: Tensor = None, training: bool = False) -> Tensor: """Compute consistency function. @@ -271,12 +312,8 @@ def consistency_function(self, x: Tensor, t: Tensor, conditions: Tensor = None, Whether internal layers (e.g., dropout) should behave in train or inference mode. """ - if conditions is not None: - xtc = ops.concatenate([x, t, conditions], axis=-1) - else: - xtc = ops.concatenate([x, t], axis=-1) - - f = self.output_projector(self.subnet(xtc, training=training)) + subnet_out = self._apply_subnet(x, t, conditions, training=training) + f = self.output_projector(subnet_out) # Compute skip and out parts (vectorized, since self.sigma2 is of shape (1, input_dim) # Thus, we can do a cross product with the time vector which is (batch_size, 1) for @@ -316,8 +353,8 @@ def compute_metrics( log_p = ops.log(p) times = keras.random.categorical(ops.expand_dims(log_p, 0), ops.shape(x)[0], seed=self.seed_generator)[0] - t1 = ops.take(discretized_time, times)[..., None] - t2 = ops.take(discretized_time, times + 1)[..., None] + t1 = expand_right_as(ops.take(discretized_time, times), x) + t2 = expand_right_as(ops.take(discretized_time, times + 1), x) # generate noise vector noise = keras.random.normal(keras.ops.shape(x), dtype=keras.ops.dtype(x), seed=self.seed_generator) diff --git a/bayesflow/networks/diffusion_model/diffusion_model.py b/bayesflow/networks/diffusion_model/diffusion_model.py index 91a05fbff..744e86b37 100644 --- a/bayesflow/networks/diffusion_model/diffusion_model.py +++ b/bayesflow/networks/diffusion_model/diffusion_model.py @@ -85,6 +85,12 @@ def __init__( Additional keyword arguments passed to the noise schedule constructor. Default is None. integrate_kwargs : dict[str, any], optional Configuration dictionary for integration during training or inference. Default is None. + concatenate_subnet_input: bool, optional + Flag for advanced users to control whether all inputs to the subnet should be concatenated + into a single vector or passed as separate arguments. If set to False, the subnet + must accept three separate inputs: 'x' (noisy parameters), 't' (log signal-to-noise ratio), + and optional 'conditions'. Default is True. + **kwargs Additional keyword arguments passed to the base class and internal components. """ @@ -116,6 +122,7 @@ def __init__( if subnet == "mlp": subnet_kwargs = DiffusionModel.MLP_DEFAULT_CONFIG | subnet_kwargs self.subnet = find_network(subnet, **subnet_kwargs) + self._concatenate_subnet_input = kwargs.get("concatenate_subnet_input", True) self.output_projector = keras.layers.Dense(units=None, bias_initializer="zeros", name="output_projector") @@ -128,15 +135,23 @@ def build(self, xz_shape: Shape, conditions_shape: Shape = None) -> None: self.output_projector.units = xz_shape[-1] input_shape = list(xz_shape) - # construct time vector - input_shape[-1] += 1 - if conditions_shape is not None: - input_shape[-1] += conditions_shape[-1] + if self._concatenate_subnet_input: + # construct time vector + input_shape[-1] += 1 + if conditions_shape is not None: + input_shape[-1] += conditions_shape[-1] + input_shape = tuple(input_shape) - input_shape = tuple(input_shape) + self.subnet.build(input_shape) + out_shape = self.subnet.compute_output_shape(input_shape) + else: + # Multiple separate inputs + time_shape = tuple(xz_shape[:-1]) + (1,) # same batch/sequence dims, 1 feature + self.subnet.build(x_shape=xz_shape, t_shape=time_shape, conditions_shape=conditions_shape) + out_shape = self.subnet.compute_output_shape( + x_shape=xz_shape, t_shape=time_shape, conditions_shape=conditions_shape + ) - self.subnet.build(input_shape) - out_shape = self.subnet.compute_output_shape(input_shape) self.output_projector.build(out_shape) def get_config(self): @@ -149,6 +164,8 @@ def get_config(self): "prediction_type": self._prediction_type, "loss_type": self._loss_type, "integrate_kwargs": self.integrate_kwargs, + "concatenate_subnet_input": self._concatenate_subnet_input, + # we do not need to store subnet_kwargs } return base_config | serialize(config) @@ -197,6 +214,35 @@ def convert_prediction_to_x( return (z + sigma_t**2 * pred) / alpha_t raise ValueError(f"Unknown prediction type {self._prediction_type}.") + def _apply_subnet( + self, xz: Tensor, log_snr: Tensor, conditions: Tensor = None, training: bool = False + ) -> Tensor | tuple[Tensor, Tensor, Tensor]: + """ + Prepares and passes the input to the subnet either by concatenating the latent variable `xz`, + the signal-to-noise ratio `log_snr`, and optional conditions or by returning them separately. + + Parameters + ---------- + xz : Tensor + The noisy input tensor for the diffusion model, typically of shape (..., D), but can vary. + log_snr : Tensor + The log signal-to-noise ratio tensor, typically of shape (..., 1). + conditions : Tensor, optional + The optional conditioning tensor (e.g. parameters). + training : bool, optional + The training mode flag, which can be used to control behavior during training. + + Returns + ------- + Tensor + The output tensor from the subnet. + """ + if self._concatenate_subnet_input: + xtc = tensor_utils.concatenate_valid([xz, log_snr, conditions], axis=-1) + return self.subnet(xtc, training=training) + else: + return self.subnet(x=xz, t=log_snr, conditions=conditions, training=training) + def velocity( self, xz: Tensor, @@ -221,7 +267,7 @@ def velocity( If True, computes the velocity for the stochastic formulation (SDE). If False, uses the deterministic formulation (ODE). conditions : Tensor, optional - Optional conditional inputs to the network, such as conditioning variables + Conditional inputs to the network, such as conditioning variables or encoder outputs. Shape must be broadcastable with `xz`. Default is None. training : bool, optional Whether the model is in training mode. Affects behavior of dropout, batch norm, @@ -238,12 +284,10 @@ def velocity( log_snr_t = ops.broadcast_to(log_snr_t, ops.shape(xz)[:-1] + (1,)) alpha_t, sigma_t = self.noise_schedule.get_alpha_sigma(log_snr_t=log_snr_t) - if conditions is None: - xtc = tensor_utils.concatenate_valid([xz, self._transform_log_snr(log_snr_t)], axis=-1) - else: - xtc = tensor_utils.concatenate_valid([xz, self._transform_log_snr(log_snr_t), conditions], axis=-1) - - pred = self.output_projector(self.subnet(xtc, training=training), training=training) + subnet_out = self._apply_subnet( + xz, self._transform_log_snr(log_snr_t), conditions=conditions, training=training + ) + pred = self.output_projector(subnet_out, training=training) x_pred = self.convert_prediction_to_x(pred=pred, z=xz, alpha_t=alpha_t, sigma_t=sigma_t, log_snr_t=log_snr_t) @@ -461,11 +505,10 @@ def compute_metrics( diffused_x = alpha_t * x + sigma_t * eps_t # calculate output of the network - if conditions is None: - xtc = tensor_utils.concatenate_valid([diffused_x, self._transform_log_snr(log_snr_t)], axis=-1) - else: - xtc = tensor_utils.concatenate_valid([diffused_x, self._transform_log_snr(log_snr_t), conditions], axis=-1) - pred = self.output_projector(self.subnet(xtc, training=training), training=training) + subnet_out = self._apply_subnet( + diffused_x, self._transform_log_snr(log_snr_t), conditions=conditions, training=training + ) + pred = self.output_projector(subnet_out, training=training) x_pred = self.convert_prediction_to_x( pred=pred, z=diffused_x, alpha_t=alpha_t, sigma_t=sigma_t, log_snr_t=log_snr_t diff --git a/bayesflow/networks/flow_matching/flow_matching.py b/bayesflow/networks/flow_matching/flow_matching.py index 645af1822..da4acd321 100644 --- a/bayesflow/networks/flow_matching/flow_matching.py +++ b/bayesflow/networks/flow_matching/flow_matching.py @@ -12,6 +12,7 @@ layer_kwargs, optimal_transport, weighted_mean, + tensor_utils, ) from bayesflow.utils.serialization import serialize, deserialize, serializable from ..inference_network import InferenceNetwork @@ -90,6 +91,11 @@ def __init__( Additional keyword arguments for configuring optimal transport. Default is None. subnet_kwargs: dict[str, any], optional, deprecated Keyword arguments passed to the subnet constructor or used to update the default MLP settings. + concatenate_subnet_input: bool, optional + Flag for advanced users to control whether all inputs to the subnet should be concatenated + into a single vector or passed as separate arguments. If set to False, the subnet + must accept three separate inputs: 'x' (noisy parameters), 't' (time), + and optional 'conditions'. Default is True. **kwargs Additional keyword arguments passed to the subnet and other components. """ @@ -107,6 +113,7 @@ def __init__( subnet_kwargs = subnet_kwargs or {} if subnet == "mlp": subnet_kwargs = FlowMatching.MLP_DEFAULT_CONFIG | subnet_kwargs + self._concatenate_subnet_input = kwargs.get("concatenate_subnet_input", True) self.subnet = find_network(subnet, **subnet_kwargs) self.output_projector = keras.layers.Dense(units=None, bias_initializer="zeros", name="output_projector") @@ -121,16 +128,25 @@ def build(self, xz_shape: Shape, conditions_shape: Shape = None) -> None: self.output_projector.units = xz_shape[-1] - # account for concatenating the time and conditions input_shape = list(xz_shape) - input_shape[-1] += 1 - if conditions_shape is not None: - input_shape[-1] += conditions_shape[-1] - input_shape = tuple(input_shape) + if self._concatenate_subnet_input: + # construct time vector + input_shape[-1] += 1 + if conditions_shape is not None: + input_shape[-1] += conditions_shape[-1] + input_shape = tuple(input_shape) + + self.subnet.build(input_shape) + out_shape = self.subnet.compute_output_shape(input_shape) + else: + # Multiple separate inputs + time_shape = tuple(xz_shape[:-1]) + (1,) # same batch/sequence dims, 1 feature + self.subnet.build(x_shape=xz_shape, t_shape=time_shape, conditions_shape=conditions_shape) + out_shape = self.subnet.compute_output_shape( + x_shape=xz_shape, t_shape=time_shape, conditions_shape=conditions_shape + ) - self.subnet.build(input_shape) - input_shape = self.subnet.compute_output_shape(input_shape) - self.output_projector.build(input_shape) + self.output_projector.build(out_shape) @classmethod def from_config(cls, config, custom_objects=None): @@ -147,22 +163,50 @@ def get_config(self): "loss_fn": self.loss_fn, "integrate_kwargs": self.integrate_kwargs, "optimal_transport_kwargs": self.optimal_transport_kwargs, + "concatenate_subnet_input": self._concatenate_subnet_input, # we do not need to store subnet_kwargs } return base_config | serialize(config) + def _apply_subnet( + self, x: Tensor, t: Tensor, conditions: Tensor = None, training: bool = False + ) -> Tensor | tuple[Tensor, Tensor, Tensor]: + """ + Prepares and passes the input to the subnet either by concatenating the latent variable `x`, + the time `t`, and optional conditions or by returning them separately. + + Parameters + ---------- + x : Tensor + The parameter tensor, typically of shape (..., D), but can vary. + t : Tensor + The time tensor, typically of shape (..., 1). + conditions : Tensor, optional + The optional conditioning tensor (e.g. parameters). + training : bool, optional + The training mode flag, which can be used to control behavior during training. + + Returns + ------- + Tensor + The output tensor from the subnet. + """ + if self._concatenate_subnet_input: + t = keras.ops.broadcast_to(t, keras.ops.shape(x)[:-1] + (1,)) + xtc = tensor_utils.concatenate_valid([x, t, conditions], axis=-1) + return self.subnet(xtc, training=training) + else: + if training is False: + t = keras.ops.broadcast_to(t, keras.ops.shape(x)[:-1] + (1,)) + return self.subnet(x=x, t=t, conditions=conditions, training=training) + def velocity(self, xz: Tensor, time: float | Tensor, conditions: Tensor = None, training: bool = False) -> Tensor: time = keras.ops.convert_to_tensor(time, dtype=keras.ops.dtype(xz)) time = expand_right_as(time, xz) - time = keras.ops.broadcast_to(time, keras.ops.shape(xz)[:-1] + (1,)) - - if conditions is None: - xtc = keras.ops.concatenate([xz, time], axis=-1) - else: - xtc = keras.ops.concatenate([xz, time, conditions], axis=-1) - return self.output_projector(self.subnet(xtc, training=training), training=training) + subnet_out = self._apply_subnet(xz, time, conditions, training=training) + return self.output_projector(subnet_out, training=training) def _velocity_trace( self, xz: Tensor, time: Tensor, conditions: Tensor = None, max_steps: int = None, training: bool = False diff --git a/tests/test_networks/conftest.py b/tests/test_networks/conftest.py index cb3f33db0..63bc317ff 100644 --- a/tests/test_networks/conftest.py +++ b/tests/test_networks/conftest.py @@ -1,6 +1,9 @@ import pytest +from collections.abc import Sequence -from bayesflow.networks import MLP +from bayesflow.networks import MLP, Sequential +from bayesflow.utils.tensor_utils import concatenate_valid +from bayesflow.utils.serialization import serializable, serialize @pytest.fixture() @@ -15,6 +18,61 @@ def diffusion_model_edm_F(): ) +@serializable("test", disable_module_check=True) +class ConcatenateMLP(Sequential): + def __init__( + self, + widths: Sequence[int] = (256, 256), + **kwargs, + ): + super().__init__(**kwargs) + self.widths = widths + self.mlp = MLP(widths) + + def call(self, x, t, conditions=None, training=False): + con = concatenate_valid([x, t, conditions], axis=-1) + return self.mlp(con) + + def compute_output_shape(self, x_shape, t_shape, conditions_shape=None): + concatenate_input_shapes = tuple( + ( + x_shape[0], + x_shape[-1] + t_shape[-1] + (conditions_shape[-1] if conditions_shape is not None else 0), + ) + ) + return self.mlp.compute_output_shape(concatenate_input_shapes) + + def build(self, x_shape, t_shape, conditions_shape=None): + if self.built: + return + + concatenate_input_shapes = tuple( + ( + x_shape[0], + x_shape[-1] + t_shape[-1] + (conditions_shape[-1] if conditions_shape is not None else 0), + ) + ) + self.mlp.build(concatenate_input_shapes) + + def get_config(self): + config = {"widths": self.widths} + + return serialize(config) + + +@pytest.fixture() +def diffusion_model_edm_F_subnet_separate_inputs(): + from bayesflow.networks import DiffusionModel + + return DiffusionModel( + subnet=ConcatenateMLP([8, 8]), + integrate_kwargs={"method": "rk45", "steps": 4}, + noise_schedule="edm", + prediction_type="F", + concatenate_subnet_input=False, + ) + + @pytest.fixture() def diffusion_model_edm_velocity(): from bayesflow.networks import DiffusionModel @@ -85,6 +143,15 @@ def flow_matching(): ) +@pytest.fixture() +def flow_matching_subnet_separate_inputs(): + from bayesflow.networks import FlowMatching + + return FlowMatching( + subnet=ConcatenateMLP([8, 8]), integrate_kwargs={"method": "rk45", "steps": 4}, concatenate_subnet_input=False + ) + + @pytest.fixture() def consistency_model(): from bayesflow.networks import ConsistencyModel @@ -92,6 +159,13 @@ def consistency_model(): return ConsistencyModel(total_steps=100, subnet=MLP([8, 8])) +@pytest.fixture() +def consistency_model_subnet_separate_inputs(): + from bayesflow.networks import ConsistencyModel + + return ConsistencyModel(total_steps=4, subnet=ConcatenateMLP([8, 8]), concatenate_subnet_input=False) + + @pytest.fixture() def affine_coupling_flow(): from bayesflow.networks import CouplingFlow @@ -196,7 +270,7 @@ def inference_network_subnet(request): "diffusion_model_edm_noise", marks=[ pytest.mark.slow, - pytest.mark.skip("noise predicition not testable without prior training for numerical reasons."), + pytest.mark.skip("noise prediction not testable without prior training for numerical reasons."), ], ), pytest.param("diffusion_model_cosine_velocity", marks=pytest.mark.slow), @@ -211,7 +285,7 @@ def inference_network_subnet(request): "diffusion_model_cosine_noise", marks=[ pytest.mark.slow, - pytest.mark.skip("noise predicition not testable without prior training for numerical reasons."), + pytest.mark.skip("noise prediction not testable without prior training for numerical reasons."), ], ), pytest.param( @@ -228,6 +302,18 @@ def generative_inference_network(request): return request.getfixturevalue(request.param) +@pytest.fixture( + params=[ + pytest.param("flow_matching_subnet_separate_inputs"), + pytest.param("consistency_model_subnet_separate_inputs"), + pytest.param("diffusion_model_edm_F_subnet_separate_inputs"), + ], + scope="function", +) +def inference_network_subnet_separate_inputs(request): + return request.getfixturevalue(request.param) + + @pytest.fixture(scope="function") def time_series_network(summary_dim): from bayesflow.networks import TimeSeriesNetwork diff --git a/tests/test_networks/test_inference_networks.py b/tests/test_networks/test_inference_networks.py index a16743f69..c0035e0f2 100644 --- a/tests/test_networks/test_inference_networks.py +++ b/tests/test_networks/test_inference_networks.py @@ -162,3 +162,16 @@ def test_compute_metrics(inference_network, random_samples, random_conditions): metrics = inference_network.compute_metrics(random_samples, conditions=random_conditions) assert "loss" in metrics + + +def test_subnet_separate_inputs(inference_network_subnet_separate_inputs, random_samples, random_conditions): + xz_shape = keras.ops.shape(random_samples) + conditions_shape = keras.ops.shape(random_conditions) if random_conditions is not None else None + inference_network_subnet_separate_inputs.build(xz_shape, conditions_shape) + + assert inference_network_subnet_separate_inputs.built is True + + # check the model has variables + assert inference_network_subnet_separate_inputs.variables, "Model has no variables." + + inference_network_subnet_separate_inputs(random_samples, random_conditions, inverse=True)