Skip to content
This repository was archived by the owner on Jul 9, 2025. It is now read-only.
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions tensorflow_similarity/models/similarity_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,46 @@ class SimilarityModel(tf.keras.Model):
"""

def __init__(self, *args, **kwargs):
self.batch_compute_gradient_portion = float(kwargs.pop('batch_compute_gradient_portion', 1))
self.batch_random_permutation = bool(kwargs.pop('batch_random_permutation', False))

assert 0. < self.batch_compute_gradient_portion <= 1.
assert self.batch_random_permutation in [True, False]

super().__init__(*args, **kwargs)

def train_step(self, data):
x, y, sample_weight = tf.keras.utils.unpack_x_y_sample_weight(data)

if self.batch_random_permutation:
indices = tf.range(start=0, limit=tf.shape(x)[0], dtype=tf.int32)
shuffled_indices = tf.random.shuffle(indices)

x = tf.gather(x, shuffled_indices)
y = tf.gather(y, shuffled_indices)
if sample_weight is not None:
sample_weight = tf.gather(sample_weight, shuffled_indices)

l = tf.cast(tf.shape(x)[0], tf.float32)
k = tf.cast(self.batch_compute_gradient_portion * l, tf.int32)

# Run forward pass.
y_pred_without_gradient = self(x[k:], training=True)

with tf.GradientTape() as tape:
y_pred_with_gradient = self(x[:k], training=True)

y_pred = tf.concat([y_pred_with_gradient, y_pred_without_gradient], axis=0)

loss = self.compute_loss(x, y, y_pred, sample_weight)

self._validate_target_and_loss(y, loss)

# Run backwards pass.
self.optimizer.minimize(loss, self.trainable_variables, tape=tape)

return self.compute_metrics(x, y, y_pred, sample_weight)

def compile(
self,
optimizer: Optimizer | str | Mapping | Sequence = "rmsprop",
Expand Down