Skip to content

Commit 8d296e9

Browse files
committed
fix: move additional metrics from approximator to networks
Supplying the additional metrics for inference and summary networks via the approximators compile method caused problems during deseralization (#497). This can be resolved nicely by moving the metrics directly to the networks' constructors, analogous to how Keras normally handles custom metrics in layers. As summary networks and inference networks inherit from the respective base classes, this change only requires minor adaptations.
1 parent 1ea451b commit 8d296e9

File tree

10 files changed

+79
-47
lines changed

10 files changed

+79
-47
lines changed

bayesflow/approximators/continuous_approximator.py

Lines changed: 14 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
from collections.abc import Mapping, Sequence, Callable
22

33
import numpy as np
4+
import warnings
45

56
import keras
67

78
from bayesflow.adapters import Adapter
89
from bayesflow.networks import InferenceNetwork, SummaryNetwork
910
from bayesflow.types import Tensor
10-
from bayesflow.utils import filter_kwargs, logging, split_arrays, squeeze_inner_estimates_dict
11+
from bayesflow.utils import filter_kwargs, split_arrays, squeeze_inner_estimates_dict
1112
from bayesflow.utils.serialization import serialize, deserialize, serializable
1213

1314
from .approximator import Approximator
@@ -97,18 +98,21 @@ def build_adapter(
9798
def compile(
9899
self,
99100
*args,
100-
inference_metrics: Sequence[keras.Metric] = None,
101-
summary_metrics: Sequence[keras.Metric] = None,
102101
**kwargs,
103102
):
104-
if inference_metrics:
105-
self.inference_network._metrics = inference_metrics
103+
if "inference_metrics" in kwargs:
104+
warnings.warn(
105+
"Supplying inference metrics to the approximator is no longer supported. "
106+
"Please pass the metrics directly to the network using the metrics parameter.",
107+
DeprecationWarning,
108+
)
106109

107-
if summary_metrics:
108-
if self.summary_network is None:
109-
logging.warning("Ignoring summary metrics because there is no summary network.")
110-
else:
111-
self.summary_network._metrics = summary_metrics
110+
if "summary_metrics" in kwargs:
111+
warnings.warn(
112+
"Supplying summary metrics to the approximator is no longer supported. "
113+
"Please pass the metrics directly to the network using the metrics parameter.",
114+
DeprecationWarning,
115+
)
112116

113117
return super().compile(*args, **kwargs)
114118

@@ -227,16 +231,6 @@ def get_config(self):
227231

228232
return base_config | serialize(config)
229233

230-
def get_compile_config(self):
231-
base_config = super().get_compile_config() or {}
232-
233-
config = {
234-
"inference_metrics": self.inference_network._metrics,
235-
"summary_metrics": self.summary_network._metrics if self.summary_network is not None else None,
236-
}
237-
238-
return base_config | serialize(config)
239-
240234
def estimate(
241235
self,
242236
conditions: Mapping[str, np.ndarray],

bayesflow/experimental/diffusion_model/diffusion_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
expand_right_as,
1111
find_network,
1212
jacobian_trace,
13-
layer_kwargs,
1413
weighted_mean,
1514
integrate,
1615
integrate_stochastic,
@@ -141,7 +140,8 @@ def build(self, xz_shape: Shape, conditions_shape: Shape = None) -> None:
141140

142141
def get_config(self):
143142
base_config = super().get_config()
144-
base_config = layer_kwargs(base_config)
143+
# base distribution is fixed and passed in constructor
144+
base_config.pop("base_distribution")
145145

146146
config = {
147147
"subnet": self.subnet,

bayesflow/networks/consistency_models/consistency_model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import numpy as np
55

66
from bayesflow.types import Tensor
7-
from bayesflow.utils import find_network, layer_kwargs, weighted_mean
7+
from bayesflow.utils import find_network, weighted_mean
88
from bayesflow.utils.serialization import deserialize, serializable, serialize
99

1010
from ..inference_network import InferenceNetwork
@@ -109,7 +109,8 @@ def from_config(cls, config, custom_objects=None):
109109

110110
def get_config(self):
111111
base_config = super().get_config()
112-
base_config = layer_kwargs(base_config)
112+
# base distribution is fixed and passed in constructor
113+
base_config.pop("base_distribution")
113114

114115
config = {
115116
"total_steps": self.total_steps,

bayesflow/networks/coupling_flow/coupling_flow.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from bayesflow.types import Tensor
44
from bayesflow.utils import (
55
find_permutation,
6-
layer_kwargs,
76
weighted_mean,
87
)
98
from bayesflow.utils.serialization import deserialize, serializable, serialize
@@ -131,7 +130,6 @@ def from_config(cls, config, custom_objects=None):
131130

132131
def get_config(self):
133132
base_config = super().get_config()
134-
base_config = layer_kwargs(base_config)
135133

136134
config = {
137135
"subnet": self.subnet,

bayesflow/networks/flow_matching/flow_matching.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
find_network,
1010
integrate,
1111
jacobian_trace,
12-
layer_kwargs,
1312
optimal_transport,
1413
weighted_mean,
1514
)
@@ -138,7 +137,6 @@ def from_config(cls, config, custom_objects=None):
138137

139138
def get_config(self):
140139
base_config = super().get_config()
141-
base_config = layer_kwargs(base_config)
142140

143141
config = {
144142
"subnet": self.subnet,

bayesflow/networks/inference_network.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
import keras
2+
from collections.abc import Sequence
23

34
from bayesflow.types import Shape, Tensor
45
from bayesflow.utils import layer_kwargs, find_distribution
56
from bayesflow.utils.decorators import allow_batch_size
7+
from bayesflow.utils.serialization import deserialize, serializable, serialize
68

79

10+
@serializable("bayesflow.networks")
811
class InferenceNetwork(keras.Layer):
9-
def __init__(self, base_distribution: str = "normal", **kwargs):
12+
def __init__(self, base_distribution: str = "normal", *, metrics: Sequence[keras.Metric] = None, **kwargs):
13+
self.custom_metrics = metrics
1014
super().__init__(**layer_kwargs(kwargs))
1115
self.base_distribution = find_distribution(base_distribution)
1216

@@ -72,3 +76,13 @@ def compute_metrics(
7276
metrics[metric.name] = metric(samples, x)
7377

7478
return metrics
79+
80+
def get_config(self):
81+
base_config = super().get_config()
82+
base_config = layer_kwargs(base_config)
83+
config = {"metrics": self.custom_metrics, "base_distribution": self.base_distribution}
84+
return base_config | serialize(config)
85+
86+
@classmethod
87+
def from_config(cls, config, custom_objects=None):
88+
return cls(**deserialize(config, custom_objects=custom_objects))

bayesflow/networks/summary_network.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
11
import keras
2+
from collections.abc import Sequence
23

34
from bayesflow.metrics.functional import maximum_mean_discrepancy
45
from bayesflow.types import Tensor
56
from bayesflow.utils import layer_kwargs, find_distribution
67
from bayesflow.utils.decorators import sanitize_input_shape
7-
from bayesflow.utils.serialization import deserialize
8+
from bayesflow.utils.serialization import deserialize, serializable, serialize
89

910

11+
@serializable("bayesflow.networks")
1012
class SummaryNetwork(keras.Layer):
11-
def __init__(self, base_distribution: str = None, **kwargs):
13+
def __init__(self, base_distribution: str = None, *, metrics: Sequence[keras.Metric] = None, **kwargs):
14+
self.custom_metrics = metrics
1215
super().__init__(**layer_kwargs(kwargs))
1316
self.base_distribution = find_distribution(base_distribution)
1417

@@ -17,7 +20,7 @@ def build(self, input_shape):
1720
x = keras.ops.zeros(input_shape)
1821
z = self.call(x)
1922

20-
if self.base_distribution is not None:
23+
if self.base_distribution is not None and not self.base_distribution.built:
2124
self.base_distribution.build(keras.ops.shape(z))
2225

2326
@sanitize_input_shape
@@ -51,6 +54,12 @@ def compute_metrics(self, x: Tensor, stage: str = "training", **kwargs) -> dict[
5154

5255
return metrics
5356

57+
def get_config(self):
58+
base_config = super().get_config()
59+
base_config = layer_kwargs(base_config)
60+
config = {"base_distribution": self.base_distribution, "metrics": self.custom_metrics}
61+
return base_config | serialize(config)
62+
5463
@classmethod
5564
def from_config(cls, config, custom_objects=None):
5665
return cls(**deserialize(config, custom_objects=custom_objects))

tests/test_networks/conftest.py

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import pytest
22

33
from bayesflow.networks import MLP
4+
from bayesflow.metrics import RootMeanSquaredError
45

56

67
@pytest.fixture()
@@ -12,6 +13,7 @@ def diffusion_model_edm_F():
1213
integrate_kwargs={"method": "rk45", "steps": 250},
1314
noise_schedule="edm",
1415
prediction_type="F",
16+
metrics=[RootMeanSquaredError()],
1517
)
1618

1719

@@ -82,22 +84,32 @@ def flow_matching():
8284
return FlowMatching(
8385
subnet=MLP([8, 8]),
8486
integrate_kwargs={"method": "rk45", "steps": 100},
87+
metrics=[RootMeanSquaredError()],
8588
)
8689

8790

8891
@pytest.fixture()
8992
def consistency_model():
9093
from bayesflow.networks import ConsistencyModel
9194

92-
return ConsistencyModel(total_steps=100, subnet=MLP([8, 8]))
95+
return ConsistencyModel(
96+
total_steps=100,
97+
subnet=MLP([8, 8]),
98+
metrics=[RootMeanSquaredError()],
99+
)
93100

94101

95102
@pytest.fixture()
96103
def affine_coupling_flow():
97104
from bayesflow.networks import CouplingFlow
98105

99106
return CouplingFlow(
100-
depth=2, subnet="mlp", subnet_kwargs=dict(widths=[8, 8]), transform="affine", transform_kwargs=dict(clamp=1.8)
107+
depth=2,
108+
subnet="mlp",
109+
subnet_kwargs=dict(widths=[8, 8]),
110+
transform="affine",
111+
transform_kwargs=dict(clamp=1.8),
112+
metrics=[RootMeanSquaredError()],
101113
)
102114

103115

@@ -106,15 +118,24 @@ def spline_coupling_flow():
106118
from bayesflow.networks import CouplingFlow
107119

108120
return CouplingFlow(
109-
depth=2, subnet="mlp", subnet_kwargs=dict(widths=[8, 8]), transform="spline", transform_kwargs=dict(bins=8)
121+
depth=2,
122+
subnet="mlp",
123+
subnet_kwargs=dict(widths=[8, 8]),
124+
transform="spline",
125+
transform_kwargs=dict(bins=8),
126+
metrics=[RootMeanSquaredError()],
110127
)
111128

112129

113130
@pytest.fixture()
114131
def free_form_flow():
115132
from bayesflow.experimental import FreeFormFlow
116133

117-
return FreeFormFlow(encoder_subnet=MLP([16, 16]), decoder_subnet=MLP([16, 16]))
134+
return FreeFormFlow(
135+
encoder_subnet=MLP([16, 16]),
136+
decoder_subnet=MLP([16, 16]),
137+
metrics=[RootMeanSquaredError()],
138+
)
118139

119140

120141
@pytest.fixture()
@@ -236,35 +257,35 @@ def generative_inference_network(request):
236257
def time_series_network(summary_dim):
237258
from bayesflow.networks import TimeSeriesNetwork
238259

239-
return TimeSeriesNetwork(summary_dim=summary_dim)
260+
return TimeSeriesNetwork(summary_dim=summary_dim, metrics=[RootMeanSquaredError()])
240261

241262

242263
@pytest.fixture(scope="function")
243264
def time_series_transformer(summary_dim):
244265
from bayesflow.networks import TimeSeriesTransformer
245266

246-
return TimeSeriesTransformer(summary_dim=summary_dim)
267+
return TimeSeriesTransformer(summary_dim=summary_dim, metrics=[RootMeanSquaredError()])
247268

248269

249270
@pytest.fixture(scope="function")
250271
def fusion_transformer(summary_dim):
251272
from bayesflow.networks import FusionTransformer
252273

253-
return FusionTransformer(summary_dim=summary_dim)
274+
return FusionTransformer(summary_dim=summary_dim, metrics=[RootMeanSquaredError()])
254275

255276

256277
@pytest.fixture(scope="function")
257278
def set_transformer(summary_dim):
258279
from bayesflow.networks import SetTransformer
259280

260-
return SetTransformer(summary_dim=summary_dim)
281+
return SetTransformer(summary_dim=summary_dim, metrics=[RootMeanSquaredError()])
261282

262283

263284
@pytest.fixture(scope="function")
264285
def deep_set(summary_dim):
265286
from bayesflow.networks import DeepSet
266287

267-
return DeepSet(summary_dim=summary_dim)
288+
return DeepSet(summary_dim=summary_dim, metrics=[RootMeanSquaredError()])
268289

269290

270291
@pytest.fixture(

tests/test_two_moons/conftest.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44
@pytest.fixture()
55
def inference_network():
66
from bayesflow.networks import CouplingFlow
7+
from bayesflow.metrics import MaximumMeanDiscrepancy
78

8-
return CouplingFlow(depth=2, subnet="mlp", subnet_kwargs=dict(widths=(32, 32)))
9+
return CouplingFlow(depth=2, subnet="mlp", subnet_kwargs=dict(widths=(32, 32)), metrics=[MaximumMeanDiscrepancy()])
910

1011

1112
@pytest.fixture()

tests/test_two_moons/test_two_moons.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,9 @@ def test_compile(approximator, random_samples, jit_compile):
1313

1414

1515
def test_fit(approximator, train_dataset, validation_dataset, batch_size):
16-
from bayesflow.metrics import MaximumMeanDiscrepancy
1716
from bayesflow.networks import PointInferenceNetwork
1817

19-
inference_metrics = []
20-
if not isinstance(approximator.inference_network, PointInferenceNetwork):
21-
inference_metrics += [MaximumMeanDiscrepancy()]
22-
approximator.compile(inference_metrics=inference_metrics)
18+
approximator.compile()
2319

2420
mock_data = train_dataset[0]
2521
mock_data = keras.tree.map_structure(keras.ops.convert_to_tensor, mock_data)

0 commit comments

Comments
 (0)