Skip to content

Commit 952862c

Browse files
committed
Compatibility: deserialize when get_config was overridden
1 parent 3d391d8 commit 952862c

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

bayesflow/networks/summary_network.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from bayesflow.types import Tensor
55
from bayesflow.utils import layer_kwargs, find_distribution
66
from bayesflow.utils.decorators import sanitize_input_shape
7+
from bayesflow.utils.serialization import deserialize
78

89

910
class SummaryNetwork(keras.Layer):
@@ -49,3 +50,9 @@ def compute_metrics(self, x: Tensor, stage: str = "training", **kwargs) -> dict[
4950
metrics[metric.name] = metric(outputs, samples)
5051

5152
return metrics
53+
54+
@classmethod
55+
def from_config(cls, config, custom_objects=None):
56+
if hasattr(cls.get_config, "_is_default") and cls.get_config._is_default:
57+
return cls(**config)
58+
return cls(**deserialize(config, custom_objects=custom_objects))

0 commit comments

Comments
 (0)