From 326f05aa22c4e73625230a2a8fcdd5054420c4fb Mon Sep 17 00:00:00 2001 From: Valentin Pratz Date: Mon, 4 Aug 2025 10:04:25 +0000 Subject: [PATCH 1/3] fix: optimizer was not used in workflow with multiple fits For the optimizer to be used, the approximator.compile function has to be called. This was not the case. I adapted the `setup_optimizer` function to match the description in its docstring, and made the compilation conditional on its output. The output indicates if a new optimizer was configured. --- bayesflow/workflows/basic_workflow.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/bayesflow/workflows/basic_workflow.py b/bayesflow/workflows/basic_workflow.py index c30271eb1..34fa03794 100644 --- a/bayesflow/workflows/basic_workflow.py +++ b/bayesflow/workflows/basic_workflow.py @@ -914,6 +914,7 @@ def build_optimizer(self, epochs: int, num_batches: int, strategy: str) -> keras self.optimizer = keras.optimizers.Adam(learning_rate, clipnorm=1.5) else: self.optimizer = keras.optimizers.AdamW(learning_rate, weight_decay=5e-3, clipnorm=1.5) + return self.optimizer def _fit( self, @@ -955,9 +956,10 @@ def _fit( else: kwargs["callbacks"] = [model_checkpoint_callback] - self.build_optimizer(epochs, dataset.num_batches, strategy=strategy) - - if not self.approximator.built: + # returns None if no new optimizer was built and assigned to self.optimizer, which indicates we do not have + # to (re)compile the approximator. + optimizer = self.build_optimizer(epochs, dataset.num_batches, strategy=strategy) + if optimizer is not None: self.approximator.compile(optimizer=self.optimizer, metrics=kwargs.pop("metrics", None)) try: From 3d391d83a70e5f573a9d4838a1b17c397be37286 Mon Sep 17 00:00:00 2001 From: Valentin Pratz Date: Tue, 5 Aug 2025 06:44:13 +0000 Subject: [PATCH 2/3] fix: remove extra deserialize call for SummaryNetwork The extra call leads to the DTypePolicy to be deserialized. This is then passed as a class, and cannot be handled by autoconf, leading to the error discussed in https://github.com/bayesflow-org/bayesflow/pull/549 --- bayesflow/networks/summary_network.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/bayesflow/networks/summary_network.py b/bayesflow/networks/summary_network.py index e821be3f3..06bc3f935 100644 --- a/bayesflow/networks/summary_network.py +++ b/bayesflow/networks/summary_network.py @@ -4,7 +4,6 @@ from bayesflow.types import Tensor from bayesflow.utils import layer_kwargs, find_distribution from bayesflow.utils.decorators import sanitize_input_shape -from bayesflow.utils.serialization import deserialize class SummaryNetwork(keras.Layer): @@ -50,7 +49,3 @@ def compute_metrics(self, x: Tensor, stage: str = "training", **kwargs) -> dict[ metrics[metric.name] = metric(outputs, samples) return metrics - - @classmethod - def from_config(cls, config, custom_objects=None): - return cls(**deserialize(config, custom_objects=custom_objects)) From 952862c644198bc298538d346f6fe442a2630c75 Mon Sep 17 00:00:00 2001 From: Valentin Pratz Date: Tue, 5 Aug 2025 07:01:40 +0000 Subject: [PATCH 3/3] Compatibility: deserialize when get_config was overridden --- bayesflow/networks/summary_network.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/bayesflow/networks/summary_network.py b/bayesflow/networks/summary_network.py index 06bc3f935..d7df0b476 100644 --- a/bayesflow/networks/summary_network.py +++ b/bayesflow/networks/summary_network.py @@ -4,6 +4,7 @@ from bayesflow.types import Tensor from bayesflow.utils import layer_kwargs, find_distribution from bayesflow.utils.decorators import sanitize_input_shape +from bayesflow.utils.serialization import deserialize class SummaryNetwork(keras.Layer): @@ -49,3 +50,9 @@ def compute_metrics(self, x: Tensor, stage: str = "training", **kwargs) -> dict[ metrics[metric.name] = metric(outputs, samples) return metrics + + @classmethod + def from_config(cls, config, custom_objects=None): + if hasattr(cls.get_config, "_is_default") and cls.get_config._is_default: + return cls(**config) + return cls(**deserialize(config, custom_objects=custom_objects))