Skip to content

Commit 6651b2a

Browse files
authored
Fix loss computation for CenterPillar when batch_size > 1 (keras-team#2056)
* Fix larger batch sizes for CenterPillar * Another fix
1 parent 1902b90 commit 6651b2a

File tree

1 file changed

+18
-21
lines changed

1 file changed

+18
-21
lines changed

keras_cv/models/object_detection_3d/center_pillar.py

Lines changed: 18 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -137,35 +137,32 @@ def compute_loss(self, x, y, y_pred, sample_weight=None, **kwargs):
137137

138138
# TODO(ianstenbit): loss heatmap threshold should be configurable.
139139
box_regression_mask = (
140-
ops.squeeze(
141-
ops.take(
142-
ops.reshape(heatmap, (ops.shape(heatmap)[0], -1)),
143-
index[..., 0] * ops.shape(heatmap)[1] + index[..., 1],
144-
axis=1,
145-
),
146-
axis=0,
140+
ops.take_along_axis(
141+
ops.reshape(heatmap, (heatmap.shape[0], -1)),
142+
index[..., 0] * heatmap.shape[1] + index[..., 1],
143+
axis=1,
147144
)
148145
> 0.95
149146
)
150147

151-
box = ops.squeeze(
152-
ops.take(
153-
ops.reshape(box, (ops.shape(box)[0], -1, 7)),
154-
index[..., 0] * ops.shape(box)[1] + index[..., 1],
155-
axis=1,
148+
box = ops.take_along_axis(
149+
ops.reshape(box, (ops.shape(box)[0], -1, 7)),
150+
ops.expand_dims(
151+
index[..., 0] * ops.shape(box)[1] + index[..., 1], axis=-1
156152
),
157-
axis=0,
153+
axis=1,
158154
)
159-
box_pred = ops.squeeze(
160-
ops.take(
161-
ops.reshape(
162-
box_pred,
163-
(ops.shape(box_pred)[0], -1, ops.shape(box_pred)[-1]),
164-
),
155+
156+
box_pred = ops.take_along_axis(
157+
ops.reshape(
158+
box_pred,
159+
(ops.shape(box_pred)[0], -1, ops.shape(box_pred)[-1]),
160+
),
161+
ops.expand_dims(
165162
index[..., 0] * ops.shape(box_pred)[1] + index[..., 1],
166-
axis=1,
163+
axis=-1,
167164
),
168-
axis=0,
165+
axis=1,
169166
)
170167

171168
box_center_mask = heatmap > 0.99

0 commit comments

Comments
 (0)