1
+ from typing import Any
2
+
3
+ import keras
4
+ from keras import ops
5
+
6
+ from keras_rs .src import types
7
+ from keras_rs .src .metrics .utils import standardize_call_inputs_ranks
8
+ from keras_rs .src .api_export import keras_rs_export
9
+ from keras_rs .src .metrics .ranking_metrics_utils import sort_by_scores
10
+
11
+
12
+ @keras_rs_export ("keras_rs.losses.ListMLELoss" )
13
+ class ListMLELoss (keras .losses .Loss ):
14
+ """Implements ListMLE (Maximum Likelihood Estimation) loss for ranking.
15
+
16
+ ListMLE loss is a listwise ranking loss that maximizes the likelihood of
17
+ the ground truth ranking. It works by:
18
+ 1. Sorting items by their relevance scores (labels)
19
+ 2. Computing the probability of observing this ranking given the
20
+ predicted scores
21
+ 3. Maximizing this likelihood (minimizing negative log-likelihood)
22
+
23
+ The loss is computed as the negative log-likelihood of the ground truth
24
+ ranking given the predicted scores:
25
+
26
+ ```
27
+ loss = -sum(log(exp(s_i) / sum(exp(s_j) for j >= i)))
28
+ ```
29
+
30
+ where s_i is the predicted score for item i in the sorted order.
31
+
32
+ Args:
33
+ temperature: Temperature parameter for scaling logits. Higher values
34
+ make the probability distribution more uniform. Defaults to 1.0.
35
+ reduction: Type of reduction to apply to the loss. In almost all cases
36
+ this should be `"sum_over_batch_size"`. Supported options are
37
+ `"sum"`, `"sum_over_batch_size"`, `"mean"`,
38
+ `"mean_with_sample_weight"` or `None`. Defaults to
39
+ `"sum_over_batch_size"`.
40
+ name: Optional name for the loss instance.
41
+ dtype: The dtype of the loss's computations. Defaults to `None`.
42
+
43
+ Examples:
44
+ ```python
45
+ # Basic usage
46
+ loss_fn = ListMLELoss()
47
+
48
+ # With temperature scaling
49
+ loss_fn = ListMLELoss(temperature=0.5)
50
+
51
+ # Example with synthetic data
52
+ y_true = [[3, 2, 1, 0]] # Relevance scores
53
+ y_pred = [[0.8, 0.6, 0.4, 0.2]] # Predicted scores
54
+ loss = loss_fn(y_true, y_pred)
55
+ ```
56
+ """
57
+
58
+ def __init__ (self , temperature : float = 1.0 , ** kwargs : Any ) -> None :
59
+ super ().__init__ (** kwargs )
60
+
61
+ if temperature <= 0.0 :
62
+ raise ValueError (
63
+ f"`temperature` should be a positive float. Received: "
64
+ f"`temperature` = { temperature } ."
65
+ )
66
+
67
+ self .temperature = temperature
68
+ self ._epsilon = 1e-10
69
+
70
+ def compute_unreduced_loss (
71
+ self ,
72
+ labels : types .Tensor ,
73
+ logits : types .Tensor ,
74
+ mask : types .Tensor | None = None ,
75
+ ) -> tuple [types .Tensor , types .Tensor ]:
76
+ """Compute the unreduced ListMLE loss.
77
+
78
+ Args:
79
+ labels: Ground truth relevance scores of
80
+ shape [batch_size,list_size].
81
+ logits: Predicted scores of shape [batch_size, list_size].
82
+ mask: Optional mask of shape [batch_size, list_size].
83
+
84
+ Returns:
85
+ Tuple of (losses, weights) where losses has shape [batch_size, 1]
86
+ and weights has the same shape.
87
+ """
88
+
89
+ valid_mask = ops .greater_equal (labels , ops .cast (0.0 , labels .dtype ))
90
+
91
+ if mask is not None :
92
+ valid_mask = ops .logical_and (valid_mask , ops .cast (mask , dtype = "bool" ))
93
+
94
+ num_valid_items = ops .sum (ops .cast (valid_mask , dtype = labels .dtype ),
95
+ axis = 1 , keepdims = True )
96
+
97
+ batch_has_valid_items = ops .greater (num_valid_items , 0.0 )
98
+
99
+
100
+ labels_for_sorting = ops .where (valid_mask , labels , ops .full_like (labels , - 1e9 ))
101
+ logits_masked = ops .where (valid_mask , logits , ops .full_like (logits , - 1e9 ))
102
+
103
+ sorted_logits , sorted_valid_mask = sort_by_scores (
104
+ tensors_to_sort = [logits_masked , valid_mask ],
105
+ scores = labels_for_sorting ,
106
+ mask = None ,
107
+ shuffle_ties = False ,
108
+ seed = None
109
+ )
110
+
111
+ sorted_logits = ops .divide (
112
+ sorted_logits ,
113
+ ops .cast (self .temperature , dtype = sorted_logits .dtype )
114
+ )
115
+
116
+ valid_logits_for_max = ops .where (sorted_valid_mask , sorted_logits ,
117
+ ops .full_like (sorted_logits , - 1e9 ))
118
+ raw_max = ops .max (valid_logits_for_max , axis = 1 , keepdims = True )
119
+ raw_max = ops .where (batch_has_valid_items , raw_max , ops .zeros_like (raw_max ))
120
+ sorted_logits = sorted_logits - raw_max
121
+
122
+ exp_logits = ops .exp (sorted_logits )
123
+ exp_logits = ops .where (sorted_valid_mask , exp_logits , ops .zeros_like (exp_logits ))
124
+
125
+ reversed_exp = ops .flip (exp_logits , axis = 1 )
126
+ reversed_cumsum = ops .cumsum (reversed_exp , axis = 1 )
127
+ cumsum_from_right = ops .flip (reversed_cumsum , axis = 1 )
128
+
129
+ log_normalizers = ops .log (cumsum_from_right + self ._epsilon )
130
+ log_probs = sorted_logits - log_normalizers
131
+
132
+ log_probs = ops .where (sorted_valid_mask , log_probs , ops .zeros_like (log_probs ))
133
+
134
+ negative_log_likelihood = - ops .sum (log_probs , axis = 1 , keepdims = True )
135
+
136
+ negative_log_likelihood = ops .where (batch_has_valid_items , negative_log_likelihood ,
137
+ ops .zeros_like (negative_log_likelihood ))
138
+
139
+ weights = ops .ones_like (negative_log_likelihood )
140
+
141
+ return negative_log_likelihood , weights
142
+
143
+ def call (
144
+ self ,
145
+ y_true : types .Tensor ,
146
+ y_pred : types .Tensor ,
147
+ ) -> types .Tensor :
148
+ """Compute the ListMLE loss.
149
+
150
+ Args:
151
+ y_true: tensor or dict. Ground truth values. If tensor, of shape
152
+ `(list_size)` for unbatched inputs or `(batch_size, list_size)`
153
+ for batched inputs. If an item has a label of -1, it is ignored
154
+ in loss computation. If it is a dictionary, it should have two
155
+ keys: `"labels"` and `"mask"`. `"mask"` can be used to ignore
156
+ elements in loss computation.
157
+ y_pred: tensor. The predicted values, of shape `(list_size)` for
158
+ unbatched inputs or `(batch_size, list_size)` for batched
159
+ inputs. Should be of the same shape as `y_true`.
160
+
161
+ Returns:
162
+ The loss tensor of shape [batch_size].
163
+ """
164
+ mask = None
165
+ if isinstance (y_true , dict ):
166
+ if "labels" not in y_true :
167
+ raise ValueError (
168
+ '`"labels"` should be present in `y_true`. Received: '
169
+ f"`y_true` = { y_true } "
170
+ )
171
+
172
+ mask = y_true .get ("mask" , None )
173
+ y_true = y_true ["labels" ]
174
+
175
+ y_true = ops .convert_to_tensor (y_true )
176
+ y_pred = ops .convert_to_tensor (y_pred )
177
+ if mask is not None :
178
+ mask = ops .convert_to_tensor (mask )
179
+
180
+ y_true , y_pred , mask , _ = standardize_call_inputs_ranks (
181
+ y_true , y_pred , mask
182
+ )
183
+
184
+ losses , weights = self .compute_unreduced_loss (
185
+ labels = y_true , logits = y_pred , mask = mask
186
+ )
187
+ losses = ops .multiply (losses , weights )
188
+ losses = ops .squeeze (losses , axis = - 1 )
189
+ return losses
190
+
191
+ def get_config (self ) -> dict [str , Any ]:
192
+ config : dict [str , Any ] = super ().get_config ()
193
+ config .update ({"temperature" : self .temperature })
194
+ return config
0 commit comments