From 07fe7748e46b5ecbd844542968b37f50cf2f428a Mon Sep 17 00:00:00 2001 From: Valentin Pratz Date: Sat, 21 Jun 2025 11:03:15 +0000 Subject: [PATCH] fix trainable parameters in distributions --- bayesflow/distributions/diagonal_normal.py | 13 +++++----- bayesflow/distributions/diagonal_student_t.py | 25 ++++++++++--------- bayesflow/distributions/mixture.py | 2 +- 3 files changed, 21 insertions(+), 19 deletions(-) diff --git a/bayesflow/distributions/diagonal_normal.py b/bayesflow/distributions/diagonal_normal.py index f8d93b945..6b64445c7 100644 --- a/bayesflow/distributions/diagonal_normal.py +++ b/bayesflow/distributions/diagonal_normal.py @@ -58,7 +58,6 @@ def __init__( self.seed_generator = seed_generator or keras.random.SeedGenerator() self.dim = None - self.log_normalization_constant = None self._mean = None self._std = None @@ -71,17 +70,18 @@ def build(self, input_shape: Shape) -> None: self.mean = ops.cast(ops.broadcast_to(self.mean, (self.dim,)), "float32") self.std = ops.cast(ops.broadcast_to(self.std, (self.dim,)), "float32") - self.log_normalization_constant = -0.5 * self.dim * math.log(2.0 * math.pi) - ops.sum(ops.log(self.std)) - if self.trainable_parameters: self._mean = self.add_weight( shape=ops.shape(self.mean), - initializer=keras.initializers.get(self.mean), + initializer=keras.initializers.get(keras.ops.copy(self.mean)), dtype="float32", trainable=True, ) self._std = self.add_weight( - shape=ops.shape(self.std), initializer=keras.initializers.get(self.std), dtype="float32", trainable=True + shape=ops.shape(self.std), + initializer=keras.initializers.get(keras.ops.copy(self.std)), + dtype="float32", + trainable=True, ) else: self._mean = self.mean @@ -91,7 +91,8 @@ def log_prob(self, samples: Tensor, *, normalize: bool = True) -> Tensor: result = -0.5 * ops.sum((samples - self._mean) ** 2 / self._std**2, axis=-1) if normalize: - result += self.log_normalization_constant + log_normalization_constant = -0.5 * self.dim * math.log(2.0 * math.pi) - ops.sum(ops.log(self._std)) + result += log_normalization_constant return result diff --git a/bayesflow/distributions/diagonal_student_t.py b/bayesflow/distributions/diagonal_student_t.py index 98e3fb7eb..9b02ee821 100644 --- a/bayesflow/distributions/diagonal_student_t.py +++ b/bayesflow/distributions/diagonal_student_t.py @@ -63,7 +63,6 @@ def __init__( self.seed_generator = seed_generator or keras.random.SeedGenerator() - self.log_normalization_constant = None self.dim = None self._loc = None self._scale = None @@ -78,21 +77,16 @@ def build(self, input_shape: Shape) -> None: self.loc = ops.cast(ops.broadcast_to(self.loc, (self.dim,)), "float32") self.scale = ops.cast(ops.broadcast_to(self.scale, (self.dim,)), "float32") - self.log_normalization_constant = ( - -0.5 * self.dim * math.log(self.df) - - 0.5 * self.dim * math.log(math.pi) - - math.lgamma(0.5 * self.df) - + math.lgamma(0.5 * (self.df + self.dim)) - - ops.sum(keras.ops.log(self.scale)) - ) - if self.trainable_parameters: self._loc = self.add_weight( - shape=ops.shape(self.loc), initializer=keras.initializers.get(self.loc), dtype="float32", trainable=True + shape=ops.shape(self.loc), + initializer=keras.initializers.get(keras.ops.copy(self.loc)), + dtype="float32", + trainable=True, ) self._scale = self.add_weight( shape=ops.shape(self.scale), - initializer=keras.initializers.get(self.scale), + initializer=keras.initializers.get(keras.ops.copy(self.scale)), dtype="float32", trainable=True, ) @@ -105,7 +99,14 @@ def log_prob(self, samples: Tensor, *, normalize: bool = True) -> Tensor: result = -0.5 * (self.df + self.dim) * ops.log1p(mahalanobis_term / self.df) if normalize: - result += self.log_normalization_constant + log_normalization_constant = ( + -0.5 * self.dim * math.log(self.df) + - 0.5 * self.dim * math.log(math.pi) + - math.lgamma(0.5 * self.df) + + math.lgamma(0.5 * (self.df + self.dim)) + - ops.sum(keras.ops.log(self._scale)) + ) + result += log_normalization_constant return result diff --git a/bayesflow/distributions/mixture.py b/bayesflow/distributions/mixture.py index a7bf2ea27..e1f04e88f 100644 --- a/bayesflow/distributions/mixture.py +++ b/bayesflow/distributions/mixture.py @@ -144,7 +144,7 @@ def build(self, input_shape: Shape) -> None: self._mixture_logits = self.add_weight( shape=(len(self.distributions),), - initializer=keras.initializers.get(self.mixture_logits), + initializer=keras.initializers.get(keras.ops.copy(self.mixture_logits)), dtype="float32", trainable=self.trainable_mixture, )