Skip to content

Commit 63388e9

Browse files
Add ListMLE Loss
1 parent 0ef7408 commit 63388e9

File tree

3 files changed

+285
-0
lines changed

3 files changed

+285
-0
lines changed

keras_rs/api/losses/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@
44
since your modifications would be overwritten.
55
"""
66

7+
from keras_rs.src.losses.list_mle_loss import (
8+
ListMLELoss as ListMLELoss,
9+
)
10+
711
from keras_rs.src.losses.pairwise_hinge_loss import (
812
PairwiseHingeLoss as PairwiseHingeLoss,
913
)

keras_rs/src/losses/list_mle_loss.py

Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
from typing import Any
2+
3+
import keras
4+
from keras import ops
5+
6+
from keras_rs.src import types
7+
from keras_rs.src.metrics.utils import standardize_call_inputs_ranks
8+
from keras_rs.src.api_export import keras_rs_export
9+
from keras_rs.src.metrics.ranking_metrics_utils import sort_by_scores
10+
11+
12+
@keras_rs_export("keras_rs.losses.ListMLELoss")
13+
class ListMLELoss(keras.losses.Loss):
14+
"""Implements ListMLE (Maximum Likelihood Estimation) loss for ranking.
15+
16+
ListMLE loss is a listwise ranking loss that maximizes the likelihood of
17+
the ground truth ranking. It works by:
18+
1. Sorting items by their relevance scores (labels)
19+
2. Computing the probability of observing this ranking given the
20+
predicted scores
21+
3. Maximizing this likelihood (minimizing negative log-likelihood)
22+
23+
The loss is computed as the negative log-likelihood of the ground truth
24+
ranking given the predicted scores:
25+
26+
```
27+
loss = -sum(log(exp(s_i) / sum(exp(s_j) for j >= i)))
28+
```
29+
30+
where s_i is the predicted score for item i in the sorted order.
31+
32+
Args:
33+
temperature: Temperature parameter for scaling logits. Higher values
34+
make the probability distribution more uniform. Defaults to 1.0.
35+
reduction: Type of reduction to apply to the loss. In almost all cases
36+
this should be `"sum_over_batch_size"`. Supported options are
37+
`"sum"`, `"sum_over_batch_size"`, `"mean"`,
38+
`"mean_with_sample_weight"` or `None`. Defaults to
39+
`"sum_over_batch_size"`.
40+
name: Optional name for the loss instance.
41+
dtype: The dtype of the loss's computations. Defaults to `None`.
42+
43+
Examples:
44+
```python
45+
# Basic usage
46+
loss_fn = ListMLELoss()
47+
48+
# With temperature scaling
49+
loss_fn = ListMLELoss(temperature=0.5)
50+
51+
# Example with synthetic data
52+
y_true = [[3, 2, 1, 0]] # Relevance scores
53+
y_pred = [[0.8, 0.6, 0.4, 0.2]] # Predicted scores
54+
loss = loss_fn(y_true, y_pred)
55+
```
56+
"""
57+
58+
def __init__(self, temperature: float = 1.0, **kwargs: Any) -> None:
59+
super().__init__(**kwargs)
60+
61+
if temperature <= 0.0:
62+
raise ValueError(
63+
f"`temperature` should be a positive float. Received: "
64+
f"`temperature` = {temperature}."
65+
)
66+
67+
self.temperature = temperature
68+
self._epsilon = 1e-10
69+
70+
def compute_unreduced_loss(
71+
self,
72+
labels: types.Tensor,
73+
logits: types.Tensor,
74+
mask: types.Tensor | None = None,
75+
) -> tuple[types.Tensor, types.Tensor]:
76+
"""Compute the unreduced ListMLE loss.
77+
78+
Args:
79+
labels: Ground truth relevance scores of
80+
shape [batch_size,list_size].
81+
logits: Predicted scores of shape [batch_size, list_size].
82+
mask: Optional mask of shape [batch_size, list_size].
83+
84+
Returns:
85+
Tuple of (losses, weights) where losses has shape [batch_size, 1]
86+
and weights has the same shape.
87+
"""
88+
89+
valid_mask = ops.greater_equal(labels, ops.cast(0.0, labels.dtype))
90+
91+
if mask is not None:
92+
valid_mask = ops.logical_and(valid_mask, ops.cast(mask, dtype="bool"))
93+
94+
num_valid_items = ops.sum(ops.cast(valid_mask, dtype=labels.dtype),
95+
axis=1, keepdims=True)
96+
97+
batch_has_valid_items = ops.greater(num_valid_items, 0.0)
98+
99+
100+
labels_for_sorting = ops.where(valid_mask, labels, ops.full_like(labels, -1e9))
101+
logits_masked = ops.where(valid_mask, logits, ops.full_like(logits, -1e9))
102+
103+
sorted_logits, sorted_valid_mask = sort_by_scores(
104+
tensors_to_sort=[logits_masked, valid_mask],
105+
scores=labels_for_sorting,
106+
mask=None,
107+
shuffle_ties=False,
108+
seed=None
109+
)
110+
111+
sorted_logits = ops.divide(
112+
sorted_logits,
113+
ops.cast(self.temperature, dtype=sorted_logits.dtype)
114+
)
115+
116+
valid_logits_for_max = ops.where(sorted_valid_mask, sorted_logits,
117+
ops.full_like(sorted_logits, -1e9))
118+
raw_max = ops.max(valid_logits_for_max, axis=1, keepdims=True)
119+
raw_max = ops.where(batch_has_valid_items, raw_max, ops.zeros_like(raw_max))
120+
sorted_logits = sorted_logits - raw_max
121+
122+
exp_logits = ops.exp(sorted_logits)
123+
exp_logits = ops.where(sorted_valid_mask, exp_logits, ops.zeros_like(exp_logits))
124+
125+
reversed_exp = ops.flip(exp_logits, axis=1)
126+
reversed_cumsum = ops.cumsum(reversed_exp, axis=1)
127+
cumsum_from_right = ops.flip(reversed_cumsum, axis=1)
128+
129+
log_normalizers = ops.log(cumsum_from_right + self._epsilon)
130+
log_probs = sorted_logits - log_normalizers
131+
132+
log_probs = ops.where(sorted_valid_mask, log_probs, ops.zeros_like(log_probs))
133+
134+
negative_log_likelihood = -ops.sum(log_probs, axis=1, keepdims=True)
135+
136+
negative_log_likelihood = ops.where(batch_has_valid_items, negative_log_likelihood,
137+
ops.zeros_like(negative_log_likelihood))
138+
139+
weights = ops.ones_like(negative_log_likelihood)
140+
141+
return negative_log_likelihood, weights
142+
143+
def call(
144+
self,
145+
y_true: types.Tensor,
146+
y_pred: types.Tensor,
147+
) -> types.Tensor:
148+
"""Compute the ListMLE loss.
149+
150+
Args:
151+
y_true: tensor or dict. Ground truth values. If tensor, of shape
152+
`(list_size)` for unbatched inputs or `(batch_size, list_size)`
153+
for batched inputs. If an item has a label of -1, it is ignored
154+
in loss computation. If it is a dictionary, it should have two
155+
keys: `"labels"` and `"mask"`. `"mask"` can be used to ignore
156+
elements in loss computation.
157+
y_pred: tensor. The predicted values, of shape `(list_size)` for
158+
unbatched inputs or `(batch_size, list_size)` for batched
159+
inputs. Should be of the same shape as `y_true`.
160+
161+
Returns:
162+
The loss tensor of shape [batch_size].
163+
"""
164+
mask = None
165+
if isinstance(y_true, dict):
166+
if "labels" not in y_true:
167+
raise ValueError(
168+
'`"labels"` should be present in `y_true`. Received: '
169+
f"`y_true` = {y_true}"
170+
)
171+
172+
mask = y_true.get("mask", None)
173+
y_true = y_true["labels"]
174+
175+
y_true = ops.convert_to_tensor(y_true)
176+
y_pred = ops.convert_to_tensor(y_pred)
177+
if mask is not None:
178+
mask = ops.convert_to_tensor(mask)
179+
180+
y_true, y_pred, mask, _ = standardize_call_inputs_ranks(
181+
y_true, y_pred, mask
182+
)
183+
184+
losses, weights = self.compute_unreduced_loss(
185+
labels=y_true, logits=y_pred, mask=mask
186+
)
187+
losses = ops.multiply(losses, weights)
188+
losses = ops.squeeze(losses, axis=-1)
189+
return losses
190+
191+
def get_config(self) -> dict[str, Any]:
192+
config: dict[str, Any] = super().get_config()
193+
config.update({"temperature": self.temperature})
194+
return config
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
import keras
2+
from absl.testing import parameterized
3+
from keras import ops
4+
from keras.losses import deserialize
5+
from keras.losses import serialize
6+
7+
from keras_rs.src import testing
8+
from keras_rs.src.losses.list_mle_loss import ListMLELoss
9+
10+
class ListMLELossTest(testing.TestCase, parameterized.TestCase):
11+
def setUp(self):
12+
self.unbatched_scores = ops.array([1.0, 3.0, 2.0, 4.0, 0.8])
13+
self.unbatched_labels = ops.array([1.0, 0.0, 1.0, 3.0, 2.0])
14+
15+
self.batched_scores = ops.array(
16+
[[1.0, 3.0, 2.0, 4.0, 0.8], [1.0, 1.8, 2.0, 3.0, 2.0]]
17+
)
18+
self.batched_labels = ops.array(
19+
[[1.0, 0.0, 1.0, 3.0, 2.0], [0.0, 1.0, 2.0, 3.0, 1.5]]
20+
)
21+
self.expected_output = ops.array([6.865693, 3.088192])
22+
23+
def test_unbatched_input(self):
24+
loss = ListMLELoss(reduction="none")
25+
output = loss(
26+
y_true=self.unbatched_labels, y_pred=self.unbatched_scores
27+
)
28+
self.assertEqual(output.shape, (1,))
29+
self.assertTrue(ops.convert_to_numpy(output[0]) > 0)
30+
self.assertAllClose(output, [self.expected_output[0]], atol=1e-5)
31+
32+
def test_batched_input(self):
33+
loss = ListMLELoss(reduction="none")
34+
output = loss(y_true=self.batched_labels, y_pred=self.batched_scores)
35+
self.assertEqual(output.shape, (2,))
36+
self.assertTrue(ops.convert_to_numpy(output[0]) > 0)
37+
self.assertTrue(ops.convert_to_numpy(output[1]) > 0)
38+
self.assertAllClose(output, self.expected_output, atol=1e-5)
39+
40+
def test_temperature(self):
41+
42+
loss_temp = ListMLELoss(temperature=0.5, reduction="none")
43+
output_temp = loss_temp(y_true=self.batched_labels, y_pred=self.batched_scores)
44+
45+
self.assertAllClose(output_temp,[10.969891,2.1283305],atol=1e-5,
46+
)
47+
48+
def test_invalid_input_rank(self):
49+
rank_1_input = ops.ones((2, 3, 4))
50+
51+
loss = ListMLELoss()
52+
with self.assertRaises(ValueError):
53+
loss(y_true=rank_1_input, y_pred=rank_1_input)
54+
55+
def test_loss_reduction(self):
56+
loss = ListMLELoss(reduction="sum_over_batch_size")
57+
output = loss(y_true=self.batched_labels, y_pred=self.batched_scores)
58+
59+
self.assertAlmostEqual(ops.convert_to_numpy(output), 4.9769425, places=5)
60+
61+
def test_scalar_sample_weight(self):
62+
sample_weight = ops.array(5.0)
63+
loss = ListMLELoss(reduction="none")
64+
65+
output = loss(
66+
y_true=self.batched_labels,
67+
y_pred=self.batched_scores,
68+
sample_weight=sample_weight,
69+
)
70+
71+
self.assertAllClose(output, self.expected_output * sample_weight, atol=1e-5)
72+
73+
def test_model_fit(self):
74+
inputs = keras.Input(shape=(20,), dtype="float32")
75+
outputs = keras.layers.Dense(5)(inputs)
76+
model = keras.Model(inputs=inputs, outputs=outputs)
77+
78+
model.compile(loss=ListMLELoss(), optimizer="adam")
79+
model.fit(
80+
x=keras.random.normal((2, 20)),
81+
y=keras.random.randint((2, 5), minval=0, maxval=2),
82+
)
83+
84+
def test_serialization(self):
85+
loss = ListMLELoss(temperature=0.8)
86+
restored = deserialize(serialize(loss))
87+
self.assertDictEqual(loss.get_config(), restored.get_config())

0 commit comments

Comments
 (0)